Skip to content

Commit

Permalink
Refactor global configuration variables***
Browse files Browse the repository at this point in the history
  • Loading branch information
arloor committed Mar 12, 2024
1 parent 1bda99e commit fdfc915
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 31 deletions.
34 changes: 17 additions & 17 deletions internal/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,30 +38,30 @@ type Config struct {
Refer string `yaml:"refer"`
}

var Instance Config
var GlobalConfig Config

func init() {
flag.Var(&Instance.Addrs, "addr", "监听地址,例如 :7788 。支持多个地址。默认监听 :7788")
flag.BoolVar(&Instance.UseTls, "tls", false, "是否使用tls")
flag.StringVar(&Instance.Cert, "cert", "cert.pem", "tls证书")
flag.StringVar(&Instance.PrivKey, "key", "privkey.pem", "tls私钥")
flag.StringVar(&Instance.LogPath, "log", "/tmp/proxy.log", "日志文件路径")
flag.StringVar(&Instance.WebPath, "content", ".", "文件服务器目录")
flag.Var(&Instance.Users, "user", "Basic认证的用户名密码,例如username:password")
flag.StringVar(&Instance.Refer, "refer", "", "本站的referer特征")
flag.Var(&GlobalConfig.Addrs, "addr", "监听地址,例如 :7788 。支持多个地址。默认监听 :7788")
flag.BoolVar(&GlobalConfig.UseTls, "tls", false, "是否使用tls")
flag.StringVar(&GlobalConfig.Cert, "cert", "cert.pem", "tls证书")
flag.StringVar(&GlobalConfig.PrivKey, "key", "privkey.pem", "tls私钥")
flag.StringVar(&GlobalConfig.LogPath, "log", "/tmp/proxy.log", "日志文件路径")
flag.StringVar(&GlobalConfig.WebPath, "content", ".", "文件服务器目录")
flag.Var(&GlobalConfig.Users, "user", "Basic认证的用户名密码,例如username:password")
flag.StringVar(&GlobalConfig.Refer, "refer", "", "本站的referer特征")
flag.Parse()
if len(Instance.Addrs) == 0 {
Instance.Addrs = append(Instance.Addrs, ":7788")
if len(GlobalConfig.Addrs) == 0 {
GlobalConfig.Addrs = append(GlobalConfig.Addrs, ":7788")
}
Instance.BasicAuth = make(map[string]string)
for _, user := range Instance.Users {
GlobalConfig.BasicAuth = make(map[string]string)
for _, user := range GlobalConfig.Users {
base64Encode := "Basic " + base64.StdEncoding.EncodeToString([]byte(user))
Instance.BasicAuth[base64Encode] = strings.Split(user, ":")[0]
GlobalConfig.BasicAuth[base64Encode] = strings.Split(user, ":")[0]
}
initLog()
out, err := yaml.Marshal(Instance)
out, err := yaml.Marshal(GlobalConfig)
if err != nil {
log.Println("go web server config:", Instance)
log.Println("go web server config:", GlobalConfig)
} else {
log.Printf("go web server config: \n%s", string(out))
}
Expand Down Expand Up @@ -117,7 +117,7 @@ func init() {
//}

func initLog() {
file := Instance.LogPath
file := GlobalConfig.LogPath
rollingFile := &lumberjack.Logger{
Filename: file,
MaxSize: 50,
Expand Down
20 changes: 10 additions & 10 deletions internal/server/connectHandlerFunc.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@ func dialContextCheckACL(network, hostPort string) (net.Conn, error) {

func connect(w http.ResponseWriter, r *http.Request) {
clientAddr := strings.Split(r.RemoteAddr, ":")[0]
var username string;
if len(config.Instance.BasicAuth) != 0 {
var ok bool;
username, ok = config.Instance.BasicAuth[r.Header.Get("proxy-authorization")]
var username string
if len(config.GlobalConfig.BasicAuth) != 0 {
var ok bool
username, ok = config.GlobalConfig.BasicAuth[r.Header.Get("proxy-authorization")]
if !ok {
log.Println("wrong basic auth from", clientAddr)
http.Error(w, "InternalServerError", http.StatusInternalServerError)
Expand Down Expand Up @@ -56,7 +56,7 @@ func connect(w http.ResponseWriter, r *http.Request) {

switch r.ProtoMajor {
case 1: // http1: hijack the whole flow
_, err := serveHijack(w, targetConn, clientAddr, hostPort,username)
_, err := serveHijack(w, targetConn, clientAddr, hostPort, username)
if err != nil {
log.Println(err, r.RemoteAddr)
}
Expand All @@ -74,7 +74,7 @@ func connect(w http.ResponseWriter, r *http.Request) {
w.Header().Add("Server", "go_web_server")
}
wFlusher.Flush()
err := dualStream(targetConn, r.Body, w, clientAddr, hostPort,username)
err := dualStream(targetConn, r.Body, w, clientAddr, hostPort, username)
if err != nil {
log.Println(err, r.RemoteAddr)
}
Expand All @@ -86,7 +86,7 @@ func connect(w http.ResponseWriter, r *http.Request) {

// Hijacks the connection from ResponseWriter, writes the response and proxies data between targetConn
// and hijacked connection.
func serveHijack(w http.ResponseWriter, targetConn net.Conn, clientAddr string, hostPort string,username string) (int, error) {
func serveHijack(w http.ResponseWriter, targetConn net.Conn, clientAddr string, hostPort string, username string) (int, error) {
hijacker, ok := w.(http.Hijacker)
if !ok {
return http.StatusInternalServerError, errors.New("ResponseWriter does not implement Hijacker")
Expand Down Expand Up @@ -124,21 +124,21 @@ func serveHijack(w http.ResponseWriter, targetConn net.Conn, clientAddr string,
return http.StatusInternalServerError, errors.New("failed to send response to client: " + err.Error())
}

return 0, dualStream(targetConn, clientConn, clientConn, clientAddr, hostPort,username)
return 0, dualStream(targetConn, clientConn, clientConn, clientAddr, hostPort, username)
}

var bufferPool = &sync.Pool{New: func() interface{} {
return make([]byte, 32*1024)
}}

func dualStream(target net.Conn, clientReader io.ReadCloser, clientWriter io.Writer, clientAddr string, hostPort string,username string) error {
func dualStream(target net.Conn, clientReader io.ReadCloser, clientWriter io.Writer, clientAddr string, hostPort string, username string) error {
stream := func(w io.Writer, r io.Reader) error {
// copy bytes from r to w
buf := bufferPool.Get().([]byte)
defer bufferPool.Put(buf)
buf = buf[0:cap(buf)]
nw, _err := flushingIoCopy(w, r, buf)
ProxyTraffic.WithLabelValues(clientAddr, hostPort,username).Add(float64(nw))
ProxyTraffic.WithLabelValues(clientAddr, hostPort, username).Add(float64(nw))
if closeWriter, ok := w.(interface {
CloseWrite() error
}); ok {
Expand Down
4 changes: 2 additions & 2 deletions internal/server/httpHandlerFunc.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ func writeIp(w http.ResponseWriter, r *http.Request) {
func fileHandlerFunc() http.HandlerFunc {

return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if config.Instance.Refer != "" && r.Header.Get("referer") != "" && !strings.Contains(r.Header.Get("referer"), config.Instance.Refer) && (strings.HasSuffix(r.URL.Path, ".html") || strings.HasSuffix(r.URL.Path, "/")) {
if config.GlobalConfig.Refer != "" && r.Header.Get("referer") != "" && !strings.Contains(r.Header.Get("referer"), config.GlobalConfig.Refer) && (strings.HasSuffix(r.URL.Path, ".html") || strings.HasSuffix(r.URL.Path, "/")) {
HttpRequst.WithLabelValues(r.Header.Get("referer"), r.URL.Path).Inc()
HttpRequst.WithLabelValues("all", "all").Inc()
}
Expand All @@ -39,7 +39,7 @@ func fileHandlerFunc() http.HandlerFunc {
http.Error(w, "invalid URL path", http.StatusBadRequest)
return
}
fs := http.FileServer(http.Dir(config.Instance.WebPath))
fs := http.FileServer(http.Dir(config.GlobalConfig.WebPath))
http.StripPrefix("/", fs).ServeHTTP(w, r)
})
}
Expand Down
4 changes: 2 additions & 2 deletions internal/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ var (
ProxyTraffic = promauto.NewCounterVec(prometheus.CounterOpts{
Name: "proxy_traffic_total",
Help: "num proxy_traffic",
}, []string{"client", "target","username"})
}, []string{"client", "target", "username"})
)

func Serve() error {
Expand All @@ -35,7 +35,7 @@ func Serve() error {

errors := make(chan error)

instance := config.Instance
instance := config.GlobalConfig
handler := MineHandler{}
for _, addr := range instance.Addrs {
srv := &http.Server{
Expand Down

0 comments on commit fdfc915

Please sign in to comment.