Skip to content

Commit

Permalink
server: add upstream connection rebalancing
Browse files Browse the repository at this point in the history
Adds support for rebalancing upstream connections across the cluster.

Every second each node calculates it's connection "balance", which is
the ratio of the number of connections the node has locally to the
cluster average. If a node finds it's balance exceeds the configured
threshold, it will shed connections.

Since the Piko cluster is typically hosted behind a load balancer, when
connections are shed, clients will automatically reconnect to random
node in the cluster, so eventually the nodes connections will be
balanced.

This does mean rebalancing isn't perfect, as an upstream may be
disconnected and reconnect to the node that's already shedding
connections, but it's the best we can do. It will cause minor disruption
while upstreams reconnect, though the alternative is to overload Piko
nodes with no way to rebalance load.

By default rebalancing is disabled.

Fixes #176.
  • Loading branch information
andydunstall committed Feb 8, 2025
1 parent ccc728f commit b8c07b4
Show file tree
Hide file tree
Showing 7 changed files with 291 additions and 8 deletions.
15 changes: 15 additions & 0 deletions server/cluster/state.go
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,21 @@ func (s *State) RemoveRemoteEndpoint(id string, endpointID string) bool {
return true
}

// AvgConns returns the mean number of connections per node in the cluster.
func (s *State) AvgConns() int {
s.mu.RLock()
defer s.mu.RUnlock()

var totalConns int
for _, node := range s.nodes {
for _, conns := range node.Endpoints {
totalConns += conns
}
}

return totalConns / len(s.nodes)
}

func (s *State) Metrics() *Metrics {
return s.metrics
}
Expand Down
106 changes: 106 additions & 0 deletions server/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,99 @@ import (
"github.com/andydunstall/piko/pkg/log"
)

type RebalanceConfig struct {
// Threshold is the threshold indicating when to rebalance (0-1).
//
// Each node will rebalance if its number of local connections exceeds
// the cluster average by the given threshold.
//
// Such as if the threshold is 0.2, the node will rebalance if it has
// over 20% more connections than the cluster average.
//
// Set the threshold to 0 to disable rebalancing.
Threshold float64 `json:"threshold" yaml:"threshold"`

// ShedRate is the percent of connections to drop every second when
// rebalancing (0-1).
//
// Such as if 0.005, the node will drop 0.5% of connections per second until
// it is balanced.
//
// Note the rate is taken as a percent of the average number of connections
// per node in the cluster, rather than the number of connections on the
// local node. This ensures all nodes shed at the same rate.
ShedRate float64 `json:"shed_rate" yaml:"shed_rate"`

// MinConns is the minimum number of local connections the node must have
// before considering rebalancing.
//
// This prevents excess rebalancing when the number of connections is
// too small to matter.
MinConns uint `json:"min_conns" yaml:"min_conns"`
}

func (c *RebalanceConfig) Validate() error {
if c.Threshold < 0 {
return fmt.Errorf("threshold cannot be negative")
}
if c.ShedRate < 0 {
return fmt.Errorf("shed-rate cannot be negative")
}
if c.ShedRate > 1 {
return fmt.Errorf("shed-rate cannot exceed 1")
}
return nil
}

func (c *RebalanceConfig) RegisterFlags(fs *pflag.FlagSet, prefix string) {
if prefix == "" {
prefix = "rebalance."
} else {
prefix = prefix + ".rebalance."
}

fs.Float64Var(
&c.Threshold,
prefix+"threshold",
c.Threshold,
`
The threshold indicating when to rebalance (0-1).
Each node will rebalance if its number of local connections exceeds the cluster
average by the given threshold.
Such as if the threshold is 0.2, the node will rebalance if it has over 20%
more connections than the cluster average.
Set the threshold to 0 to disable rebalancing.`,
)
fs.Float64Var(
&c.ShedRate,
prefix+"shed-rate",
c.ShedRate,
`
The percent of connections to drop every second when rebalancing (0-1).
Such as if 0.005, the node will drop 0.5% of connections per second until it
is balanced.
Note the rate is taken as a percent of the average number of connections per
node in the cluster, rather than the number of connections on the local node.
This ensures all nodes shed at the same rate.`,
)
fs.UintVar(
&c.MinConns,
prefix+"min-conns",
c.MinConns,
`
The minimum number of local connections the node must have before considering
rebalancing.
This prevents excess rebalancing when the number of connections is too small to
matter.`,
)
}

// HTTPConfig contains generic configuration for the HTTP servers.
type HTTPConfig struct {
// ReadTimeout is the maximum duration for reading the entire
Expand Down Expand Up @@ -176,13 +269,18 @@ type UpstreamConfig struct {

Auth auth.Config `json:"auth" yaml:"auth"`

Rebalance RebalanceConfig `json:"rebalance" yaml:"rebalance"`

TLS TLSConfig `json:"tls" yaml:"tls"`
}

func (c *UpstreamConfig) Validate() error {
if c.BindAddr == "" {
return fmt.Errorf("missing bind addr")
}
if err := c.Rebalance.Validate(); err != nil {
return fmt.Errorf("rebalance: %w", err)
}
if err := c.TLS.Validate(); err != nil {
return fmt.Errorf("tls: %w", err)
}
Expand Down Expand Up @@ -219,6 +317,8 @@ advertise address of '10.26.104.14:8000'.`,

c.Auth.RegisterFlags(fs, "upstream")

c.Rebalance.RegisterFlags(fs, "upstream")

c.TLS.RegisterFlags(fs, "upstream")
}

Expand Down Expand Up @@ -432,6 +532,12 @@ func Default() *Config {
},
Upstream: UpstreamConfig{
BindAddr: ":8001",
Rebalance: RebalanceConfig{
// Disable by default.
Threshold: 0,
ShedRate: 0.005,
MinConns: 50,
},
},
Admin: AdminConfig{
BindAddr: ":8002",
Expand Down
18 changes: 18 additions & 0 deletions server/config/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,11 @@ upstream:
audience: my-audience
issuer: my-issuer
rebalance:
threshold: 0.2
shed_rate: 0.005
min_conns: 100
tls:
cert: /piko/cert.pem
key: /piko/key.pem
Expand Down Expand Up @@ -150,6 +155,11 @@ grace_period: 2m
Audience: "my-audience",
Issuer: "my-issuer",
},
Rebalance: RebalanceConfig{
Threshold: 0.2,
ShedRate: 0.005,
MinConns: 100,
},
TLS: TLSConfig{
Cert: "/piko/cert.pem",
Key: "/piko/key.pem",
Expand Down Expand Up @@ -222,6 +232,9 @@ func TestConfig_LoadFlags(t *testing.T) {
"--proxy.tls.key", "/piko/key.pem",
"--upstream.bind-addr", "10.15.104.25:8001",
"--upstream.advertise-addr", "1.2.3.4:8001",
"--upstream.rebalance.threshold", "0.2",
"--upstream.rebalance.shed-rate", "0.005",
"--upstream.rebalance.min-conns", "100",
"--upstream.auth.hmac-secret-key", "hmac-secret-key",
"--upstream.auth.rsa-public-key", "rsa-public-key",
"--upstream.auth.ecdsa-public-key", "ecdsa-public-key",
Expand Down Expand Up @@ -294,6 +307,11 @@ func TestConfig_LoadFlags(t *testing.T) {
Audience: "my-audience",
Issuer: "my-issuer",
},
Rebalance: RebalanceConfig{
Threshold: 0.2,
ShedRate: 0.005,
MinConns: 100,
},
TLS: TLSConfig{
Cert: "/piko/cert.pem",
Key: "/piko/key.pem",
Expand Down
32 changes: 32 additions & 0 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"net"
"strings"
"sync"
"time"

"github.com/hashicorp/go-sockaddr"
"github.com/prometheus/client_golang/prometheus"
Expand Down Expand Up @@ -35,6 +36,9 @@ type Server struct {
upstreamLn net.Listener
upstreamServer *upstream.Server

rebalanceCtx context.Context
rebalanceCancel context.CancelFunc

adminLn net.Listener
adminServer *admin.Server

Expand Down Expand Up @@ -94,6 +98,10 @@ func NewServer(conf *config.Config, logger log.Logger) (*Server, error) {
}
s.upstreamLn = upstreamLn

rebalanceCtx, rebalanceCancel := context.WithCancel(context.Background())
s.rebalanceCtx = rebalanceCtx
s.rebalanceCancel = rebalanceCancel

// Admin listener.

adminLn, err := s.adminListen()
Expand Down Expand Up @@ -155,6 +163,8 @@ func NewServer(conf *config.Config, logger log.Logger) (*Server, error) {
upstreams,
upstreamVerifier,
upstreamTLSConfig,
s.clusterState,
conf.Upstream,
logger,
)

Expand Down Expand Up @@ -391,6 +401,11 @@ func (s *Server) startUpstreamServer() {
s.logger.Error("failed to run upstream server", zap.Error(err))
}
})
if s.conf.Upstream.Rebalance.Threshold != 0 {
s.runGoroutine(func() {
s.upstreamRebalance()
})
}
}

func (s *Server) startAdminServer() {
Expand Down Expand Up @@ -419,6 +434,7 @@ func (s *Server) shutdownUsageReporting() {
}

func (s *Server) shutdownUpstreamServer(ctx context.Context) {
s.rebalanceCancel()
if err := s.upstreamServer.Shutdown(ctx); err != nil {
s.logger.Error("failed to shutdown upstream server", zap.Error(err))
}
Expand Down Expand Up @@ -492,6 +508,22 @@ func (s *Server) adminListen() (net.Listener, error) {
return ln, nil
}

// upstreamRebalance rebalances the upstream server connections every second.
func (s *Server) upstreamRebalance() {
ticker := time.NewTicker(time.Second)
defer ticker.Stop()

// TODO(andydunstall): Handle server shutdown.
for {
select {
case <-ticker.C:
s.upstreamServer.Rebalance()
case <-s.rebalanceCtx.Done():
return
}
}
}

// runGoroutine runs the given function as a background goroutine. If the
// function returns before the server is shutdown, it is considered a fatal
// error and the server is forcefully shutdown.
Expand Down
2 changes: 2 additions & 0 deletions server/upstream/metrics.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ func NewMetrics() *Metrics {
ConnectedUpstreams: prometheus.NewGauge(
prometheus.GaugeOpts{
Namespace: "piko",
// TODO upstream not upstreams
Subsystem: "upstreams",
Name: "connected_upstreams",
Help: "Number of upstreams connected to this node",
Expand Down Expand Up @@ -53,6 +54,7 @@ func NewMetrics() *Metrics {
},
[]string{"node_id"},
),
// TODO rebalancing: prometheus.NewGauge
}
}

Expand Down
Loading

0 comments on commit b8c07b4

Please sign in to comment.