diff --git a/.vscode/launch.json b/.vscode/launch.json index 8d048f9..cf66223 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -13,6 +13,8 @@ "program": "${workspaceFolder}/cmd/${workspaceFolderBasename}", "args": [ "--addr=localhost:7788", + "--addr=localhost:3333", + "--user=username:password", "--refer=arloor" ,"--tls=true" ] diff --git a/internal/config/config.go b/internal/config/config.go index ad0cc1c..651a6f5 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -1,7 +1,9 @@ package config import ( + "encoding/base64" "flag" + "fmt" "gopkg.in/natefinch/lumberjack.v2" "gopkg.in/yaml.v2" "io" @@ -9,29 +11,48 @@ import ( "os" ) +type stringArray []string + +func (i *stringArray) String() string { + return fmt.Sprint(*i) +} + +// Set 方法是flag.Value接口, 设置flag Value的方法. +// 通过多个flag指定的值, 所以我们追加到最终的数组上. +func (i *stringArray) Set(value string) error { + *i = append(*i, value) + return nil +} + type Config struct { - Addr string `yaml:"addr"` - UseTls bool `yaml:"tls"` - Cert string `yaml:"cert"` - PrivKey string `yaml:"key"` - LogPath string `yaml:"log"` - WebPath string `yaml:"content"` - BasicAuth string `yaml:"auth"` - Refer string `yaml:"refer"` + Addrs stringArray `yaml:"addrs"` + UseTls bool `yaml:"tls"` + Cert string `yaml:"cert"` + PrivKey string `yaml:"key"` + LogPath string `yaml:"log"` + WebPath string `yaml:"content"` + Users stringArray `yaml:"users"` + BasicAuth map[string]bool `yaml:"auth"` + Refer string `yaml:"refer"` } var Instance Config func init() { - flag.StringVar(&Instance.Addr, "addr", ":7788", "监听地址") + flag.Var(&Instance.Addrs, "addr", ":7788") flag.BoolVar(&Instance.UseTls, "tls", false, "是否使用tls") flag.StringVar(&Instance.Cert, "cert", "cert.pem", "tls证书") flag.StringVar(&Instance.PrivKey, "key", "privkey.pem", "tls私钥") flag.StringVar(&Instance.LogPath, "log", "/tmp/proxy.log", "日志文件路径") flag.StringVar(&Instance.WebPath, "content", ".", "文件服务器目录") - flag.StringVar(&Instance.BasicAuth, "auth", "", "Basic Auth Header") + flag.Var(&Instance.Users, "user", "") flag.StringVar(&Instance.Refer, "refer", "", "本站的referer特征") flag.Parse() + Instance.BasicAuth = make(map[string]bool) + for _, user := range Instance.Users { + base64Encode := "Basic " + base64.StdEncoding.EncodeToString([]byte(user)) + Instance.BasicAuth[base64Encode] = true + } initLog() out, err := yaml.Marshal(Instance) if err != nil { diff --git a/internal/server/connectHandlerFunc.go b/internal/server/connectHandlerFunc.go index b34a585..812c4af 100644 --- a/internal/server/connectHandlerFunc.go +++ b/internal/server/connectHandlerFunc.go @@ -20,10 +20,12 @@ func dialContextCheckACL(network, hostPort string) (net.Conn, error) { func connect(w http.ResponseWriter, r *http.Request) { clientAddr := strings.Split(r.RemoteAddr, ":")[0] - if r.Header.Get("proxy-authorization") != config.Instance.BasicAuth { - log.Println("wrong basic auth from", clientAddr) - http.Error(w, "InternalServerError", http.StatusInternalServerError) - return + if len(config.Instance.BasicAuth) != 0 { + if _, ok := config.Instance.BasicAuth[r.Header.Get("proxy-authorization")]; !ok { + log.Println("wrong basic auth from", clientAddr) + http.Error(w, "InternalServerError", http.StatusInternalServerError) + return + } } if r.ProtoMajor == 2 { if len(r.URL.Scheme) > 0 || len(r.URL.Path) > 0 { diff --git a/internal/server/server.go b/internal/server/server.go index de4402d..0f728de 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -33,43 +33,50 @@ func Serve() error { http.Handle("/metrics", promhttp.Handler()) http.HandleFunc("/", fileHandlerFunc()) + errors := make(chan error) + instance := config.Instance handler := MineHandler{} - srv := &http.Server{ - Addr: instance.Addr, - Handler: handler, - IdleTimeout: 30 * time.Second, - ReadHeaderTimeout: 30 * time.Second, - ReadTimeout: 30 * time.Second, - WriteTimeout: 30 * time.Second, // Set idle timeout - TLSConfig: &tls.Config{ - GetCertificate: func(*tls.ClientHelloInfo) (*tls.Certificate, error) { - // Always get latest localhost.crt and localhost.key - // ex: keeping certificates file somewhere in global location where created certificates updated and this closure function can refer that - now := time.Now() - if ssl_cert == nil || now.Sub(ssl_last_cert_update) > ssl_cert_update_interval { - cert, err := tls.LoadX509KeyPair(instance.Cert, instance.PrivKey) - if err != nil { - log.Println("Error loading certificate", err) - if ssl_cert != nil { - return ssl_cert, nil + for _, addr := range instance.Addrs { + srv := &http.Server{ + Addr: addr, + Handler: handler, + IdleTimeout: 30 * time.Second, + ReadHeaderTimeout: 30 * time.Second, + ReadTimeout: 30 * time.Second, + WriteTimeout: 30 * time.Second, // Set idle timeout + TLSConfig: &tls.Config{ + GetCertificate: func(*tls.ClientHelloInfo) (*tls.Certificate, error) { + // Always get latest localhost.crt and localhost.key + // ex: keeping certificates file somewhere in global location where created certificates updated and this closure function can refer that + now := time.Now() + if ssl_cert == nil || now.Sub(ssl_last_cert_update) > ssl_cert_update_interval { + cert, err := tls.LoadX509KeyPair(instance.Cert, instance.PrivKey) + if err != nil { + log.Println("Error loading certificate", err) + if ssl_cert != nil { + return ssl_cert, nil + } + return nil, err + } else { + log.Println("Loaded certificate", instance.Cert, instance.PrivKey) } - return nil, err + ssl_cert = &cert + ssl_last_cert_update = now + return &cert, nil } else { - log.Println("Loaded certificate", instance.Cert, instance.PrivKey) + return ssl_cert, nil } - ssl_cert = &cert - ssl_last_cert_update = now - return &cert, nil - } else { - return ssl_cert, nil - } + }, }, - }, - } - if !instance.UseTls { - return srv.ListenAndServe() - } else { - return srv.ListenAndServeTLS("", "") + } + go func() { + if !instance.UseTls { + errors <- srv.ListenAndServe() + } else { + errors <- srv.ListenAndServeTLS("", "") + } + }() } + return <-errors }