Skip to content

Commit

Permalink
servers: add support for UDP servers and implement QUIC server
Browse files Browse the repository at this point in the history
  • Loading branch information
janos committed Jul 30, 2018
1 parent e0c16b0 commit 87cf624
Show file tree
Hide file tree
Showing 7 changed files with 204 additions and 37 deletions.
12 changes: 12 additions & 0 deletions servers/grpc/grpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,15 @@ package grpcServer

import (
"context"
"net"

"google.golang.org/grpc"
"resenje.org/web/servers"
)

var (
_ servers.Server = new(Server)
_ servers.TCPServer = new(Server)
)

// Server wraps grpc.Server to provide methods for
Expand All @@ -24,6 +31,11 @@ func New(server *grpc.Server) (s *Server) {
}
}

// ServeTCP serves request on TCP listener.
func (s *Server) ServeTCP(ln net.Listener) (err error) {
return s.Server.Serve(ln)
}

// Close executes grpc.Server.Stop method.
func (s *Server) Close() (err error) {
s.Server.Stop()
Expand Down
6 changes: 3 additions & 3 deletions servers/grpc/grpc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ func TestServer(t *testing.T) {
addr := "localhost:" + strconv.Itoa(ln.Addr().(*net.TCPAddr).Port)

go func() {
if err := s.Serve(ln); err != nil {
if err := s.ServeTCP(ln); err != nil {
panic(err)
}
}()
Expand Down Expand Up @@ -75,7 +75,7 @@ func TestServerShutdown(t *testing.T) {
addr := "localhost:" + strconv.Itoa(ln.Addr().(*net.TCPAddr).Port)

go func() {
if err := s.Serve(ln); err != nil {
if err := s.ServeTCP(ln); err != nil {
if e, ok := err.(*net.OpError); !(ok && e.Op == "accept") {
panic(err)
}
Expand Down Expand Up @@ -123,7 +123,7 @@ func TestServerClose(t *testing.T) {
addr := "localhost:" + strconv.Itoa(ln.Addr().(*net.TCPAddr).Port)

go func() {
if err := s.Serve(ln); err != nil {
if err := s.ServeTCP(ln); err != nil {
if e, ok := err.(*net.OpError); !(ok && e.Op == "accept") {
panic(err)
}
Expand Down
11 changes: 9 additions & 2 deletions servers/http/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,13 @@ import (
"net"
"net/http"
"time"

"resenje.org/web/servers"
)

var (
_ servers.Server = new(Server)
_ servers.TCPServer = new(Server)
)

// Options struct holds parameters that can be configure using
Expand Down Expand Up @@ -47,11 +54,11 @@ func New(handler http.Handler, opts ...Option) (s *Server) {
return
}

// Serve executes http.Server.Serve method.
// ServeTCP executes http.Server.Serve method.
// If the provided listener is net.TCPListener, keep alive
// will be enabled. If server is configured with TLS,
// a tls.Listener will be created with provided listener.
func (s *Server) Serve(ln net.Listener) (err error) {
func (s *Server) ServeTCP(ln net.Listener) (err error) {
if l, ok := ln.(*net.TCPListener); ok {
ln = tcpKeepAliveListener{TCPListener: l}
}
Expand Down
4 changes: 2 additions & 2 deletions servers/http/http_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ func TestServer(t *testing.T) {
addr := "http://localhost:" + strconv.Itoa(ln.Addr().(*net.TCPAddr).Port)

go func() {
if err := s.Serve(ln); err != nil && err != http.ErrServerClosed {
if err := s.ServeTCP(ln); err != nil && err != http.ErrServerClosed {
panic(err)
}
}()
Expand Down Expand Up @@ -102,7 +102,7 @@ viBngkOY/zwTS9mYvM8ixsj16b2WWzajtjhBtihs+tur
addr := "https://localhost:" + strconv.Itoa(ln.Addr().(*net.TCPAddr).Port)

go func() {
if err := s.Serve(ln); err != nil && err != http.ErrServerClosed {
if err := s.ServeTCP(ln); err != nil && err != http.ErrServerClosed {
panic(err)
}
}()
Expand Down
78 changes: 78 additions & 0 deletions servers/quic/quic.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
// Copyright (c) 2018, Janoš Guljaš <[email protected]>
// All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package quicServer

import (
"context"
"crypto/tls"
"net"
"net/http"

"github.com/lucas-clemente/quic-go/h2quic"
"resenje.org/web/servers"
)

var (
_ servers.Server = new(Server)
_ servers.UDPServer = new(Server)
)

// Options struct holds parameters that can be configure using
// functions with prefix With.
type Options struct {
tlsConfig *tls.Config
}

// Option is a function that sets optional parameters for
// the Server.
type Option func(*Options)

// WithTLSConfig sets a TLS configuration for the HTTP server
// and creates a TLS listener.
func WithTLSConfig(tlsConfig *tls.Config) Option { return func(o *Options) { o.tlsConfig = tlsConfig } }

// Server wraps h2quic.Server to provide methods for
// resenje.org/web/servers.Server interface.
type Server struct {
*h2quic.Server
}

// New creates a new instance of Server.
func New(handler http.Handler, opts ...Option) (s *Server) {
o := &Options{}
for _, opt := range opts {
opt(o)
}
s = &Server{
Server: &h2quic.Server{
Server: &http.Server{
Handler: handler,
TLSConfig: o.tlsConfig,
},
},
}
return
}

// ServeUDP serves requests over UDP connection.
func (s *Server) ServeUDP(conn *net.UDPConn) (err error) {
s.Server.Server.Addr = conn.LocalAddr().String()
return s.Server.Serve(conn)
}

// Shutdown calls h2quic.Server.Close method.
func (s *Server) Shutdown(_ context.Context) (err error) {
return s.Server.Close()
}

// QuicHeadersHandler should be used as a middleware to set
// quic related headers to TCP server that suggest alternative svc.
func (s *Server) QuicHeadersHandler(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
s.SetQuicHeaders(w.Header())
h.ServeHTTP(w, r)
})
}
120 changes: 95 additions & 25 deletions servers/servers.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,9 @@ func New(opts ...Option) (s *Servers) {

// Server defines required methods for a type that can be added to
// the Servers.
// In addition to this methods, a Server should implement TCPServer
// or UDPServer to be able to serve requests.
type Server interface {
// Serve should start server responding to requests.
// The listener is initialized and already listening.
Serve(ln net.Listener) error
// Close should stop server from serving all existing requests
// and stop accepting new ones.
// The listener provided in Serve method must stop listening.
Expand All @@ -80,11 +79,26 @@ type Server interface {
Shutdown(ctx context.Context) error
}

// TCPServer defines methods for a server that accepts requests
// over TCP listener.
type TCPServer interface {
// Serve should start server responding to requests.
// The listener is initialized and already listening.
ServeTCP(ln net.Listener) error
}

// UDPServer defines methods for a server that accepts requests
// over UDP listener.
type UDPServer interface {
ServeUDP(conn *net.UDPConn) error
}

type server struct {
Server
name string
address string
tcpAddr *net.TCPAddr
udpAddr *net.UDPAddr
}

func (s *server) label() string {
Expand All @@ -94,6 +108,16 @@ func (s *server) label() string {
return s.name + " server"
}

func (s *server) isTCP() (srv TCPServer, yes bool) {
srv, yes = s.Server.(TCPServer)
return
}

func (s *server) isUDP() (srv UDPServer, yes bool) {
srv, yes = s.Server.(UDPServer)
return
}

// Add adds a new server instance by a custom name and with
// address to listen to.
func (s *Servers) Add(name, address string, srv Server) {
Expand All @@ -110,43 +134,73 @@ func (s *Servers) Add(name, address string, srv Server) {
// New new servers must be added after this methid is called.
func (s *Servers) Serve() (err error) {
lns := make([]net.Listener, len(s.servers))
conns := make([]*net.UDPConn, len(s.servers))
for i, srv := range s.servers {
ln, err := net.Listen("tcp", srv.address)
if err != nil {
for _, l := range lns {
if l == nil {
continue
}
if err := l.Close(); err != nil {
s.logger.Errorf("%s listener %q close: %v", srv.label(), srv.address, err)
if _, yes := srv.isTCP(); yes {
ln, err := net.Listen("tcp", srv.address)
if err != nil {
for _, l := range lns {
if l == nil {
continue
}
if err := l.Close(); err != nil {
s.logger.Errorf("%s tcp listener %q close: %v", srv.label(), srv.address, err)
}
}
return fmt.Errorf("%s tcp listener %q: %v", srv.label(), srv.address, err)
}
return fmt.Errorf("%s listener %q: %v", srv.label(), srv.address, err)
lns[i] = ln
}
if _, yes := srv.isUDP(); yes {
addr, err := net.ResolveUDPAddr("udp", srv.address)
if err != nil {
return fmt.Errorf("%s resolve udp address %q: %v", srv.label(), srv.address, err)
}
conn, err := net.ListenUDP("udp", addr)
if err != nil {
return fmt.Errorf("%s udp listener %q: %v", srv.label(), srv.address, err)
}
conns[i] = conn
}
lns[i] = ln
}
for i, srv := range s.servers {
go func(srv *server, ln net.Listener) {
defer s.recover()
if tcpSrv, yes := srv.isTCP(); yes {
go func(srv *server, ln net.Listener) {
defer s.recover()

s.mu.Lock()
srv.tcpAddr = ln.Addr().(*net.TCPAddr)
s.mu.Unlock()
s.mu.Lock()
srv.tcpAddr = ln.Addr().(*net.TCPAddr)
s.mu.Unlock()

s.logger.Infof("%s listening on %q", srv.label(), srv.tcpAddr.String())
if err = srv.Serve(ln); err != nil {
s.logger.Errorf("%s serve %q: %v", srv.label(), srv.tcpAddr.String(), err)
}
}(srv, lns[i])
s.logger.Infof("%s listening on %q", srv.label(), srv.tcpAddr.String())
if err = tcpSrv.ServeTCP(ln); err != nil {
s.logger.Errorf("%s serve %q: %v", srv.label(), srv.tcpAddr.String(), err)
}
}(srv, lns[i])
}
if udpSrv, yes := srv.isUDP(); yes {
go func(srv *server, conn *net.UDPConn) {
defer s.recover()

s.mu.Lock()
srv.udpAddr = conn.LocalAddr().(*net.UDPAddr)
s.mu.Unlock()

s.logger.Infof("%s listening on %q", srv.label(), srv.tcpAddr.String())
if err = udpSrv.ServeUDP(conn); err != nil {
s.logger.Errorf("%s serve %q: %v", srv.label(), srv.tcpAddr.String(), err)
}
}(srv, conns[i])
}
}
return nil
}

// Addr returns a TCP address of the listener that a server
// TCPAddr returns a TCP address of the listener that a server
// with a specific name is using. If there are more servers
// with the same name, the address of the first started server
// is returned.
func (s *Servers) Addr(name string) (a *net.TCPAddr) {
func (s *Servers) TCPAddr(name string) (a *net.TCPAddr) {
s.mu.Lock()
defer s.mu.Unlock()

Expand All @@ -158,6 +212,22 @@ func (s *Servers) Addr(name string) (a *net.TCPAddr) {
return nil
}

// UDPAddr returns a UDP address of the listener that a server
// with a specific name is using. If there are more servers
// with the same name, the address of the first started server
// is returned.
func (s *Servers) UDPAddr(name string) (a *net.UDPAddr) {
s.mu.Lock()
defer s.mu.Unlock()

for _, srv := range s.servers {
if srv.name == name {
return srv.udpAddr
}
}
return nil
}

// Close stops all servers, by calling Close method on each of them.
func (s *Servers) Close() {
wg := &sync.WaitGroup{}
Expand Down
10 changes: 5 additions & 5 deletions servers/servers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ func newMockServer() *mockServer {
}
}

func (s *mockServer) Serve(ln net.Listener) error {
func (s *mockServer) ServeTCP(ln net.Listener) error {
s.ln = ln
s.serving <- struct{}{}
if s.fail {
Expand Down Expand Up @@ -164,7 +164,7 @@ func newPanicServer() *panicServer {
}
}

func (s *panicServer) Serve(_ net.Listener) error {
func (s *panicServer) ServeTCP(_ net.Listener) error {
s.serving <- struct{}{}
panic("")
}
Expand Down Expand Up @@ -366,7 +366,7 @@ func TestServerFailure(t *testing.T) {
}
}

func TestServerAddr(t *testing.T) {
func TestServerTCPAddr(t *testing.T) {
var buf Buffer
log.SetOutput(&buf)

Expand All @@ -382,12 +382,12 @@ func TestServerAddr(t *testing.T) {

<-m.serving

a := s.Addr("mock").String()
a := s.TCPAddr("mock").String()
if a != m.ln.Addr().String() {
t.Errorf("got %q, expected %q", a, m.ln.Addr().String())
}

u := s.Addr("unknown")
u := s.TCPAddr("unknown")
if u != nil {
t.Errorf("got %v, expected %v", u, nil)
}
Expand Down

0 comments on commit 87cf624

Please sign in to comment.