diff --git a/endless.go b/endless.go index 45090a5..2dd507e 100644 --- a/endless.go +++ b/endless.go @@ -4,6 +4,7 @@ import ( "crypto/tls" "flag" "fmt" + "io" "log" "net" "net/http" @@ -60,8 +61,18 @@ func init() { DefaultHammerTime = 60 * time.Second } +type Server interface { + Serve(l net.Listener) error +} + +type CloseFunc func() error + +func (fn CloseFunc) Close() error { return fn() } + type endlessServer struct { - http.Server + Server + io.Closer + Addr string EndlessListener net.Listener SignalHooks map[int]map[os.Signal][]func() tlsInnerListener *endlessListener @@ -71,11 +82,7 @@ type endlessServer struct { state uint8 } -/* -NewServer returns an intialized endlessServer Object. Calling Serve on it will -actually "start" the server. -*/ -func NewServer(addr string, handler http.Handler) (srv *endlessServer) { +func newEndlessServer(addr string, server Server, closer io.Closer) (srv *endlessServer) { runningServerReg.Lock() defer runningServerReg.Unlock() if !flag.Parsed() { @@ -90,6 +97,9 @@ func NewServer(addr string, handler http.Handler) (srv *endlessServer) { } srv = &endlessServer{ + Server: server, + Closer: closer, + Addr: addr, wg: sync.WaitGroup{}, sigChan: make(chan os.Signal), isChild: isChild, @@ -114,18 +124,33 @@ func NewServer(addr string, handler http.Handler) (srv *endlessServer) { state: STATE_INIT, } - srv.Server.Addr = addr - srv.Server.ReadTimeout = DefaultReadTimeOut - srv.Server.WriteTimeout = DefaultWriteTimeOut - srv.Server.MaxHeaderBytes = DefaultMaxHeaderBytes - srv.Server.Handler = handler - runningServersOrder = append(runningServersOrder, addr) runningServers[addr] = srv return } +/* +NewServer returns an intialized endlessServer Object. Calling Serve on it will +actually "start" the server. +*/ +func NewServer(addr string, handler http.Handler) *endlessServer { + srv := &http.Server{ + Addr: addr, + ReadTimeout: DefaultReadTimeOut, + WriteTimeout: DefaultWriteTimeOut, + MaxHeaderBytes: DefaultMaxHeaderBytes, + Handler: handler, + } + + return newEndlessServer(addr, srv, CloseFunc(func() error { + // disable keep-alives on existing connections + srv.SetKeepAlivesEnabled(false) + + return nil + })) +} + /* ListenAndServe listens on the TCP network address addr and then calls Serve with handler to handle requests on incoming connections. Handler is typically @@ -148,6 +173,79 @@ func ListenAndServeTLS(addr string, certFile string, keyFile string, handler htt return server.ListenAndServeTLS(certFile, keyFile) } +type Handler interface { + Serve(net.Conn) +} + +type HandleFunc func(net.Conn) + +func (fn HandleFunc) Serve(conn net.Conn) { fn(conn) } + +func ListenAndServeTCP(addr string, handler Handler) error { + server := NewTcpServer(addr, handler) + return server.ListenAndServe() +} + +type tcpServer struct { + handler Handler +} + +func NewTcpServer(addr string, handler Handler) *endlessServer { + return newEndlessServer(addr, &tcpServer{handler}, CloseFunc(func() error { return nil })) +} + +func (srv *tcpServer) Serve(l net.Listener) error { + defer l.Close() + var tempDelay time.Duration // how long to sleep on accept failure + for { + rw, e := l.Accept() + if e != nil { + if ne, ok := e.(net.Error); ok && ne.Temporary() { + if tempDelay == 0 { + tempDelay = 5 * time.Millisecond + } else { + tempDelay *= 2 + } + if max := 1 * time.Second; tempDelay > max { + tempDelay = max + } + log.Printf("%d %s tcp: Accept error: %v; retrying in %v", os.Getpid(), l.(*endlessListener).Addr(), e, tempDelay) + time.Sleep(tempDelay) + continue + } + return e + } + tempDelay = 0 + + go srv.handler.Serve(rw) + } +} + +type tlsServer struct { + tcpServer + + TLSConfig *tls.Config +} + +func NewTlsServer(addr string, handler Handler, config *tls.Config) *endlessServer { + return newEndlessServer(addr, &tlsServer{tcpServer{handler}, config}, CloseFunc(func() error { return nil })) +} + +func (srv *tlsServer) Serve(l net.Listener) error { + return srv.tcpServer.Serve(tls.NewListener(l, srv.TLSConfig)) +} + +func (srv *endlessServer) TLSConfig() *tls.Config { + switch t := srv.Server.(type) { + case *http.Server: + return t.TLSConfig + case *tlsServer: + return t.TLSConfig + } + + return nil +} + /* Serve accepts incoming HTTP connections on the listener l, creating a new service goroutine for each. The service goroutines read requests and then call @@ -159,18 +257,18 @@ sync.Waitgroup so that all outstanding connections can be served before shutting down the server. */ func (srv *endlessServer) Serve() (err error) { - defer log.Println(syscall.Getpid(), "Serve() returning...") + defer log.Println(syscall.Getpid(), srv.EndlessListener.Addr(), "Serve() returning...") srv.state = STATE_RUNNING err = srv.Server.Serve(srv.EndlessListener) - log.Println(syscall.Getpid(), "Waiting for connections to finish...") + log.Println(syscall.Getpid(), srv.EndlessListener.Addr(), "Waiting for connections to finish...") srv.wg.Wait() srv.state = STATE_TERMINATE return } /* -ListenAndServe listens on the TCP network address srv.Addr and then calls Serve -to handle requests on incoming connections. If srv.Addr is blank, ":http" is +ListenAndServe listens on the TCP network address srv.EndlessListener.Addr() and then calls Serve +to handle requests on incoming connections. If srv.EndlessListener.Addr() is blank, ":http" is used. */ func (srv *endlessServer) ListenAndServe() (err error) { @@ -193,12 +291,12 @@ func (srv *endlessServer) ListenAndServe() (err error) { syscall.Kill(syscall.Getppid(), syscall.SIGTERM) } - log.Println(syscall.Getpid(), srv.Addr) + log.Println(syscall.Getpid(), srv.EndlessListener.Addr(), "ListenAndServe") return srv.Serve() } /* -ListenAndServeTLS listens on the TCP network address srv.Addr and then calls +ListenAndServeTLS listens on the TCP network address srv.EndlessListener.Addr() and then calls Serve to handle requests on incoming TLS connections. Filenames containing a certificate and matching private key for the server must @@ -206,7 +304,7 @@ be provided. If the certificate is signed by a certificate authority, the certFile should be the concatenation of the server's certificate followed by the CA's certificate. -If srv.Addr is blank, ":https" is used. +If srv.EndlessListener.Addr() is blank, ":https" is used. */ func (srv *endlessServer) ListenAndServeTLS(certFile, keyFile string) (err error) { addr := srv.Addr @@ -215,9 +313,11 @@ func (srv *endlessServer) ListenAndServeTLS(certFile, keyFile string) (err error } config := &tls.Config{} - if srv.TLSConfig != nil { - *config = *srv.TLSConfig + + if srv.TLSConfig() != nil { + *config = *srv.TLSConfig() } + if config.NextProtos == nil { config.NextProtos = []string{"http/1.1"} } @@ -243,7 +343,7 @@ func (srv *endlessServer) ListenAndServeTLS(certFile, keyFile string) (err error syscall.Kill(syscall.Getppid(), syscall.SIGTERM) } - log.Println(syscall.Getpid(), srv.Addr) + log.Println(syscall.Getpid(), srv.EndlessListener.Addr(), "ListenAndServeTLS") return srv.Serve() } @@ -298,26 +398,26 @@ func (srv *endlessServer) handleSignals() { srv.signalHooks(PRE_SIGNAL, sig) switch sig { case syscall.SIGHUP: - log.Println(pid, "Received SIGHUP. forking.") + log.Println(pid, srv.EndlessListener.Addr(), "Received SIGHUP. forking.") err := srv.fork() if err != nil { log.Println("Fork err:", err) } case syscall.SIGUSR1: - log.Println(pid, "Received SIGUSR1.") + log.Println(pid, srv.EndlessListener.Addr(), "Received SIGUSR1.") case syscall.SIGUSR2: - log.Println(pid, "Received SIGUSR2.") + log.Println(pid, srv.EndlessListener.Addr(), "Received SIGUSR2.") srv.hammerTime(0 * time.Second) case syscall.SIGINT: - log.Println(pid, "Received SIGINT.") + log.Println(pid, srv.EndlessListener.Addr(), "Received SIGINT.") srv.shutdown() case syscall.SIGTERM: - log.Println(pid, "Received SIGTERM.") + log.Println(pid, srv.EndlessListener.Addr(), "Received SIGTERM.") srv.shutdown() case syscall.SIGTSTP: - log.Println(pid, "Received SIGTSTP.") + log.Println(pid, srv.EndlessListener.Addr(), "Received SIGTSTP.") default: - log.Printf("Received %v: nothing i care about...\n", sig) + log.Printf("%d %s Received %v: nothing i care about...\n", pid, srv.EndlessListener.Addr(), sig) } srv.signalHooks(POST_SIGNAL, sig) } @@ -347,11 +447,12 @@ func (srv *endlessServer) shutdown() { if DefaultHammerTime >= 0 { go srv.hammerTime(DefaultHammerTime) } - // disable keep-alives on existing connections - srv.SetKeepAlivesEnabled(false) + + srv.Close() + err := srv.EndlessListener.Close() if err != nil { - log.Println(syscall.Getpid(), "Listener.Close() error:", err) + log.Println(syscall.Getpid(), srv.EndlessListener.Addr(), "Listener.Close() error:", err) } else { log.Println(syscall.Getpid(), srv.EndlessListener.Addr(), "Listener closed.") } @@ -405,12 +506,12 @@ func (srv *endlessServer) fork() (err error) { switch srvPtr.EndlessListener.(type) { case *endlessListener: // normal listener - files[socketPtrOffsetMap[srvPtr.Server.Addr]] = srvPtr.EndlessListener.(*endlessListener).File() + files[socketPtrOffsetMap[srvPtr.Addr]] = srvPtr.EndlessListener.(*endlessListener).File() default: // tls listener - files[socketPtrOffsetMap[srvPtr.Server.Addr]] = srvPtr.tlsInnerListener.File() + files[socketPtrOffsetMap[srvPtr.Addr]] = srvPtr.tlsInnerListener.File() } - orderArgs[socketPtrOffsetMap[srvPtr.Server.Addr]] = srvPtr.Server.Addr + orderArgs[socketPtrOffsetMap[srvPtr.Addr]] = srvPtr.Addr } // log.Println(files) diff --git a/examples/echo.go b/examples/echo.go new file mode 100644 index 0000000..08a4636 --- /dev/null +++ b/examples/echo.go @@ -0,0 +1,67 @@ +package main + +import ( + "io" + "io/ioutil" + "log" + "net" + "net/http" + "os" + "sync" + + "github.com/flier/endless" + "github.com/gorilla/mux" +) + +func handler(w http.ResponseWriter, r *http.Request) { + buf, _ := ioutil.ReadAll(r.Body) + + w.Write(buf) +} + +func main() { + var wg sync.WaitGroup + + wg.Add(2) + + go func() { + defer wg.Done() + + endless.ListenAndServeTCP("localhost:8007", endless.HandleFunc(func(conn net.Conn) { + defer conn.Close() + + var buf [4096]byte + + for { + if n, err := conn.Read(buf[:]); err != nil { + if err != io.EOF { + log.Printf("error, %s", err) + } + + break + } else if _, err := conn.Write(buf[:n]); err != nil { + log.Printf("error, %s", err) + + break + } + } + })) + }() + + go func() { + defer wg.Done() + + mux1 := mux.NewRouter() + mux1.HandleFunc("/", handler).Methods("POST") + + if err := endless.ListenAndServe("localhost:8008", mux1); err != nil { + log.Println(err) + } else { + log.Println("Server on 8007 stopped") + } + }() + + wg.Wait() + + os.Exit(0) +}