From 87cf624af785fa7c2f8e6dd5ba0d26b1e337da6d Mon Sep 17 00:00:00 2001 From: Janos Guljas Date: Sun, 29 Jul 2018 15:56:24 +0200 Subject: [PATCH] servers: add support for UDP servers and implement QUIC server --- servers/grpc/grpc.go | 12 ++++ servers/grpc/grpc_test.go | 6 +- servers/http/http.go | 11 +++- servers/http/http_test.go | 4 +- servers/quic/quic.go | 78 +++++++++++++++++++++++++ servers/servers.go | 120 ++++++++++++++++++++++++++++++-------- servers/servers_test.go | 10 ++-- 7 files changed, 204 insertions(+), 37 deletions(-) create mode 100644 servers/quic/quic.go diff --git a/servers/grpc/grpc.go b/servers/grpc/grpc.go index 1115b89..12268c9 100644 --- a/servers/grpc/grpc.go +++ b/servers/grpc/grpc.go @@ -7,8 +7,15 @@ package grpcServer import ( "context" + "net" "google.golang.org/grpc" + "resenje.org/web/servers" +) + +var ( + _ servers.Server = new(Server) + _ servers.TCPServer = new(Server) ) // Server wraps grpc.Server to provide methods for @@ -24,6 +31,11 @@ func New(server *grpc.Server) (s *Server) { } } +// ServeTCP serves request on TCP listener. +func (s *Server) ServeTCP(ln net.Listener) (err error) { + return s.Server.Serve(ln) +} + // Close executes grpc.Server.Stop method. func (s *Server) Close() (err error) { s.Server.Stop() diff --git a/servers/grpc/grpc_test.go b/servers/grpc/grpc_test.go index 401b6c6..cb225d6 100644 --- a/servers/grpc/grpc_test.go +++ b/servers/grpc/grpc_test.go @@ -36,7 +36,7 @@ func TestServer(t *testing.T) { addr := "localhost:" + strconv.Itoa(ln.Addr().(*net.TCPAddr).Port) go func() { - if err := s.Serve(ln); err != nil { + if err := s.ServeTCP(ln); err != nil { panic(err) } }() @@ -75,7 +75,7 @@ func TestServerShutdown(t *testing.T) { addr := "localhost:" + strconv.Itoa(ln.Addr().(*net.TCPAddr).Port) go func() { - if err := s.Serve(ln); err != nil { + if err := s.ServeTCP(ln); err != nil { if e, ok := err.(*net.OpError); !(ok && e.Op == "accept") { panic(err) } @@ -123,7 +123,7 @@ func TestServerClose(t *testing.T) { addr := "localhost:" + strconv.Itoa(ln.Addr().(*net.TCPAddr).Port) go func() { - if err := s.Serve(ln); err != nil { + if err := s.ServeTCP(ln); err != nil { if e, ok := err.(*net.OpError); !(ok && e.Op == "accept") { panic(err) } diff --git a/servers/http/http.go b/servers/http/http.go index 2b77adb..d82d1f1 100644 --- a/servers/http/http.go +++ b/servers/http/http.go @@ -10,6 +10,13 @@ import ( "net" "net/http" "time" + + "resenje.org/web/servers" +) + +var ( + _ servers.Server = new(Server) + _ servers.TCPServer = new(Server) ) // Options struct holds parameters that can be configure using @@ -47,11 +54,11 @@ func New(handler http.Handler, opts ...Option) (s *Server) { return } -// Serve executes http.Server.Serve method. +// ServeTCP executes http.Server.Serve method. // If the provided listener is net.TCPListener, keep alive // will be enabled. If server is configured with TLS, // a tls.Listener will be created with provided listener. -func (s *Server) Serve(ln net.Listener) (err error) { +func (s *Server) ServeTCP(ln net.Listener) (err error) { if l, ok := ln.(*net.TCPListener); ok { ln = tcpKeepAliveListener{TCPListener: l} } diff --git a/servers/http/http_test.go b/servers/http/http_test.go index 5bb1351..d28447a 100644 --- a/servers/http/http_test.go +++ b/servers/http/http_test.go @@ -31,7 +31,7 @@ func TestServer(t *testing.T) { addr := "http://localhost:" + strconv.Itoa(ln.Addr().(*net.TCPAddr).Port) go func() { - if err := s.Serve(ln); err != nil && err != http.ErrServerClosed { + if err := s.ServeTCP(ln); err != nil && err != http.ErrServerClosed { panic(err) } }() @@ -102,7 +102,7 @@ viBngkOY/zwTS9mYvM8ixsj16b2WWzajtjhBtihs+tur addr := "https://localhost:" + strconv.Itoa(ln.Addr().(*net.TCPAddr).Port) go func() { - if err := s.Serve(ln); err != nil && err != http.ErrServerClosed { + if err := s.ServeTCP(ln); err != nil && err != http.ErrServerClosed { panic(err) } }() diff --git a/servers/quic/quic.go b/servers/quic/quic.go new file mode 100644 index 0000000..3b86857 --- /dev/null +++ b/servers/quic/quic.go @@ -0,0 +1,78 @@ +// Copyright (c) 2018, Janoš Guljaš +// All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package quicServer + +import ( + "context" + "crypto/tls" + "net" + "net/http" + + "github.com/lucas-clemente/quic-go/h2quic" + "resenje.org/web/servers" +) + +var ( + _ servers.Server = new(Server) + _ servers.UDPServer = new(Server) +) + +// Options struct holds parameters that can be configure using +// functions with prefix With. +type Options struct { + tlsConfig *tls.Config +} + +// Option is a function that sets optional parameters for +// the Server. +type Option func(*Options) + +// WithTLSConfig sets a TLS configuration for the HTTP server +// and creates a TLS listener. +func WithTLSConfig(tlsConfig *tls.Config) Option { return func(o *Options) { o.tlsConfig = tlsConfig } } + +// Server wraps h2quic.Server to provide methods for +// resenje.org/web/servers.Server interface. +type Server struct { + *h2quic.Server +} + +// New creates a new instance of Server. +func New(handler http.Handler, opts ...Option) (s *Server) { + o := &Options{} + for _, opt := range opts { + opt(o) + } + s = &Server{ + Server: &h2quic.Server{ + Server: &http.Server{ + Handler: handler, + TLSConfig: o.tlsConfig, + }, + }, + } + return +} + +// ServeUDP serves requests over UDP connection. +func (s *Server) ServeUDP(conn *net.UDPConn) (err error) { + s.Server.Server.Addr = conn.LocalAddr().String() + return s.Server.Serve(conn) +} + +// Shutdown calls h2quic.Server.Close method. +func (s *Server) Shutdown(_ context.Context) (err error) { + return s.Server.Close() +} + +// QuicHeadersHandler should be used as a middleware to set +// quic related headers to TCP server that suggest alternative svc. +func (s *Server) QuicHeadersHandler(h http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + s.SetQuicHeaders(w.Header()) + h.ServeHTTP(w, r) + }) +} diff --git a/servers/servers.go b/servers/servers.go index d5ecf27..3a054ce 100644 --- a/servers/servers.go +++ b/servers/servers.go @@ -65,10 +65,9 @@ func New(opts ...Option) (s *Servers) { // Server defines required methods for a type that can be added to // the Servers. +// In addition to this methods, a Server should implement TCPServer +// or UDPServer to be able to serve requests. type Server interface { - // Serve should start server responding to requests. - // The listener is initialized and already listening. - Serve(ln net.Listener) error // Close should stop server from serving all existing requests // and stop accepting new ones. // The listener provided in Serve method must stop listening. @@ -80,11 +79,26 @@ type Server interface { Shutdown(ctx context.Context) error } +// TCPServer defines methods for a server that accepts requests +// over TCP listener. +type TCPServer interface { + // Serve should start server responding to requests. + // The listener is initialized and already listening. + ServeTCP(ln net.Listener) error +} + +// UDPServer defines methods for a server that accepts requests +// over UDP listener. +type UDPServer interface { + ServeUDP(conn *net.UDPConn) error +} + type server struct { Server name string address string tcpAddr *net.TCPAddr + udpAddr *net.UDPAddr } func (s *server) label() string { @@ -94,6 +108,16 @@ func (s *server) label() string { return s.name + " server" } +func (s *server) isTCP() (srv TCPServer, yes bool) { + srv, yes = s.Server.(TCPServer) + return +} + +func (s *server) isUDP() (srv UDPServer, yes bool) { + srv, yes = s.Server.(UDPServer) + return +} + // Add adds a new server instance by a custom name and with // address to listen to. func (s *Servers) Add(name, address string, srv Server) { @@ -110,43 +134,73 @@ func (s *Servers) Add(name, address string, srv Server) { // New new servers must be added after this methid is called. func (s *Servers) Serve() (err error) { lns := make([]net.Listener, len(s.servers)) + conns := make([]*net.UDPConn, len(s.servers)) for i, srv := range s.servers { - ln, err := net.Listen("tcp", srv.address) - if err != nil { - for _, l := range lns { - if l == nil { - continue - } - if err := l.Close(); err != nil { - s.logger.Errorf("%s listener %q close: %v", srv.label(), srv.address, err) + if _, yes := srv.isTCP(); yes { + ln, err := net.Listen("tcp", srv.address) + if err != nil { + for _, l := range lns { + if l == nil { + continue + } + if err := l.Close(); err != nil { + s.logger.Errorf("%s tcp listener %q close: %v", srv.label(), srv.address, err) + } } + return fmt.Errorf("%s tcp listener %q: %v", srv.label(), srv.address, err) } - return fmt.Errorf("%s listener %q: %v", srv.label(), srv.address, err) + lns[i] = ln + } + if _, yes := srv.isUDP(); yes { + addr, err := net.ResolveUDPAddr("udp", srv.address) + if err != nil { + return fmt.Errorf("%s resolve udp address %q: %v", srv.label(), srv.address, err) + } + conn, err := net.ListenUDP("udp", addr) + if err != nil { + return fmt.Errorf("%s udp listener %q: %v", srv.label(), srv.address, err) + } + conns[i] = conn } - lns[i] = ln } for i, srv := range s.servers { - go func(srv *server, ln net.Listener) { - defer s.recover() + if tcpSrv, yes := srv.isTCP(); yes { + go func(srv *server, ln net.Listener) { + defer s.recover() - s.mu.Lock() - srv.tcpAddr = ln.Addr().(*net.TCPAddr) - s.mu.Unlock() + s.mu.Lock() + srv.tcpAddr = ln.Addr().(*net.TCPAddr) + s.mu.Unlock() - s.logger.Infof("%s listening on %q", srv.label(), srv.tcpAddr.String()) - if err = srv.Serve(ln); err != nil { - s.logger.Errorf("%s serve %q: %v", srv.label(), srv.tcpAddr.String(), err) - } - }(srv, lns[i]) + s.logger.Infof("%s listening on %q", srv.label(), srv.tcpAddr.String()) + if err = tcpSrv.ServeTCP(ln); err != nil { + s.logger.Errorf("%s serve %q: %v", srv.label(), srv.tcpAddr.String(), err) + } + }(srv, lns[i]) + } + if udpSrv, yes := srv.isUDP(); yes { + go func(srv *server, conn *net.UDPConn) { + defer s.recover() + + s.mu.Lock() + srv.udpAddr = conn.LocalAddr().(*net.UDPAddr) + s.mu.Unlock() + + s.logger.Infof("%s listening on %q", srv.label(), srv.tcpAddr.String()) + if err = udpSrv.ServeUDP(conn); err != nil { + s.logger.Errorf("%s serve %q: %v", srv.label(), srv.tcpAddr.String(), err) + } + }(srv, conns[i]) + } } return nil } -// Addr returns a TCP address of the listener that a server +// TCPAddr returns a TCP address of the listener that a server // with a specific name is using. If there are more servers // with the same name, the address of the first started server // is returned. -func (s *Servers) Addr(name string) (a *net.TCPAddr) { +func (s *Servers) TCPAddr(name string) (a *net.TCPAddr) { s.mu.Lock() defer s.mu.Unlock() @@ -158,6 +212,22 @@ func (s *Servers) Addr(name string) (a *net.TCPAddr) { return nil } +// UDPAddr returns a UDP address of the listener that a server +// with a specific name is using. If there are more servers +// with the same name, the address of the first started server +// is returned. +func (s *Servers) UDPAddr(name string) (a *net.UDPAddr) { + s.mu.Lock() + defer s.mu.Unlock() + + for _, srv := range s.servers { + if srv.name == name { + return srv.udpAddr + } + } + return nil +} + // Close stops all servers, by calling Close method on each of them. func (s *Servers) Close() { wg := &sync.WaitGroup{} diff --git a/servers/servers_test.go b/servers/servers_test.go index 81a4b5d..5ce0fb9 100644 --- a/servers/servers_test.go +++ b/servers/servers_test.go @@ -39,7 +39,7 @@ func newMockServer() *mockServer { } } -func (s *mockServer) Serve(ln net.Listener) error { +func (s *mockServer) ServeTCP(ln net.Listener) error { s.ln = ln s.serving <- struct{}{} if s.fail { @@ -164,7 +164,7 @@ func newPanicServer() *panicServer { } } -func (s *panicServer) Serve(_ net.Listener) error { +func (s *panicServer) ServeTCP(_ net.Listener) error { s.serving <- struct{}{} panic("") } @@ -366,7 +366,7 @@ func TestServerFailure(t *testing.T) { } } -func TestServerAddr(t *testing.T) { +func TestServerTCPAddr(t *testing.T) { var buf Buffer log.SetOutput(&buf) @@ -382,12 +382,12 @@ func TestServerAddr(t *testing.T) { <-m.serving - a := s.Addr("mock").String() + a := s.TCPAddr("mock").String() if a != m.ln.Addr().String() { t.Errorf("got %q, expected %q", a, m.ln.Addr().String()) } - u := s.Addr("unknown") + u := s.TCPAddr("unknown") if u != nil { t.Errorf("got %v, expected %v", u, nil) }