Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

producer: Integrate context.Context with publishing #365

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 31 additions & 12 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"bufio"
"bytes"
"compress/flate"
"context"
"crypto/tls"
"encoding/json"
"errors"
Expand Down Expand Up @@ -119,8 +120,7 @@ func NewConn(addr string, config *Config, delegate ConnDelegate) *Conn {
// The logger parameter is an interface that requires the following
// method to be implemented (such as the the stdlib log.Logger):
//
// Output(calldepth int, s string)
//
// Output(calldepth int, s string)
func (c *Conn) SetLogger(l logger, lvl LogLevel, format string) {
c.logGuard.Lock()
defer c.logGuard.Unlock()
Expand Down Expand Up @@ -171,12 +171,18 @@ func (c *Conn) getLogLevel() LogLevel {
// Connect dials and bootstraps the nsqd connection
// (including IDENTIFY) and returns the IdentifyResponse
func (c *Conn) Connect() (*IdentifyResponse, error) {
ctx := context.Background()
return c.ConnectWithContext(ctx)
}

func (c *Conn) ConnectWithContext(ctx context.Context) (*IdentifyResponse, error) {
dialer := &net.Dialer{
LocalAddr: c.config.LocalAddr,
Timeout: c.config.DialTimeout,
}

conn, err := dialer.Dial("tcp", c.addr)
// the timeout used is smallest of dialer.Timeout (config.DialTimeout) or context timeout
conn, err := dialer.DialContext(ctx, "tcp", c.addr)
if err != nil {
return nil, err
}
Expand All @@ -190,7 +196,7 @@ func (c *Conn) Connect() (*IdentifyResponse, error) {
return nil, fmt.Errorf("[%s] failed to write magic - %s", c.addr, err)
}

resp, err := c.identify()
resp, err := c.identify(ctx)
if err != nil {
return nil, err
}
Expand All @@ -200,7 +206,7 @@ func (c *Conn) Connect() (*IdentifyResponse, error) {
c.log(LogLevelError, "Auth Required")
return nil, errors.New("Auth Required")
}
err := c.auth(c.config.AuthSecret)
err := c.auth(ctx, c.config.AuthSecret)
if err != nil {
c.log(LogLevelError, "Auth Failed %s", err)
return nil, err
Expand Down Expand Up @@ -291,13 +297,26 @@ func (c *Conn) Write(p []byte) (int, error) {
// WriteCommand is a goroutine safe method to write a Command
// to this connection, and flush.
func (c *Conn) WriteCommand(cmd *Command) error {
ctx := context.Background()
return c.WriteCommandWithContext(ctx, cmd)
}

func (c *Conn) WriteCommandWithContext(ctx context.Context, cmd *Command) error {
c.mtx.Lock()

_, err := cmd.WriteTo(c)
if err != nil {
var err error
select {
case <-ctx.Done():
c.mtx.Unlock()
return ctx.Err()
default:
_, err := cmd.WriteTo(c)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we're shadowing err here and it won't be correct on line 323

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch! This should be fixed now 👍

if err != nil {
goto exit
}
err = c.Flush()
goto exit
}
err = c.Flush()

exit:
c.mtx.Unlock()
Expand All @@ -320,7 +339,7 @@ func (c *Conn) Flush() error {
return nil
}

func (c *Conn) identify() (*IdentifyResponse, error) {
func (c *Conn) identify(ctx context.Context) (*IdentifyResponse, error) {
ci := make(map[string]interface{})
ci["client_id"] = c.config.ClientID
ci["hostname"] = c.config.Hostname
Expand Down Expand Up @@ -350,7 +369,7 @@ func (c *Conn) identify() (*IdentifyResponse, error) {
return nil, ErrIdentify{err.Error()}
}

err = c.WriteCommand(cmd)
err = c.WriteCommandWithContext(ctx, cmd)
if err != nil {
return nil, ErrIdentify{err.Error()}
}
Expand Down Expand Up @@ -479,13 +498,13 @@ func (c *Conn) upgradeSnappy() error {
return nil
}

func (c *Conn) auth(secret string) error {
func (c *Conn) auth(ctx context.Context, secret string) error {
cmd, err := Auth(secret)
if err != nil {
return err
}

err = c.WriteCommand(cmd)
err = c.WriteCommandWithContext(ctx, cmd)
if err != nil {
return err
}
Expand Down
86 changes: 69 additions & 17 deletions producer.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package nsq

import (
"context"
"fmt"
"log"
"os"
Expand All @@ -15,8 +16,10 @@ type producerConn interface {
SetLoggerLevel(LogLevel)
SetLoggerForLevel(logger, LogLevel, string)
Connect() (*IdentifyResponse, error)
ConnectWithContext(context.Context) (*IdentifyResponse, error)
Close() error
WriteCommand(*Command) error
WriteCommandWithContext(context.Context, *Command) error
}

// Producer is a high-level type to publish to NSQ.
Expand Down Expand Up @@ -53,6 +56,7 @@ type Producer struct {
// to retrieve metadata about the command after the
// response is received.
type ProducerTransaction struct {
ctx context.Context
cmd *Command
doneChan chan *ProducerTransaction
Error error // the error (or nil) of the publish command
Expand Down Expand Up @@ -105,23 +109,27 @@ func NewProducer(addr string, config *Config) (*Producer, error) {
// configured correctly, rather than relying on the lazy "connect on Publish"
// behavior of a Producer.
func (w *Producer) Ping() error {
ctx := context.Background()
return w.PingWithContext(ctx)
}

func (w *Producer) PingWithContext(ctx context.Context) error {
if atomic.LoadInt32(&w.state) != StateConnected {
err := w.connect()
err := w.connect(ctx)
if err != nil {
return err
}
}

return w.conn.WriteCommand(Nop())
return w.conn.WriteCommandWithContext(ctx, Nop())
}

// SetLogger assigns the logger to use as well as a level
//
// The logger parameter is an interface that requires the following
// method to be implemented (such as the the stdlib log.Logger):
//
// Output(calldepth int, s string)
//
// Output(calldepth int, s string)
func (w *Producer) SetLogger(l logger, lvl LogLevel) {
w.logGuard.Lock()
defer w.logGuard.Unlock()
Expand Down Expand Up @@ -192,7 +200,13 @@ func (w *Producer) Stop() {
// and the response error if present
func (w *Producer) PublishAsync(topic string, body []byte, doneChan chan *ProducerTransaction,
args ...interface{}) error {
return w.sendCommandAsync(Publish(topic, body), doneChan, args)
ctx := context.Background()
return w.PublishAsyncWithContext(ctx, topic, body, doneChan, args...)
}

func (w *Producer) PublishAsyncWithContext(ctx context.Context, topic string, body []byte, doneChan chan *ProducerTransaction,
args ...interface{}) error {
return w.sendCommandAsync(ctx, Publish(topic, body), doneChan, args)
}

// MultiPublishAsync publishes a slice of message bodies to the specified topic
Expand All @@ -203,35 +217,56 @@ func (w *Producer) PublishAsync(topic string, body []byte, doneChan chan *Produc
// will receive a `ProducerTransaction` instance with the supplied variadic arguments
// and the response error if present
func (w *Producer) MultiPublishAsync(topic string, body [][]byte, doneChan chan *ProducerTransaction,
args ...interface{}) error {
ctx := context.Background()
return w.MultiPublishAsyncWithContext(ctx, topic, body, doneChan, args...)
}

func (w *Producer) MultiPublishAsyncWithContext(ctx context.Context, topic string, body [][]byte, doneChan chan *ProducerTransaction,
args ...interface{}) error {
cmd, err := MultiPublish(topic, body)
if err != nil {
return err
}
return w.sendCommandAsync(cmd, doneChan, args)
return w.sendCommandAsync(ctx, cmd, doneChan, args)
}

// Publish synchronously publishes a message body to the specified topic, returning
// an error if publish failed
func (w *Producer) Publish(topic string, body []byte) error {
return w.sendCommand(Publish(topic, body))
ctx := context.Background()
return w.PublishWithContext(ctx, topic, body)
}

func (w *Producer) PublishWithContext(ctx context.Context, topic string, body []byte) error {
return w.sendCommand(ctx, Publish(topic, body))
}

// MultiPublish synchronously publishes a slice of message bodies to the specified topic, returning
// an error if publish failed
func (w *Producer) MultiPublish(topic string, body [][]byte) error {
ctx := context.Background()
return w.MultiPublishWithContext(ctx, topic, body)
}

func (w *Producer) MultiPublishWithContext(ctx context.Context, topic string, body [][]byte) error {
cmd, err := MultiPublish(topic, body)
if err != nil {
return err
}
return w.sendCommand(cmd)
return w.sendCommand(ctx, cmd)
}

// DeferredPublish synchronously publishes a message body to the specified topic
// where the message will queue at the channel level until the timeout expires, returning
// an error if publish failed
func (w *Producer) DeferredPublish(topic string, delay time.Duration, body []byte) error {
return w.sendCommand(DeferredPublish(topic, delay, body))
ctx := context.Background()
return w.DeferredPublishWithContext(ctx, topic, delay, body)
}

func (w *Producer) DeferredPublishWithContext(ctx context.Context, topic string, delay time.Duration, body []byte) error {
return w.sendCommand(ctx, DeferredPublish(topic, delay, body))
}

// DeferredPublishAsync publishes a message body to the specified topic
Expand All @@ -244,12 +279,18 @@ func (w *Producer) DeferredPublish(topic string, delay time.Duration, body []byt
// and the response error if present
func (w *Producer) DeferredPublishAsync(topic string, delay time.Duration, body []byte,
doneChan chan *ProducerTransaction, args ...interface{}) error {
return w.sendCommandAsync(DeferredPublish(topic, delay, body), doneChan, args)
ctx := context.Background()
return w.DeferredPublishAsyncWithContext(ctx, topic, delay, body, doneChan, args...)
}

func (w *Producer) DeferredPublishAsyncWithContext(ctx context.Context, topic string, delay time.Duration, body []byte,
doneChan chan *ProducerTransaction, args ...interface{}) error {
return w.sendCommandAsync(ctx, DeferredPublish(topic, delay, body), doneChan, args)
}

func (w *Producer) sendCommand(cmd *Command) error {
func (w *Producer) sendCommand(ctx context.Context, cmd *Command) error {
doneChan := make(chan *ProducerTransaction)
err := w.sendCommandAsync(cmd, doneChan, nil)
err := w.sendCommandAsync(ctx, cmd, doneChan, nil)
if err != nil {
close(doneChan)
return err
Expand All @@ -258,21 +299,22 @@ func (w *Producer) sendCommand(cmd *Command) error {
return t.Error
}

func (w *Producer) sendCommandAsync(cmd *Command, doneChan chan *ProducerTransaction,
func (w *Producer) sendCommandAsync(ctx context.Context, cmd *Command, doneChan chan *ProducerTransaction,
args []interface{}) error {
// keep track of how many outstanding producers we're dealing with
// in order to later ensure that we clean them all up...
atomic.AddInt32(&w.concurrentProducers, 1)
defer atomic.AddInt32(&w.concurrentProducers, -1)

if atomic.LoadInt32(&w.state) != StateConnected {
err := w.connect()
err := w.connect(ctx)
if err != nil {
return err
}
}

t := &ProducerTransaction{
ctx: ctx,
cmd: cmd,
doneChan: doneChan,
Args: args,
Expand All @@ -282,12 +324,14 @@ func (w *Producer) sendCommandAsync(cmd *Command, doneChan chan *ProducerTransac
case w.transactionChan <- t:
case <-w.exitChan:
return ErrStopped
case <-ctx.Done():
return ctx.Err()
}

return nil
}

func (w *Producer) connect() error {
func (w *Producer) connect(ctx context.Context) error {
w.guard.Lock()
defer w.guard.Unlock()

Expand All @@ -312,7 +356,7 @@ func (w *Producer) connect() error {
w.conn.SetLoggerForLevel(w.logger[index], LogLevel(index), format)
}

_, err := w.conn.Connect()
_, err := w.conn.ConnectWithContext(ctx)
if err != nil {
w.conn.Close()
w.log(LogLevelError, "(%s) error connecting to nsqd - %s", w.addr, err)
Expand Down Expand Up @@ -344,9 +388,17 @@ func (w *Producer) router() {
select {
case t := <-w.transactionChan:
w.transactions = append(w.transactions, t)
err := w.conn.WriteCommand(t.cmd)
err := w.conn.WriteCommandWithContext(t.ctx, t.cmd)
if err != nil {
w.log(LogLevelError, "(%s) sending command - %s", w.conn.String(), err)
if err == context.Canceled || err == context.DeadlineExceeded {
// keep the connection alive if related to context timeout
// need to do some stuff that's in Producer.popTransaction here
w.transactions = w.transactions[1:]
t.Error = err
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't love this code we've copied in to this context-specific edge case, can't we consolidate this in to popTransaction? Do we need to define a context-specific error code for the protocol?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I wasn't a huge fan of how that was being handled...

I updated the code to invoke popTransaction without modifying the function signature, meaning I seemingly need to have at least one new "frameType" in protocol.go to handle the context errors. I opted to be explicit and add two "frameTypes" (for context.Cancelled and context.DeadlineExceeded respectively) so that we would not need to infer the context error from the byte slice/handle a default case like so:

if frameType == FrameTypeContextError {
    switch string(data) {
    case context.Canceled.Error():
        t.Error = context.Canceled
    case context.DeadlineExceeded.Error():
        t.Error = context.DeadlineExceeded
    default:
        t.Error = ???
    }
}

The FrameTypeContextCanceled and FrameTypeContextDeadlineExceeded are not really "frameTypes" per se, so I also don't fully love this approach, but it feels better than what it was.

I'm curious to hear your thoughts on this approach and any suggestions you may have on how to improve upon this.

t.finish()
continue
}
w.close()
}
case data := <-w.responseChan:
Expand Down
Loading