diff --git a/internal/config/config.go b/internal/config/config.go index 78a61ec..f91a783 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -38,30 +38,30 @@ type Config struct { Refer string `yaml:"refer"` } -var Instance Config +var GlobalConfig Config func init() { - flag.Var(&Instance.Addrs, "addr", "监听地址,例如 :7788 。支持多个地址。默认监听 :7788") - flag.BoolVar(&Instance.UseTls, "tls", false, "是否使用tls") - flag.StringVar(&Instance.Cert, "cert", "cert.pem", "tls证书") - flag.StringVar(&Instance.PrivKey, "key", "privkey.pem", "tls私钥") - flag.StringVar(&Instance.LogPath, "log", "/tmp/proxy.log", "日志文件路径") - flag.StringVar(&Instance.WebPath, "content", ".", "文件服务器目录") - flag.Var(&Instance.Users, "user", "Basic认证的用户名密码,例如username:password") - flag.StringVar(&Instance.Refer, "refer", "", "本站的referer特征") + flag.Var(&GlobalConfig.Addrs, "addr", "监听地址,例如 :7788 。支持多个地址。默认监听 :7788") + flag.BoolVar(&GlobalConfig.UseTls, "tls", false, "是否使用tls") + flag.StringVar(&GlobalConfig.Cert, "cert", "cert.pem", "tls证书") + flag.StringVar(&GlobalConfig.PrivKey, "key", "privkey.pem", "tls私钥") + flag.StringVar(&GlobalConfig.LogPath, "log", "/tmp/proxy.log", "日志文件路径") + flag.StringVar(&GlobalConfig.WebPath, "content", ".", "文件服务器目录") + flag.Var(&GlobalConfig.Users, "user", "Basic认证的用户名密码,例如username:password") + flag.StringVar(&GlobalConfig.Refer, "refer", "", "本站的referer特征") flag.Parse() - if len(Instance.Addrs) == 0 { - Instance.Addrs = append(Instance.Addrs, ":7788") + if len(GlobalConfig.Addrs) == 0 { + GlobalConfig.Addrs = append(GlobalConfig.Addrs, ":7788") } - Instance.BasicAuth = make(map[string]string) - for _, user := range Instance.Users { + GlobalConfig.BasicAuth = make(map[string]string) + for _, user := range GlobalConfig.Users { base64Encode := "Basic " + base64.StdEncoding.EncodeToString([]byte(user)) - Instance.BasicAuth[base64Encode] = strings.Split(user, ":")[0] + GlobalConfig.BasicAuth[base64Encode] = strings.Split(user, ":")[0] } initLog() - out, err := yaml.Marshal(Instance) + out, err := yaml.Marshal(GlobalConfig) if err != nil { - log.Println("go web server config:", Instance) + log.Println("go web server config:", GlobalConfig) } else { log.Printf("go web server config: \n%s", string(out)) } @@ -117,7 +117,7 @@ func init() { //} func initLog() { - file := Instance.LogPath + file := GlobalConfig.LogPath rollingFile := &lumberjack.Logger{ Filename: file, MaxSize: 50, diff --git a/internal/server/connectHandlerFunc.go b/internal/server/connectHandlerFunc.go index 0ed36cb..d920f4b 100644 --- a/internal/server/connectHandlerFunc.go +++ b/internal/server/connectHandlerFunc.go @@ -20,10 +20,10 @@ func dialContextCheckACL(network, hostPort string) (net.Conn, error) { func connect(w http.ResponseWriter, r *http.Request) { clientAddr := strings.Split(r.RemoteAddr, ":")[0] - var username string; - if len(config.Instance.BasicAuth) != 0 { - var ok bool; - username, ok = config.Instance.BasicAuth[r.Header.Get("proxy-authorization")] + var username string + if len(config.GlobalConfig.BasicAuth) != 0 { + var ok bool + username, ok = config.GlobalConfig.BasicAuth[r.Header.Get("proxy-authorization")] if !ok { log.Println("wrong basic auth from", clientAddr) http.Error(w, "InternalServerError", http.StatusInternalServerError) @@ -56,7 +56,7 @@ func connect(w http.ResponseWriter, r *http.Request) { switch r.ProtoMajor { case 1: // http1: hijack the whole flow - _, err := serveHijack(w, targetConn, clientAddr, hostPort,username) + _, err := serveHijack(w, targetConn, clientAddr, hostPort, username) if err != nil { log.Println(err, r.RemoteAddr) } @@ -74,7 +74,7 @@ func connect(w http.ResponseWriter, r *http.Request) { w.Header().Add("Server", "go_web_server") } wFlusher.Flush() - err := dualStream(targetConn, r.Body, w, clientAddr, hostPort,username) + err := dualStream(targetConn, r.Body, w, clientAddr, hostPort, username) if err != nil { log.Println(err, r.RemoteAddr) } @@ -86,7 +86,7 @@ func connect(w http.ResponseWriter, r *http.Request) { // Hijacks the connection from ResponseWriter, writes the response and proxies data between targetConn // and hijacked connection. -func serveHijack(w http.ResponseWriter, targetConn net.Conn, clientAddr string, hostPort string,username string) (int, error) { +func serveHijack(w http.ResponseWriter, targetConn net.Conn, clientAddr string, hostPort string, username string) (int, error) { hijacker, ok := w.(http.Hijacker) if !ok { return http.StatusInternalServerError, errors.New("ResponseWriter does not implement Hijacker") @@ -124,21 +124,21 @@ func serveHijack(w http.ResponseWriter, targetConn net.Conn, clientAddr string, return http.StatusInternalServerError, errors.New("failed to send response to client: " + err.Error()) } - return 0, dualStream(targetConn, clientConn, clientConn, clientAddr, hostPort,username) + return 0, dualStream(targetConn, clientConn, clientConn, clientAddr, hostPort, username) } var bufferPool = &sync.Pool{New: func() interface{} { return make([]byte, 32*1024) }} -func dualStream(target net.Conn, clientReader io.ReadCloser, clientWriter io.Writer, clientAddr string, hostPort string,username string) error { +func dualStream(target net.Conn, clientReader io.ReadCloser, clientWriter io.Writer, clientAddr string, hostPort string, username string) error { stream := func(w io.Writer, r io.Reader) error { // copy bytes from r to w buf := bufferPool.Get().([]byte) defer bufferPool.Put(buf) buf = buf[0:cap(buf)] nw, _err := flushingIoCopy(w, r, buf) - ProxyTraffic.WithLabelValues(clientAddr, hostPort,username).Add(float64(nw)) + ProxyTraffic.WithLabelValues(clientAddr, hostPort, username).Add(float64(nw)) if closeWriter, ok := w.(interface { CloseWrite() error }); ok { diff --git a/internal/server/httpHandlerFunc.go b/internal/server/httpHandlerFunc.go index 652bfcd..1fedb58 100644 --- a/internal/server/httpHandlerFunc.go +++ b/internal/server/httpHandlerFunc.go @@ -24,7 +24,7 @@ func writeIp(w http.ResponseWriter, r *http.Request) { func fileHandlerFunc() http.HandlerFunc { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if config.Instance.Refer != "" && r.Header.Get("referer") != "" && !strings.Contains(r.Header.Get("referer"), config.Instance.Refer) && (strings.HasSuffix(r.URL.Path, ".html") || strings.HasSuffix(r.URL.Path, "/")) { + if config.GlobalConfig.Refer != "" && r.Header.Get("referer") != "" && !strings.Contains(r.Header.Get("referer"), config.GlobalConfig.Refer) && (strings.HasSuffix(r.URL.Path, ".html") || strings.HasSuffix(r.URL.Path, "/")) { HttpRequst.WithLabelValues(r.Header.Get("referer"), r.URL.Path).Inc() HttpRequst.WithLabelValues("all", "all").Inc() } @@ -39,7 +39,7 @@ func fileHandlerFunc() http.HandlerFunc { http.Error(w, "invalid URL path", http.StatusBadRequest) return } - fs := http.FileServer(http.Dir(config.Instance.WebPath)) + fs := http.FileServer(http.Dir(config.GlobalConfig.WebPath)) http.StripPrefix("/", fs).ServeHTTP(w, r) }) } diff --git a/internal/server/server.go b/internal/server/server.go index 853a0e8..b992925 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -25,7 +25,7 @@ var ( ProxyTraffic = promauto.NewCounterVec(prometheus.CounterOpts{ Name: "proxy_traffic_total", Help: "num proxy_traffic", - }, []string{"client", "target","username"}) + }, []string{"client", "target", "username"}) ) func Serve() error { @@ -35,7 +35,7 @@ func Serve() error { errors := make(chan error) - instance := config.Instance + instance := config.GlobalConfig handler := MineHandler{} for _, addr := range instance.Addrs { srv := &http.Server{