Skip to content

Commit

Permalink
Add contexts to storage API
Browse files Browse the repository at this point in the history
  • Loading branch information
Thomas Desrosiers committed Jun 25, 2017
1 parent d7808c0 commit 2a1dd9c
Show file tree
Hide file tree
Showing 16 changed files with 81 additions and 59 deletions.
7 changes: 4 additions & 3 deletions handlers/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
package handlers

import (
"context"
"encoding/json"
"fmt"
"net/http"
Expand Down Expand Up @@ -41,7 +42,7 @@ func GetShortHandler(store storage.Storage, index Index) http.Handler {
}
index.Short = short

url, err := store.Load(short)
url, err := store.Load(context.TODO(), short)
switch err := errors.Cause(err); err {
case nil:
http.Redirect(w, r, url, http.StatusFound)
Expand Down Expand Up @@ -81,14 +82,14 @@ func SetShortHandler(store storage.Storage) http.Handler {
return
}

short, err = unnamed.Save(url)
short, err = unnamed.Save(context.TODO(), url)
} else {
if !namedOk {
http.Error(w, "Current storage layer does not support storing a named url", http.StatusBadRequest)
return
}

err = named.SaveName(short, url)
err = named.SaveName(context.TODO(), short, url)
}
if err != nil {
http.Error(w, fmt.Sprintf("Failed to save '%s' to '%s' because: %s", url, short, err), http.StatusInternalServerError)
Expand Down
7 changes: 4 additions & 3 deletions storage/filesystem.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package storage

import (
"context"
"io/ioutil"
"os"
"path/filepath"
Expand Down Expand Up @@ -30,7 +31,7 @@ func (s *Filesystem) Code(url string) string {
return strconv.FormatUint(s.c, 36)
}

func (s *Filesystem) Save(url string) (string, error) {
func (s *Filesystem) Save(ctx context.Context, url string) (string, error) {
if _, err := validateURL(url); err != nil {
return "", err
}
Expand Down Expand Up @@ -61,7 +62,7 @@ func FlattenPath(path string, separator string) string {
return strings.Replace(path, string(os.PathSeparator), separator, -1)
}

func (s *Filesystem) SaveName(rawShort, url string) error {
func (s *Filesystem) SaveName(ctx context.Context, rawShort, url string) error {
short, err := sanitizeShort(rawShort)
if err != nil {
return err
Expand All @@ -82,7 +83,7 @@ func (s *Filesystem) SaveName(rawShort, url string) error {
return err
}

func (s *Filesystem) Load(rawShort string) (string, error) {
func (s *Filesystem) Load(ctx context.Context, rawShort string) (string, error) {
short, err := sanitizeShort(rawShort)
if err != nil {
return "", err
Expand Down
9 changes: 5 additions & 4 deletions storage/inmem.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package storage

import (
"context"
"encoding/json"
"fmt"
"sync"
Expand Down Expand Up @@ -46,15 +47,15 @@ func NewInmemFromMap(randLength int, initialShorts map[string]string) (*Inmem, e
s, _ := NewInmem(randLength)

for k, v := range initialShorts {
if err := s.SaveName(k, v); err != nil {
if err := s.SaveName(context.Background(), k, v); err != nil {
return nil, errors.Wrap(err, "failed to save initial short")
}
}

return s, nil
}

func (s *Inmem) Save(url string) (string, error) {
func (s *Inmem) Save(ctx context.Context, url string) (string, error) {
if _, err := validateURL(url); err != nil {
return "", err
}
Expand All @@ -76,7 +77,7 @@ func (s *Inmem) Save(url string) (string, error) {
return "", ErrShortExhaustion
}

func (s *Inmem) SaveName(rawShort string, url string) error {
func (s *Inmem) SaveName(ctx context.Context, rawShort string, url string) error {
short, err := sanitizeShort(rawShort)
if err != nil {
return err
Expand All @@ -91,7 +92,7 @@ func (s *Inmem) SaveName(rawShort string, url string) error {
return nil
}

func (s *Inmem) Load(rawShort string) (string, error) {
func (s *Inmem) Load(ctx context.Context, rawShort string) (string, error) {
short, err := sanitizeShort(rawShort)
if err != nil {
return "", err
Expand Down
12 changes: 8 additions & 4 deletions storage/migrations/S3v3.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
package migrations

import "github.com/thomaso-mirodin/go-shorten/storage"
import (
"context"

"github.com/thomaso-mirodin/go-shorten/storage"
)

func init() {
storage.SupportedStorageTypes["S3v3"] = new(interface{})
Expand All @@ -12,12 +16,12 @@ type S3v2MigrationStore struct {
*storage.S3
}

func (s *S3v2MigrationStore) Load(short string) (long string, err error) {
long, err = s.S3.Load(short)
func (s *S3v2MigrationStore) Load(ctx context.Context, short string) (long string, err error) {
long, err = s.S3.Load(ctx, short)
if err != nil {
return
}

err = s.S3.SaveName(short, long)
err = s.S3.SaveName(ctx, short, long)
return
}
12 changes: 7 additions & 5 deletions storage/multistorage/loader.go
Original file line number Diff line number Diff line change
@@ -1,20 +1,22 @@
package multistorage

import (
"context"

"github.com/pkg/errors"
"github.com/thomaso-mirodin/go-shorten/storage"
)

// Loaders are expected to process the slice of stores and return the result of Load(short) from one of them. Should return ErrEmpty if stores is empty
type Loader func(short string, stores []storage.NamedStorage) (string, error)
type Loader func(ctx context.Context, short string, stores []storage.NamedStorage) (string, error)

func loadFirstFunc(short string, stores []storage.NamedStorage) (string, error) {
func loadFirstFunc(ctx context.Context, short string, stores []storage.NamedStorage) (string, error) {
if len(stores) == 0 {
return "", ErrEmpty
}

for _, store := range stores {
long, err := store.Load(short)
long, err := store.Load(ctx, short)
if err == storage.ErrShortNotSet {
continue
}
Expand All @@ -26,14 +28,14 @@ func loadFirstFunc(short string, stores []storage.NamedStorage) (string, error)

var ErrUnexpectedMultipleAnswers = errors.New("MultiStorage: results returned were not the same")

func loadCompareAllResultsFunc(short string, stores []storage.NamedStorage) (string, error) {
func loadCompareAllResultsFunc(ctx context.Context, short string, stores []storage.NamedStorage) (string, error) {
if len(stores) == 0 {
return "", ErrEmpty
}

results := make([]loadResult, 0, len(stores))
for _, store := range stores {
s, err := store.Load(short)
s, err := store.Load(ctx, short)

results = append(results, loadResult{s, err})
}
Expand Down
5 changes: 3 additions & 2 deletions storage/multistorage/loader_internal_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package multistorage

import (
"context"
"testing"

"github.com/pkg/errors"
Expand Down Expand Up @@ -78,7 +79,7 @@ func TestLoadFirstFunc(t *testing.T) {
t.Parallel()

t.Logf("querying for %q, expecting (%q,%#v)", tt.inputShort, tt.expectedLong, tt.expectedErr)
long, err := loadFirstFunc(tt.inputShort, tt.stores)
long, err := loadFirstFunc(context.Background(), tt.inputShort, tt.stores)
t.Logf("got: (%q, %#v)", long, err)
if cause := errors.Cause(err); cause != tt.expectedErr {
t.Errorf("unexpected error: expected(%#v) != actual(%#v)", tt.expectedErr, cause)
Expand Down Expand Up @@ -165,7 +166,7 @@ func TestLoadCompareAllResultsFunc(t *testing.T) {
t.Parallel()

t.Logf("querying for %q, expecting (%q,%#v)", tt.inputShort, tt.expectedLong, tt.expectedErr)
long, err := loadCompareAllResultsFunc(tt.inputShort, tt.stores)
long, err := loadCompareAllResultsFunc(context.Background(), tt.inputShort, tt.stores)
t.Logf("got: (%q, %#v)", long, err)
if cause := errors.Cause(err); cause != tt.expectedErr {
t.Errorf("unexpected error: expected(%#v) != actual(%#v)", tt.expectedErr, cause)
Expand Down
10 changes: 6 additions & 4 deletions storage/multistorage/multistorage.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package multistorage

import (
"context"

"github.com/pkg/errors"
"github.com/thomaso-mirodin/go-shorten/storage"
)
Expand Down Expand Up @@ -51,19 +53,19 @@ func (s *MultiStorage) validateStore() error {
}

// Load with a basic MultiStorage will query the underlying storages (in order) returning when either a response or error is encountered, only returning an ErrShortNotSet when all underlying storages have been exhausted.
func (s *MultiStorage) Load(short string) (string, error) {
func (s *MultiStorage) Load(ctx context.Context, short string) (string, error) {
if err := s.validateStore(); err != nil {
return "", errors.Wrap(err, "failed to validate underlying store")
}

return s.loader(short, s.stores)
return s.loader(ctx, short, s.stores)
}

// SaveName will return the first successful insure that all
func (s *MultiStorage) SaveName(short string, long string) error {
func (s *MultiStorage) SaveName(ctx context.Context, short string, long string) error {
if err := s.validateStore(); err != nil {
return errors.Wrap(err, "failed to validate underlying store")
}

return s.saver(short, long, s.stores)
return s.saver(ctx, short, long, s.stores)
}
7 changes: 4 additions & 3 deletions storage/multistorage/multistorage_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package multistorage_test

import (
"context"
"testing"

"github.com/pkg/errors"
Expand Down Expand Up @@ -41,13 +42,13 @@ func TestSingleBackend(t *testing.T) {
}

t.Logf("Saving %q->%q", inputShort, inputLong)
if err := m.SaveName(inputShort, inputLong); err != nil {
if err := m.SaveName(context.Background(), inputShort, inputLong); err != nil {
t.Fatalf("error saving %q->%q into the store", inputShort, inputLong)
}
t.Logf("Got: %v", err)

t.Logf("Loading %q", inputShort)
long, err := m.Load(inputShort)
long, err := m.Load(context.Background(), inputShort)
t.Logf("Got: %q, %v", long, err)
if err != nil {
t.Fatalf("error loading value that should exist: %q", err)
Expand Down Expand Up @@ -77,7 +78,7 @@ func TestMultipleBackendLoad(t *testing.T) {
for _, input := range inputShorts {
for inputShort, expectedLong := range input {
t.Logf("Loading %q", inputShort)
long, err := m.Load(inputShort)
long, err := m.Load(context.Background(), inputShort)
t.Logf("Got: %q, %v", long, err)
if err != nil {
t.Errorf("loading %q returned an error: %q", inputShort, err)
Expand Down
12 changes: 7 additions & 5 deletions storage/multistorage/saver.go
Original file line number Diff line number Diff line change
@@ -1,22 +1,24 @@
package multistorage

import (
"context"

multierror "github.com/hashicorp/go-multierror"
"github.com/pkg/errors"
"github.com/thomaso-mirodin/go-shorten/storage"
)

// Saver are expected to process a slice of storages and return a result of SaveName(short, url string)
type Saver func(short string, url string, stores []storage.NamedStorage) error
type Saver func(ctx context.Context, short string, url string, stores []storage.NamedStorage) error

func saveAllFunc(short string, url string, stores []storage.NamedStorage) error {
func saveAllFunc(ctx context.Context, short string, url string, stores []storage.NamedStorage) error {
if len(stores) == 0 {
return ErrEmpty
}

errs := new(multierror.Error)
for _, store := range stores {
err := store.SaveName(short, url)
err := store.SaveName(ctx, short, url)

if err != nil {
multierror.Append(
Expand All @@ -30,14 +32,14 @@ func saveAllFunc(short string, url string, stores []storage.NamedStorage) error
}

// saveOnlyOnceFunc
func saveOnlyOnceFunc(short string, url string, stores []storage.NamedStorage) error {
func saveOnlyOnceFunc(ctx context.Context, short string, url string, stores []storage.NamedStorage) error {
if len(stores) == 0 {
return ErrEmpty
}

errs := new(multierror.Error)
for _, store := range stores {
err := store.SaveName(short, url)
err := store.SaveName(ctx, short, url)

if err == nil {
return nil
Expand Down
3 changes: 2 additions & 1 deletion storage/multistorage/saver_internal_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package multistorage

import (
"context"
"testing"

"github.com/pkg/errors"
Expand All @@ -21,7 +22,7 @@ func (stv saveTestData) testFunc(t *testing.T, testedFunc Saver) {
stores := stv.inputStores

t.Logf("saving (%q, %q, %q), expecting (%#v)", stv.inputShort, stv.inputURL, stv.inputStores, stv.expectedErr)
err := testedFunc(stv.inputShort, stv.inputURL, stores)
err := testedFunc(context.Background(), stv.inputShort, stv.inputURL, stores)
t.Logf("got: (%#v)", err)

if cause := errors.Cause(err); cause != stv.expectedErr {
Expand Down
5 changes: 3 additions & 2 deletions storage/regex.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package storage

import (
"context"
"fmt"
"regexp"
)
Expand Down Expand Up @@ -37,7 +38,7 @@ func NewRegexFromList(redirects map[string]string) (*Regex, error) {
}, nil
}

func (r Regex) Load(short string) (string, error) {
func (r Regex) Load(ctx context.Context, short string) (string, error) {
// Regex intentionally doesn't do sanitization, each regex can have whatever flexability it wants

for _, remap := range r.remaps {
Expand All @@ -49,7 +50,7 @@ func (r Regex) Load(short string) (string, error) {
return "", ErrShortNotSet
}

func (r Regex) SaveName(short string, long string) (string, error) {
func (r Regex) SaveName(ctx context.Context, short string, long string) (string, error) {
// Regex intentionally doesn't do sanitization, each regex can have whatever flexability it wants

return "", fmt.Errorf("regex doesn't yet support saving after creation")
Expand Down
3 changes: 2 additions & 1 deletion storage/regex_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package storage

import (
"context"
"testing"

"github.com/pkg/errors"
Expand Down Expand Up @@ -29,7 +30,7 @@ func TestRegexLoad(t *testing.T) {

for _, tt := range testTable {
t.Logf("Table: %#v", tt)
actual, err := r.Load(tt.in)
actual, err := r.Load(context.Background(), tt.in)
if err != tt.err {
t.Errorf("actual err (%q) != expected err (%q)", err, tt.err)
}
Expand Down
Loading

0 comments on commit 2a1dd9c

Please sign in to comment.