diff --git a/handler/handler.go b/handler/handler.go index 5eef9d5..c7b10ab 100644 --- a/handler/handler.go +++ b/handler/handler.go @@ -4,6 +4,8 @@ import ( "encoding/json" "fmt" "log" + "sync" + "time" "github.com/nats-io/go-nats-streaming" "github.com/openfaas/faas/gateway/queue" @@ -11,11 +13,54 @@ import ( // NatsQueue queue for work type NatsQueue struct { - nc stan.Conn - ClientID string - ClusterID string - NATSURL string - Topic string + nc stan.Conn + ncMutex *sync.RWMutex + maxReconnect int + delayBetweenReconnect time.Duration + + ClientID string + ClusterID string + NATSURL string + Topic string +} + +func (q *NatsQueue) connect() error { + nc, err := stan.Connect( + q.ClusterID, + q.ClientID, + stan.NatsURL(q.NATSURL), + stan.SetConnectionLostHandler(func(conn stan.Conn, err error) { + log.Printf("Disconnected from %s\n", q.NATSURL) + + q.reconnect(0) + }), + ) + + if err != nil { + return err + } + + q.ncMutex.Lock() + q.nc = nc + q.ncMutex.Unlock() + + return nil +} + +func (q *NatsQueue) reconnect(iteration int) { + if iteration < q.maxReconnect { + time.Sleep(time.Second * time.Duration(iteration) * q.delayBetweenReconnect) + + if err := q.connect(); err != nil { + log.Printf("Reconnection (%d) to %s failed", iteration, q.NATSURL) + + q.reconnect(iteration + 1) + } else { + log.Printf("Reconnection (%d) to %s succed", iteration, q.NATSURL) + } + } else { + log.Printf("Reconnection limit (%d) reached\n", q.maxReconnect) + } } // CreateNatsQueue ready for asynchronous processing @@ -27,15 +72,16 @@ func CreateNatsQueue(address string, port int, clientConfig NatsConfig) (*NatsQu clientID := clientConfig.GetClientID() clusterID := "faas-cluster" - nc, err := stan.Connect(clusterID, clientID, stan.NatsURL(natsURL)) queue1 := NatsQueue{ - nc: nc, ClientID: clientID, ClusterID: clusterID, NATSURL: natsURL, Topic: "faas-request", + ncMutex: &sync.RWMutex{}, } + err = queue1.connect() + return &queue1, err } @@ -50,7 +96,11 @@ func (q *NatsQueue) Queue(req *queue.Request) error { log.Println(err) } - err = q.nc.Publish(q.Topic, out) + q.ncMutex.RLock() + nc := q.nc + q.ncMutex.RUnlock() + + err = nc.Publish(q.Topic, out) return err } diff --git a/handler/nats_config.go b/handler/nats_config.go index ca23329..37fab14 100644 --- a/handler/nats_config.go +++ b/handler/nats_config.go @@ -2,15 +2,27 @@ package handler import ( "os" + "time" "github.com/openfaas/nats-queue-worker/nats" ) type NatsConfig interface { GetClientID() string + GetMaxReconnect() int + GetDelayBetweenReconnect() time.Duration } type DefaultNatsConfig struct { + maxReconnect int + delayBetweenReconnect time.Duration +} + +func NewDefaultNatsConfig() DefaultNatsConfig { + return DefaultNatsConfig{ + maxReconnect: 5, + delayBetweenReconnect: time.Second * 2, + } } // GetClientID returns the ClientID assigned to this producer/consumer. @@ -19,6 +31,14 @@ func (DefaultNatsConfig) GetClientID() string { return getClientID(val) } +func (c *DefaultNatsConfig) GetMaxReconnect() int { + return c.maxReconnect +} + +func (c *DefaultNatsConfig) GetDelayBetweenReconnect() time.Duration { + return c.delayBetweenReconnect +} + func getClientID(hostname string) string { return "faas-publisher-" + nats.GetClientID(hostname) }