Skip to content

Commit

Permalink
Update launch.json and config.go files
Browse files Browse the repository at this point in the history
  • Loading branch information
arloor committed Mar 11, 2024
1 parent 16272ac commit 4a619a1
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 46 deletions.
2 changes: 2 additions & 0 deletions .vscode/launch.json
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
"program": "${workspaceFolder}/cmd/${workspaceFolderBasename}",
"args": [
"--addr=localhost:7788",
"--addr=localhost:3333",
"--user=username:password",
"--refer=arloor"
,"--tls=true"
]
Expand Down
41 changes: 31 additions & 10 deletions internal/config/config.go
Original file line number Diff line number Diff line change
@@ -1,37 +1,58 @@
package config

import (
"encoding/base64"
"flag"
"fmt"
"gopkg.in/natefinch/lumberjack.v2"
"gopkg.in/yaml.v2"
"io"
"log"
"os"
)

type stringArray []string

func (i *stringArray) String() string {
return fmt.Sprint(*i)
}

// Set 方法是flag.Value接口, 设置flag Value的方法.
// 通过多个flag指定的值, 所以我们追加到最终的数组上.
func (i *stringArray) Set(value string) error {
*i = append(*i, value)
return nil
}

type Config struct {
Addr string `yaml:"addr"`
UseTls bool `yaml:"tls"`
Cert string `yaml:"cert"`
PrivKey string `yaml:"key"`
LogPath string `yaml:"log"`
WebPath string `yaml:"content"`
BasicAuth string `yaml:"auth"`
Refer string `yaml:"refer"`
Addrs stringArray `yaml:"addrs"`
UseTls bool `yaml:"tls"`
Cert string `yaml:"cert"`
PrivKey string `yaml:"key"`
LogPath string `yaml:"log"`
WebPath string `yaml:"content"`
Users stringArray `yaml:"users"`
BasicAuth map[string]bool `yaml:"auth"`
Refer string `yaml:"refer"`
}

var Instance Config

func init() {
flag.StringVar(&Instance.Addr, "addr", ":7788", "监听地址")
flag.Var(&Instance.Addrs, "addr", ":7788")
flag.BoolVar(&Instance.UseTls, "tls", false, "是否使用tls")
flag.StringVar(&Instance.Cert, "cert", "cert.pem", "tls证书")
flag.StringVar(&Instance.PrivKey, "key", "privkey.pem", "tls私钥")
flag.StringVar(&Instance.LogPath, "log", "/tmp/proxy.log", "日志文件路径")
flag.StringVar(&Instance.WebPath, "content", ".", "文件服务器目录")
flag.StringVar(&Instance.BasicAuth, "auth", "", "Basic Auth Header")
flag.Var(&Instance.Users, "user", "")
flag.StringVar(&Instance.Refer, "refer", "", "本站的referer特征")
flag.Parse()
Instance.BasicAuth = make(map[string]bool)
for _, user := range Instance.Users {
base64Encode := "Basic " + base64.StdEncoding.EncodeToString([]byte(user))
Instance.BasicAuth[base64Encode] = true
}
initLog()
out, err := yaml.Marshal(Instance)
if err != nil {
Expand Down
10 changes: 6 additions & 4 deletions internal/server/connectHandlerFunc.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,12 @@ func dialContextCheckACL(network, hostPort string) (net.Conn, error) {

func connect(w http.ResponseWriter, r *http.Request) {
clientAddr := strings.Split(r.RemoteAddr, ":")[0]
if r.Header.Get("proxy-authorization") != config.Instance.BasicAuth {
log.Println("wrong basic auth from", clientAddr)
http.Error(w, "InternalServerError", http.StatusInternalServerError)
return
if len(config.Instance.BasicAuth) != 0 {
if _, ok := config.Instance.BasicAuth[r.Header.Get("proxy-authorization")]; !ok {
log.Println("wrong basic auth from", clientAddr)
http.Error(w, "InternalServerError", http.StatusInternalServerError)
return
}
}
if r.ProtoMajor == 2 {
if len(r.URL.Scheme) > 0 || len(r.URL.Path) > 0 {
Expand Down
71 changes: 39 additions & 32 deletions internal/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,43 +33,50 @@ func Serve() error {
http.Handle("/metrics", promhttp.Handler())
http.HandleFunc("/", fileHandlerFunc())

errors := make(chan error)

instance := config.Instance
handler := MineHandler{}
srv := &http.Server{
Addr: instance.Addr,
Handler: handler,
IdleTimeout: 30 * time.Second,
ReadHeaderTimeout: 30 * time.Second,
ReadTimeout: 30 * time.Second,
WriteTimeout: 30 * time.Second, // Set idle timeout
TLSConfig: &tls.Config{
GetCertificate: func(*tls.ClientHelloInfo) (*tls.Certificate, error) {
// Always get latest localhost.crt and localhost.key
// ex: keeping certificates file somewhere in global location where created certificates updated and this closure function can refer that
now := time.Now()
if ssl_cert == nil || now.Sub(ssl_last_cert_update) > ssl_cert_update_interval {
cert, err := tls.LoadX509KeyPair(instance.Cert, instance.PrivKey)
if err != nil {
log.Println("Error loading certificate", err)
if ssl_cert != nil {
return ssl_cert, nil
for _, addr := range instance.Addrs {
srv := &http.Server{
Addr: addr,
Handler: handler,
IdleTimeout: 30 * time.Second,
ReadHeaderTimeout: 30 * time.Second,
ReadTimeout: 30 * time.Second,
WriteTimeout: 30 * time.Second, // Set idle timeout
TLSConfig: &tls.Config{
GetCertificate: func(*tls.ClientHelloInfo) (*tls.Certificate, error) {
// Always get latest localhost.crt and localhost.key
// ex: keeping certificates file somewhere in global location where created certificates updated and this closure function can refer that
now := time.Now()
if ssl_cert == nil || now.Sub(ssl_last_cert_update) > ssl_cert_update_interval {
cert, err := tls.LoadX509KeyPair(instance.Cert, instance.PrivKey)
if err != nil {
log.Println("Error loading certificate", err)
if ssl_cert != nil {
return ssl_cert, nil
}
return nil, err
} else {
log.Println("Loaded certificate", instance.Cert, instance.PrivKey)
}
return nil, err
ssl_cert = &cert
ssl_last_cert_update = now
return &cert, nil
} else {
log.Println("Loaded certificate", instance.Cert, instance.PrivKey)
return ssl_cert, nil
}
ssl_cert = &cert
ssl_last_cert_update = now
return &cert, nil
} else {
return ssl_cert, nil
}
},
},
},
}
if !instance.UseTls {
return srv.ListenAndServe()
} else {
return srv.ListenAndServeTLS("", "")
}
go func() {
if !instance.UseTls {
errors <- srv.ListenAndServe()
} else {
errors <- srv.ListenAndServeTLS("", "")
}
}()
}
return <-errors
}

0 comments on commit 4a619a1

Please sign in to comment.