-
Notifications
You must be signed in to change notification settings - Fork 2.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #676 from erikwilson/go-proxy
Add go load-balancing proxy
- Loading branch information
Showing
20 changed files
with
1,659 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
package loadbalancer | ||
|
||
import ( | ||
"encoding/json" | ||
"io/ioutil" | ||
|
||
"github.com/rancher/k3s/pkg/agent/util" | ||
) | ||
|
||
func (lb *LoadBalancer) writeConfig() error { | ||
configOut, err := json.MarshalIndent(lb, "", " ") | ||
if err != nil { | ||
return err | ||
} | ||
if err := util.WriteFile(lb.configFile, string(configOut)); err != nil { | ||
return err | ||
} | ||
return nil | ||
} | ||
|
||
func (lb *LoadBalancer) updateConfig() error { | ||
writeConfig := true | ||
if configBytes, err := ioutil.ReadFile(lb.configFile); err == nil { | ||
config := &LoadBalancer{} | ||
if err := json.Unmarshal(configBytes, config); err == nil { | ||
if config.ServerURL == lb.ServerURL { | ||
writeConfig = false | ||
lb.setServers(config.ServerAddresses) | ||
} | ||
} | ||
} | ||
if writeConfig { | ||
if err := lb.writeConfig(); err != nil { | ||
return err | ||
} | ||
} | ||
return nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,142 @@ | ||
package loadbalancer | ||
|
||
import ( | ||
"context" | ||
"errors" | ||
"net" | ||
"path/filepath" | ||
"sync" | ||
|
||
"github.com/google/tcpproxy" | ||
"github.com/rancher/k3s/pkg/cli/cmds" | ||
"github.com/sirupsen/logrus" | ||
) | ||
|
||
type LoadBalancer struct { | ||
mutex sync.Mutex | ||
dialer *net.Dialer | ||
proxy *tcpproxy.Proxy | ||
|
||
configFile string | ||
localAddress string | ||
localServerURL string | ||
originalServerAddress string | ||
ServerURL string | ||
ServerAddresses []string | ||
randomServers []string | ||
currentServerAddress string | ||
nextServerIndex int | ||
} | ||
|
||
const ( | ||
serviceName = "k3s-agent-load-balancer" | ||
) | ||
|
||
func Setup(ctx context.Context, cfg cmds.Agent) (_lb *LoadBalancer, _err error) { | ||
if cfg.DisableLoadBalancer { | ||
return nil, nil | ||
} | ||
|
||
listener, err := net.Listen("tcp", "127.0.0.1:0") | ||
defer func() { | ||
if _err != nil { | ||
logrus.Warnf("Error starting load balancer: %s", _err) | ||
if listener != nil { | ||
listener.Close() | ||
} | ||
} | ||
}() | ||
if err != nil { | ||
return nil, err | ||
} | ||
localAddress := listener.Addr().String() | ||
|
||
originalServerAddress, localServerURL, err := parseURL(cfg.ServerURL, localAddress) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
lb := &LoadBalancer{ | ||
dialer: &net.Dialer{}, | ||
configFile: filepath.Join(cfg.DataDir, "etc", serviceName+".json"), | ||
localAddress: localAddress, | ||
localServerURL: localServerURL, | ||
originalServerAddress: originalServerAddress, | ||
ServerURL: cfg.ServerURL, | ||
} | ||
|
||
lb.setServers([]string{lb.originalServerAddress}) | ||
|
||
lb.proxy = &tcpproxy.Proxy{ | ||
ListenFunc: func(string, string) (net.Listener, error) { | ||
return listener, nil | ||
}, | ||
} | ||
lb.proxy.AddRoute(serviceName, &tcpproxy.DialProxy{ | ||
Addr: serviceName, | ||
DialContext: lb.dialContext, | ||
}) | ||
|
||
if err := lb.updateConfig(); err != nil { | ||
return nil, err | ||
} | ||
if err := lb.proxy.Start(); err != nil { | ||
return nil, err | ||
} | ||
logrus.Infof("Running load balancer %s -> %v", lb.localAddress, lb.randomServers) | ||
|
||
return lb, nil | ||
} | ||
|
||
func (lb *LoadBalancer) Update(serverAddresses []string) { | ||
if lb == nil { | ||
return | ||
} | ||
if !lb.setServers(serverAddresses) { | ||
return | ||
} | ||
logrus.Infof("Updating load balancer server addresses -> %v", lb.randomServers) | ||
|
||
if err := lb.writeConfig(); err != nil { | ||
logrus.Warnf("Error updating load balancer config: %s", err) | ||
} | ||
} | ||
|
||
func (lb *LoadBalancer) LoadBalancerServerURL() string { | ||
if lb == nil { | ||
return "" | ||
} | ||
return lb.localServerURL | ||
} | ||
|
||
func (lb *LoadBalancer) dialContext(ctx context.Context, network, address string) (net.Conn, error) { | ||
startIndex := lb.nextServerIndex | ||
for { | ||
targetServer := lb.currentServerAddress | ||
|
||
conn, err := lb.dialer.DialContext(ctx, network, targetServer) | ||
if err == nil { | ||
return conn, nil | ||
} | ||
logrus.Warnf("Dial error from load balancer: %s", err) | ||
|
||
newServer, err := lb.nextServer(targetServer) | ||
if err != nil { | ||
return nil, err | ||
} | ||
if targetServer != newServer { | ||
logrus.Warnf("Dial context in load balancer failed over to %s", newServer) | ||
} | ||
if ctx.Err() != nil { | ||
return nil, ctx.Err() | ||
} | ||
|
||
maxIndex := len(lb.randomServers) | ||
if startIndex > maxIndex { | ||
startIndex = maxIndex | ||
} | ||
if lb.nextServerIndex == startIndex { | ||
return nil, errors.New("all servers failed") | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,183 @@ | ||
package loadbalancer | ||
|
||
import ( | ||
"bufio" | ||
"context" | ||
"errors" | ||
"fmt" | ||
"io/ioutil" | ||
"net" | ||
"net/url" | ||
"os" | ||
"strings" | ||
"testing" | ||
"time" | ||
|
||
"github.com/rancher/k3s/pkg/cli/cmds" | ||
) | ||
|
||
type server struct { | ||
listener net.Listener | ||
conns []net.Conn | ||
prefix string | ||
} | ||
|
||
func createServer(prefix string) (*server, error) { | ||
listener, err := net.Listen("tcp", "127.0.0.1:0") | ||
if err != nil { | ||
return nil, err | ||
} | ||
s := &server{ | ||
prefix: prefix, | ||
listener: listener, | ||
} | ||
go s.serve() | ||
return s, nil | ||
} | ||
|
||
func (s *server) serve() { | ||
for { | ||
conn, err := s.listener.Accept() | ||
if err != nil { | ||
return | ||
} | ||
s.conns = append(s.conns, conn) | ||
go s.echo(conn) | ||
} | ||
} | ||
|
||
func (s *server) close() { | ||
s.listener.Close() | ||
for _, conn := range s.conns { | ||
conn.Close() | ||
} | ||
} | ||
|
||
func (s *server) echo(conn net.Conn) { | ||
for { | ||
result, err := bufio.NewReader(conn).ReadString('\n') | ||
if err != nil { | ||
return | ||
} | ||
conn.Write([]byte(s.prefix + ":" + result)) | ||
} | ||
} | ||
|
||
func ping(conn net.Conn) (string, error) { | ||
fmt.Fprintf(conn, "ping\n") | ||
result, err := bufio.NewReader(conn).ReadString('\n') | ||
if err != nil { | ||
return "", err | ||
} | ||
return strings.TrimSpace(result), nil | ||
} | ||
|
||
func assertEqual(t *testing.T, a interface{}, b interface{}) { | ||
if a != b { | ||
t.Fatalf("[ %v != %v ]", a, b) | ||
} | ||
} | ||
|
||
func assertNotEqual(t *testing.T, a interface{}, b interface{}) { | ||
if a == b { | ||
t.Fatalf("[ %v == %v ]", a, b) | ||
} | ||
} | ||
|
||
func TestFailOver(t *testing.T) { | ||
tmpDir, err := ioutil.TempDir("", "lb-test") | ||
if err != nil { | ||
assertEqual(t, err, nil) | ||
} | ||
defer os.RemoveAll(tmpDir) | ||
|
||
ogServe, err := createServer("og") | ||
if err != nil { | ||
assertEqual(t, err, nil) | ||
} | ||
|
||
lbServe, err := createServer("lb") | ||
if err != nil { | ||
assertEqual(t, err, nil) | ||
} | ||
|
||
cfg := cmds.Agent{ | ||
ServerURL: fmt.Sprintf("http://%s/", ogServe.listener.Addr().String()), | ||
DataDir: tmpDir, | ||
} | ||
|
||
lb, err := Setup(context.Background(), cfg) | ||
if err != nil { | ||
assertEqual(t, err, nil) | ||
} | ||
|
||
parsedURL, err := url.Parse(lb.LoadBalancerServerURL()) | ||
if err != nil { | ||
assertEqual(t, err, nil) | ||
} | ||
localAddress := parsedURL.Host | ||
|
||
lb.Update([]string{lbServe.listener.Addr().String()}) | ||
|
||
conn1, err := net.Dial("tcp", localAddress) | ||
if err != nil { | ||
assertEqual(t, err, nil) | ||
} | ||
result1, err := ping(conn1) | ||
if err != nil { | ||
assertEqual(t, err, nil) | ||
} | ||
assertEqual(t, result1, "lb:ping") | ||
|
||
lbServe.close() | ||
|
||
_, err = ping(conn1) | ||
assertNotEqual(t, err, nil) | ||
|
||
conn2, err := net.Dial("tcp", localAddress) | ||
if err != nil { | ||
assertEqual(t, err, nil) | ||
} | ||
result2, err := ping(conn2) | ||
if err != nil { | ||
assertEqual(t, err, nil) | ||
} | ||
assertEqual(t, result2, "og:ping") | ||
} | ||
|
||
func TestFailFast(t *testing.T) { | ||
tmpDir, err := ioutil.TempDir("", "lb-test") | ||
if err != nil { | ||
assertEqual(t, err, nil) | ||
} | ||
defer os.RemoveAll(tmpDir) | ||
|
||
cfg := cmds.Agent{ | ||
ServerURL: "http://127.0.0.1:-1/", | ||
DataDir: tmpDir, | ||
} | ||
|
||
lb, err := Setup(context.Background(), cfg) | ||
if err != nil { | ||
assertEqual(t, err, nil) | ||
} | ||
|
||
conn, err := net.Dial("tcp", lb.localAddress) | ||
if err != nil { | ||
assertEqual(t, err, nil) | ||
} | ||
|
||
done := make(chan error) | ||
go func() { | ||
_, err = ping(conn) | ||
done <- err | ||
}() | ||
timeout := time.After(10 * time.Millisecond) | ||
|
||
select { | ||
case err := <-done: | ||
assertNotEqual(t, err, nil) | ||
case <-timeout: | ||
t.Fatal(errors.New("time out")) | ||
} | ||
} |
Oops, something went wrong.