From 3c679293587ab2845b234cc3f0271e8e58203176 Mon Sep 17 00:00:00 2001 From: Bart Smykla Date: Mon, 14 Jan 2019 15:43:20 +0000 Subject: [PATCH] Added reconnection logic when NATS is disconnected How was it tested: I deployed this version of nats-queue-worker and then simulated NATS disconnection Signed-off-by: Bart Smykla --- handler/handler.go | 70 +++++++++++++++++++++++++++++++++++------- handler/nats_config.go | 20 ++++++++++++ 2 files changed, 79 insertions(+), 11 deletions(-) diff --git a/handler/handler.go b/handler/handler.go index 5eef9d5..c7e1f06 100644 --- a/handler/handler.go +++ b/handler/handler.go @@ -4,6 +4,8 @@ import ( "encoding/json" "fmt" "log" + "sync" + "time" "github.com/nats-io/go-nats-streaming" "github.com/openfaas/faas/gateway/queue" @@ -11,13 +13,56 @@ import ( // NatsQueue queue for work type NatsQueue struct { - nc stan.Conn + nc stan.Conn + ncMutex *sync.RWMutex + maxReconnect int + reconnectDelay time.Duration + ClientID string ClusterID string NATSURL string Topic string } +func (q *NatsQueue) connect() error { + nc, err := stan.Connect( + q.ClusterID, + q.ClientID, + stan.NatsURL(q.NATSURL), + stan.SetConnectionLostHandler(func(conn stan.Conn, err error) { + log.Printf("Disconnected from %s\n", q.NATSURL) + + q.reconnect() + }), + ) + + if err != nil { + return err + } + + q.ncMutex.Lock() + q.nc = nc + q.ncMutex.Unlock() + + return nil +} + +func (q *NatsQueue) reconnect() { + for i := 0; i < q.maxReconnect; i++ { + time.Sleep(time.Second * time.Duration(i) * q.reconnectDelay) + + if err := q.connect(); err == nil { + log.Printf("Reconnection (%d/%d) to %s succeeded", i + 1, q.maxReconnect, q.NATSURL) + + return + } + + log.Printf("Reconnection (%d/%d) to %s failed", i + 1, q.maxReconnect, q.NATSURL) + } + + log.Printf("Reconnection limit (%d) reached\n", q.maxReconnect) +} + // CreateNatsQueue ready for asynchronous processing func CreateNatsQueue(address string, port int, clientConfig NatsConfig) (*NatsQueue, error) { var err error @@ -27,22 +72,23 @@ func CreateNatsQueue(address string, port int, clientConfig NatsConfig) (*NatsQu clientID := clientConfig.GetClientID() clusterID := "faas-cluster" - nc, err := stan.Connect(clusterID, clientID, stan.NatsURL(natsURL)) queue1 := NatsQueue{ - nc: nc, - ClientID: clientID, - ClusterID: clusterID, - NATSURL: natsURL, - Topic: "faas-request", + ClientID: clientID, + ClusterID: clusterID, + NATSURL: natsURL, + Topic: "faas-request", + maxReconnect: clientConfig.GetMaxReconnect(), + reconnectDelay: clientConfig.GetReconnectDelay(), + ncMutex: &sync.RWMutex{}, } + err = queue1.connect() + return &queue1, err } // Queue request for processing func (q *NatsQueue) Queue(req *queue.Request) error { - var err error - fmt.Printf("NatsQueue - submitting request: %s.\n", req.Function) out, err := json.Marshal(req) @@ -50,7 +96,9 @@ func (q *NatsQueue) Queue(req *queue.Request) error { log.Println(err) } - err = q.nc.Publish(q.Topic, out) + q.ncMutex.RLock() + nc := q.nc + q.ncMutex.RUnlock() - return err + return nc.Publish(q.Topic, out) } diff --git a/handler/nats_config.go b/handler/nats_config.go index ca23329..1261a00 100644 --- a/handler/nats_config.go +++ b/handler/nats_config.go @@ -2,15 +2,27 @@ package handler import ( "os" + "time" "github.com/openfaas/nats-queue-worker/nats" ) type NatsConfig interface { GetClientID() string + GetMaxReconnect() int + GetReconnectDelay() time.Duration } type DefaultNatsConfig struct { + maxReconnect int + delayBetweenReconnect time.Duration +} + +func NewDefaultNatsConfig() DefaultNatsConfig { + return DefaultNatsConfig{ + maxReconnect: 5, + delayBetweenReconnect: time.Second * 2, + } } // GetClientID returns the ClientID assigned to this producer/consumer. @@ -19,6 +31,14 @@ func (DefaultNatsConfig) GetClientID() string { return getClientID(val) } +func (c DefaultNatsConfig) GetMaxReconnect() int { + return c.maxReconnect +} + +func (c DefaultNatsConfig) GetReconnectDelay() time.Duration { + return c.delayBetweenReconnect +} + func getClientID(hostname string) string { return "faas-publisher-" + nats.GetClientID(hostname) }