Skip to content

Commit

Permalink
More linters, some cleanup and possible fix for pubsub hang (#5)
Browse files Browse the repository at this point in the history
  • Loading branch information
maroux authored Jun 24, 2021
1 parent ccd389b commit d96a4fb
Show file tree
Hide file tree
Showing 15 changed files with 180 additions and 171 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/golangci-lint.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: Go Checks
name: Golint

on:
push:
Expand Down Expand Up @@ -33,4 +33,4 @@ jobs:
- name: golangci-lint
uses: golangci/golangci-lint-action@v2
with:
version: v1.29
version: v1.40
6 changes: 5 additions & 1 deletion .golangci.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ locale = "US"
check-shadowing = true
disable = ["composites"]

[linters-settings.goimports]

local-prefixes = "github.com/cloudchacho/hedwig-go"

[linters]

enable = ["misspell"]
enable = ["misspell", "gofmt", "goimports", "revive"]
34 changes: 17 additions & 17 deletions aws/aws.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,15 @@ import (
"github.com/cloudchacho/hedwig-go"
)

type awsBackend struct {
type backend struct {
settings *hedwig.Settings

sqs sqsiface.SQSAPI
sns snsiface.SNSAPI
}

// AWSMetadata is additional metadata associated with a message
type AWSMetadata struct {
// Metadata is additional metadata associated with a message
type Metadata struct {
// AWS receipt identifier
ReceiptHandle string

Expand All @@ -48,15 +48,15 @@ type AWSMetadata struct {

const sqsWaitTimeoutSeconds int64 = 20

func (a *awsBackend) getSQSQueueName() string {
func (a *backend) getSQSQueueName() string {
return fmt.Sprintf("HEDWIG-%s", a.settings.QueueName)
}

func (a *awsBackend) getSNSTopic(messageTopic string) string {
func (a *backend) getSNSTopic(messageTopic string) string {
return fmt.Sprintf("arn:aws:sns:%s:%s:hedwig-%s", a.settings.AWSRegion, a.settings.AWSAccountID, messageTopic)
}

func (a *awsBackend) getSQSQueueURL(ctx context.Context) (*string, error) {
func (a *backend) getSQSQueueURL(ctx context.Context) (*string, error) {
out, err := a.sqs.GetQueueUrlWithContext(ctx, &sqs.GetQueueUrlInput{
QueueName: aws.String(a.getSQSQueueName()),
})
Expand All @@ -68,7 +68,7 @@ func (a *awsBackend) getSQSQueueURL(ctx context.Context) (*string, error) {

// isValidForSQS checks that the payload is allowed in SQS message body since only some UTF8 characters are allowed
// ref: https://docs.amazonaws.cn/en_us/AWSSimpleQueueService/latest/APIReference/API_SendMessage.html
func (a *awsBackend) isValidForSQS(payload []byte) bool {
func (a *backend) isValidForSQS(payload []byte) bool {
if !utf8.Valid(payload) {
return false
}
Expand All @@ -79,7 +79,7 @@ func (a *awsBackend) isValidForSQS(payload []byte) bool {
}

// Publish a message represented by the payload, with specified attributes to the specific topic
func (a *awsBackend) Publish(ctx context.Context, message *hedwig.Message, payload []byte, attributes map[string]string, topic string) (string, error) {
func (a *backend) Publish(ctx context.Context, message *hedwig.Message, payload []byte, attributes map[string]string, topic string) (string, error) {
snsTopic := a.getSNSTopic(topic)
var payloadStr string

Expand Down Expand Up @@ -116,7 +116,7 @@ func (a *awsBackend) Publish(ctx context.Context, message *hedwig.Message, paylo

// Receive messages from configured queue(s) and provide it through the callback. This should run indefinitely
// until the context is canceled. Provider metadata should include all info necessary to ack/nack a message.
func (a *awsBackend) Receive(ctx context.Context, numMessages uint32, visibilityTimeout time.Duration, callback hedwig.ConsumerCallback) error {
func (a *backend) Receive(ctx context.Context, numMessages uint32, visibilityTimeout time.Duration, callback hedwig.ConsumerCallback) error {
queueURL, err := a.getSQSQueueURL(ctx)
if err != nil {
return errors.Wrap(err, "failed to get SQS Queue URL")
Expand Down Expand Up @@ -180,7 +180,7 @@ func (a *awsBackend) Receive(ctx context.Context, numMessages uint32, visibility
if err != nil {
receiveCount = -1
}
metadata := AWSMetadata{
metadata := Metadata{
*queueMessage.ReceiptHandle,
firstReceiveTime,
sentTime,
Expand Down Expand Up @@ -208,14 +208,14 @@ func (a *awsBackend) Receive(ctx context.Context, numMessages uint32, visibility
}

// NackMessage nacks a message on the queue
func (a *awsBackend) NackMessage(ctx context.Context, providerMetadata interface{}) error {
func (a *backend) NackMessage(ctx context.Context, providerMetadata interface{}) error {
// not supported by AWS
return nil
}

// AckMessage acknowledges a message on the queue
func (a *awsBackend) AckMessage(ctx context.Context, providerMetadata interface{}) error {
receipt := providerMetadata.(AWSMetadata).ReceiptHandle
func (a *backend) AckMessage(ctx context.Context, providerMetadata interface{}) error {
receipt := providerMetadata.(Metadata).ReceiptHandle
queueURL, err := a.getSQSQueueURL(ctx)
if err != nil {
return errors.Wrap(err, "failed to get SQS Queue URL")
Expand All @@ -227,13 +227,13 @@ func (a *awsBackend) AckMessage(ctx context.Context, providerMetadata interface{
return err
}

// NewAWSBackend creates a backend for publishing and consuming from AWS
// The provider metadata produced by this backend will have concrete type: aws.AWSMetadata
func NewAWSBackend(settings *hedwig.Settings, sessionCache *AWSSessionsCache) hedwig.IBackend {
// NewBackend creates a backend for publishing and consuming from AWS
// The provider metadata produced by this backend will have concrete type: aws.Metadata
func NewBackend(settings *hedwig.Settings, sessionCache *SessionsCache) hedwig.IBackend {

awsSession := sessionCache.GetSession(settings)

return &awsBackend{
return &backend{
settings,
sqs.New(awsSession),
sns.New(awsSession),
Expand Down
12 changes: 6 additions & 6 deletions aws/aws_session.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,19 +38,19 @@ type sessionKey struct {
awsSessionToken string
}

// AWSSessionsCache is a cache that holds sessions
type AWSSessionsCache struct {
// SessionsCache is a cache that holds sessions
type SessionsCache struct {
sessionMap sync.Map
}

// NewAWSSessionsCache creates a new session cache
func NewAWSSessionsCache() *AWSSessionsCache {
return &AWSSessionsCache{
func NewAWSSessionsCache() *SessionsCache {
return &SessionsCache{
sessionMap: sync.Map{},
}
}

func (c *AWSSessionsCache) getOrCreateSession(settings *hedwig.Settings) *session.Session {
func (c *SessionsCache) getOrCreateSession(settings *hedwig.Settings) *session.Session {
key := sessionKey{awsRegion: settings.AWSRegion, awsAccessKeyID: settings.AWSAccessKey, awsSessionToken: settings.AWSSessionToken}
s, ok := c.sessionMap.Load(key)
if !ok {
Expand All @@ -61,6 +61,6 @@ func (c *AWSSessionsCache) getOrCreateSession(settings *hedwig.Settings) *sessio
}

// GetSession retrieves a session if it is cached, otherwise creates one
func (c *AWSSessionsCache) GetSession(settings *hedwig.Settings) *session.Session {
func (c *SessionsCache) GetSession(settings *hedwig.Settings) *session.Session {
return c.getOrCreateSession(settings)
}
55 changes: 29 additions & 26 deletions aws/aws_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,11 +96,14 @@ func (fs *fakeSQS) SendMessageWithContext(ctx aws.Context, in *sqs.SendMessageIn
return args.Get(0).(*sqs.SendMessageOutput), args.Error(1)
}

// revive:disable:var-naming
func (fs *fakeSQS) GetQueueUrlWithContext(ctx aws.Context, in *sqs.GetQueueUrlInput, opts ...request.Option) (*sqs.GetQueueUrlOutput, error) {
args := fs.Called(ctx, in, opts)
return args.Get(0).(*sqs.GetQueueUrlOutput), args.Error(1)
}

// revive:enable:var-naming

func (fs *fakeSQS) ReceiveMessageWithContext(ctx aws.Context, in *sqs.ReceiveMessageInput, opts ...request.Option) (*sqs.ReceiveMessageOutput, error) {
args := fs.Called(ctx, in, opts)
return args.Get(0).(*sqs.ReceiveMessageOutput), args.Error(1)
Expand Down Expand Up @@ -151,35 +154,35 @@ func (s *BackendTestSuite) TestReceive() {
s.Require().NoError(err)
receiveCount := 1
body := `{"vehicle_id": "C_123"}`
messageId := "123"
messageID := "123"
sqsMessage := sqs.Message{
ReceiptHandle: aws.String(receiptHandle),
MessageAttributes: map[string]*sqs.MessageAttributeValue{
"foo": &sqs.MessageAttributeValue{StringValue: aws.String("bar")},
"foo": {StringValue: aws.String("bar")},
},
Attributes: map[string]*string{
sqs.MessageSystemAttributeNameApproximateFirstReceiveTimestamp: aws.String("1295500510456"),
sqs.MessageSystemAttributeNameSentTimestamp: aws.String("1295500510123"),
sqs.MessageSystemAttributeNameApproximateReceiveCount: aws.String(strconv.Itoa(int(receiveCount))),
},
Body: aws.String(body),
MessageId: aws.String(messageId),
MessageId: aws.String(messageID),
}
body2 := `vbI9vCDijJg=`
messageId2 := "456"
messageID2 := "456"
sqsMessage2 := sqs.Message{
ReceiptHandle: aws.String(receiptHandle),
MessageAttributes: map[string]*sqs.MessageAttributeValue{
"foo": &sqs.MessageAttributeValue{StringValue: aws.String("bar")},
"hedwig_encoding": &sqs.MessageAttributeValue{StringValue: aws.String("base64")},
"foo": {StringValue: aws.String("bar")},
"hedwig_encoding": {StringValue: aws.String("base64")},
},
Attributes: map[string]*string{
sqs.MessageSystemAttributeNameApproximateFirstReceiveTimestamp: aws.String("1295500510456"),
sqs.MessageSystemAttributeNameSentTimestamp: aws.String("1295500510123"),
sqs.MessageSystemAttributeNameApproximateReceiveCount: aws.String(strconv.Itoa(int(receiveCount))),
},
Body: aws.String(body2),
MessageId: aws.String(messageId2),
MessageId: aws.String(messageID2),
}
receiveOutput := &sqs.ReceiveMessageOutput{
Messages: []*sqs.Message{&sqsMessage, &sqsMessage2},
Expand All @@ -194,7 +197,7 @@ func (s *BackendTestSuite) TestReceive() {
attributes := map[string]string{
"foo": "bar",
}
providerMetadata := AWSMetadata{
providerMetadata := Metadata{
ReceiptHandle: receiptHandle,
FirstReceiveTime: firstReceiveTime.UTC(),
SentTime: sentTime.UTC(),
Expand Down Expand Up @@ -256,20 +259,20 @@ func (s *BackendTestSuite) TestReceiveFailedNonUTF8Decoding() {
}
receiveCount := 1
body := `foobar`
messageId := "123"
messageID := "123"
sqsMessage := sqs.Message{
ReceiptHandle: aws.String(receiptHandle),
MessageAttributes: map[string]*sqs.MessageAttributeValue{
"foo": &sqs.MessageAttributeValue{StringValue: aws.String("bar")},
"hedwig_encoding": &sqs.MessageAttributeValue{StringValue: aws.String("base64")},
"foo": {StringValue: aws.String("bar")},
"hedwig_encoding": {StringValue: aws.String("base64")},
},
Attributes: map[string]*string{
sqs.MessageSystemAttributeNameApproximateFirstReceiveTimestamp: aws.String("1295500510456"),
sqs.MessageSystemAttributeNameSentTimestamp: aws.String("1295500510123"),
sqs.MessageSystemAttributeNameApproximateReceiveCount: aws.String(strconv.Itoa(int(receiveCount))),
},
Body: aws.String(body),
MessageId: aws.String(messageId),
MessageId: aws.String(messageID),
}
receiveOutput := &sqs.ReceiveMessageOutput{Messages: []*sqs.Message{&sqsMessage}}
s.fakeSQS.On("ReceiveMessageWithContext", ctx, receiveInput, []request.Option(nil)).
Expand Down Expand Up @@ -452,19 +455,19 @@ func (s *BackendTestSuite) TestReceiveMissingAttributes() {
}
receiptHandle := "123"
body := `{"vehicle_id": "C_123"}`
messageId := "123"
messageID := "123"
sqsMessage := sqs.Message{
ReceiptHandle: aws.String(receiptHandle),
MessageAttributes: map[string]*sqs.MessageAttributeValue{
"foo": &sqs.MessageAttributeValue{StringValue: aws.String("bar")},
"foo": {StringValue: aws.String("bar")},
},
Attributes: map[string]*string{
sqs.MessageSystemAttributeNameApproximateFirstReceiveTimestamp: aws.String(""),
sqs.MessageSystemAttributeNameSentTimestamp: aws.String(""),
sqs.MessageSystemAttributeNameApproximateReceiveCount: aws.String(""),
},
Body: aws.String(body),
MessageId: aws.String(messageId),
MessageId: aws.String(messageID),
}
receiveOutput := &sqs.ReceiveMessageOutput{
Messages: []*sqs.Message{&sqsMessage},
Expand All @@ -479,7 +482,7 @@ func (s *BackendTestSuite) TestReceiveMissingAttributes() {
attributes := map[string]string{
"foo": "bar",
}
providerMetadata := AWSMetadata{
providerMetadata := Metadata{
ReceiptHandle: receiptHandle,
FirstReceiveTime: time.Time{},
SentTime: time.Time{},
Expand Down Expand Up @@ -531,9 +534,9 @@ func (s *BackendTestSuite) TestPublish() {
s.fakeSNS.On("PublishWithContext", ctx, expectedSnsInput, mock.Anything).
Return(output, nil)

messageId, err := s.backend.Publish(ctx, s.message, s.payload, s.attributes, msgTopic)
messageID, err := s.backend.Publish(ctx, s.message, s.payload, s.attributes, msgTopic)
s.NoError(err)
s.Equal(messageId, "123")
s.Equal(messageID, "123")

s.fakeSNS.AssertExpectations(s.T())
}
Expand Down Expand Up @@ -569,9 +572,9 @@ func (s *BackendTestSuite) TestPublishInvalidCharacters() {
s.fakeSNS.On("PublishWithContext", ctx, expectedSnsInput, mock.Anything).
Return(output, nil)

messageId, err := s.backend.Publish(ctx, s.message, invalidPayload, s.attributes, msgTopic)
messageID, err := s.backend.Publish(ctx, s.message, invalidPayload, s.attributes, msgTopic)
s.NoError(err)
s.Equal(messageId, "123")
s.Equal(messageID, "123")

s.fakeSNS.AssertExpectations(s.T())
}
Expand Down Expand Up @@ -628,7 +631,7 @@ func (s *BackendTestSuite) TestAck() {
s.fakeSQS.On("DeleteMessageWithContext", ctx, deleteInput, mock.Anything).
Return(deleteOutput, nil)

err := s.backend.AckMessage(ctx, AWSMetadata{ReceiptHandle: receiptHandle})
err := s.backend.AckMessage(ctx, Metadata{ReceiptHandle: receiptHandle})
s.NoError(err)

s.fakeSQS.AssertExpectations(s.T())
Expand Down Expand Up @@ -657,7 +660,7 @@ func (s *BackendTestSuite) TestAckError() {
s.fakeSQS.On("DeleteMessageWithContext", ctx, deleteInput, mock.Anything).
Return((*sqs.DeleteMessageOutput)(nil), errors.New("failed to ack"))

err := s.backend.AckMessage(ctx, AWSMetadata{ReceiptHandle: receiptHandle})
err := s.backend.AckMessage(ctx, Metadata{ReceiptHandle: receiptHandle})
s.EqualError(err, "failed to ack")

s.fakeSQS.AssertExpectations(s.T())
Expand All @@ -675,7 +678,7 @@ func (s *BackendTestSuite) TestAckGetQueueError() {

receiptHandle := "foobar"

err := s.backend.AckMessage(ctx, AWSMetadata{ReceiptHandle: receiptHandle})
err := s.backend.AckMessage(ctx, Metadata{ReceiptHandle: receiptHandle})
s.EqualError(err, "failed to get SQS Queue URL: no internet")

s.fakeSQS.AssertExpectations(s.T())
Expand All @@ -686,7 +689,7 @@ func (s *BackendTestSuite) TestNack() {

receiptHandle := "foobar"

err := s.backend.NackMessage(ctx, AWSMetadata{ReceiptHandle: receiptHandle})
err := s.backend.NackMessage(ctx, Metadata{ReceiptHandle: receiptHandle})
s.NoError(err)

// no calls expected
Expand All @@ -699,7 +702,7 @@ func (s *BackendTestSuite) TestNew() {

type BackendTestSuite struct {
suite.Suite
backend *awsBackend
backend *backend
settings *hedwig.Settings
fakeSQS *fakeSQS
fakeSNS *fakeSNS
Expand Down Expand Up @@ -739,7 +742,7 @@ func (s *BackendTestSuite) SetupTest() {
payload := []byte(`{"vehicle_id": "C_123"}`)
attributes := map[string]string{"foo": "bar"}

s.backend = NewAWSBackend(settings, NewAWSSessionsCache()).(*awsBackend)
s.backend = NewBackend(settings, NewAWSSessionsCache()).(*backend)
s.backend.sqs = fakeSQS
s.backend.sns = fakeSNS
s.settings = settings
Expand Down
2 changes: 1 addition & 1 deletion consumer.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ func (c *queueConsumer) processMessage(ctx context.Context, payload []byte, attr
return
}

loggingFields = LoggingFields{"message_id": message.ID}
loggingFields = LoggingFields{"message_id": message.ID, "type": message.Type, "version": message.DataSchemaVersion}

callbackKey := MessageTypeMajorVersion{message.Type, uint(message.DataSchemaVersion.Major())}
var callback CallbackFunction
Expand Down
Loading

0 comments on commit d96a4fb

Please sign in to comment.