diff --git a/handlers/api.go b/handlers/api.go index c1d5756..187095d 100644 --- a/handlers/api.go +++ b/handlers/api.go @@ -2,6 +2,7 @@ package handlers import ( + "context" "encoding/json" "fmt" "net/http" @@ -41,7 +42,7 @@ func GetShortHandler(store storage.Storage, index Index) http.Handler { } index.Short = short - url, err := store.Load(short) + url, err := store.Load(context.TODO(), short) switch err := errors.Cause(err); err { case nil: http.Redirect(w, r, url, http.StatusFound) @@ -81,14 +82,14 @@ func SetShortHandler(store storage.Storage) http.Handler { return } - short, err = unnamed.Save(url) + short, err = unnamed.Save(context.TODO(), url) } else { if !namedOk { http.Error(w, "Current storage layer does not support storing a named url", http.StatusBadRequest) return } - err = named.SaveName(short, url) + err = named.SaveName(context.TODO(), short, url) } if err != nil { http.Error(w, fmt.Sprintf("Failed to save '%s' to '%s' because: %s", url, short, err), http.StatusInternalServerError) diff --git a/storage/filesystem.go b/storage/filesystem.go index bb38fe6..47c4196 100644 --- a/storage/filesystem.go +++ b/storage/filesystem.go @@ -1,6 +1,7 @@ package storage import ( + "context" "io/ioutil" "os" "path/filepath" @@ -30,7 +31,7 @@ func (s *Filesystem) Code(url string) string { return strconv.FormatUint(s.c, 36) } -func (s *Filesystem) Save(url string) (string, error) { +func (s *Filesystem) Save(ctx context.Context, url string) (string, error) { if _, err := validateURL(url); err != nil { return "", err } @@ -61,7 +62,7 @@ func FlattenPath(path string, separator string) string { return strings.Replace(path, string(os.PathSeparator), separator, -1) } -func (s *Filesystem) SaveName(rawShort, url string) error { +func (s *Filesystem) SaveName(ctx context.Context, rawShort, url string) error { short, err := sanitizeShort(rawShort) if err != nil { return err @@ -82,7 +83,7 @@ func (s *Filesystem) SaveName(rawShort, url string) error { return err } -func (s *Filesystem) Load(rawShort string) (string, error) { +func (s *Filesystem) Load(ctx context.Context, rawShort string) (string, error) { short, err := sanitizeShort(rawShort) if err != nil { return "", err diff --git a/storage/inmem.go b/storage/inmem.go index b1623af..770d511 100644 --- a/storage/inmem.go +++ b/storage/inmem.go @@ -1,6 +1,7 @@ package storage import ( + "context" "encoding/json" "fmt" "sync" @@ -46,7 +47,7 @@ func NewInmemFromMap(randLength int, initialShorts map[string]string) (*Inmem, e s, _ := NewInmem(randLength) for k, v := range initialShorts { - if err := s.SaveName(k, v); err != nil { + if err := s.SaveName(context.Background(), k, v); err != nil { return nil, errors.Wrap(err, "failed to save initial short") } } @@ -54,7 +55,7 @@ func NewInmemFromMap(randLength int, initialShorts map[string]string) (*Inmem, e return s, nil } -func (s *Inmem) Save(url string) (string, error) { +func (s *Inmem) Save(ctx context.Context, url string) (string, error) { if _, err := validateURL(url); err != nil { return "", err } @@ -76,7 +77,7 @@ func (s *Inmem) Save(url string) (string, error) { return "", ErrShortExhaustion } -func (s *Inmem) SaveName(rawShort string, url string) error { +func (s *Inmem) SaveName(ctx context.Context, rawShort string, url string) error { short, err := sanitizeShort(rawShort) if err != nil { return err @@ -91,7 +92,7 @@ func (s *Inmem) SaveName(rawShort string, url string) error { return nil } -func (s *Inmem) Load(rawShort string) (string, error) { +func (s *Inmem) Load(ctx context.Context, rawShort string) (string, error) { short, err := sanitizeShort(rawShort) if err != nil { return "", err diff --git a/storage/migrations/S3v3.go b/storage/migrations/S3v3.go index d82cd73..378416a 100644 --- a/storage/migrations/S3v3.go +++ b/storage/migrations/S3v3.go @@ -1,6 +1,10 @@ package migrations -import "github.com/thomaso-mirodin/go-shorten/storage" +import ( + "context" + + "github.com/thomaso-mirodin/go-shorten/storage" +) func init() { storage.SupportedStorageTypes["S3v3"] = new(interface{}) @@ -12,12 +16,12 @@ type S3v2MigrationStore struct { *storage.S3 } -func (s *S3v2MigrationStore) Load(short string) (long string, err error) { - long, err = s.S3.Load(short) +func (s *S3v2MigrationStore) Load(ctx context.Context, short string) (long string, err error) { + long, err = s.S3.Load(ctx, short) if err != nil { return } - err = s.S3.SaveName(short, long) + err = s.S3.SaveName(ctx, short, long) return } diff --git a/storage/multistorage/loader.go b/storage/multistorage/loader.go index 52ff8f2..b2ba732 100644 --- a/storage/multistorage/loader.go +++ b/storage/multistorage/loader.go @@ -1,20 +1,22 @@ package multistorage import ( + "context" + "github.com/pkg/errors" "github.com/thomaso-mirodin/go-shorten/storage" ) // Loaders are expected to process the slice of stores and return the result of Load(short) from one of them. Should return ErrEmpty if stores is empty -type Loader func(short string, stores []storage.NamedStorage) (string, error) +type Loader func(ctx context.Context, short string, stores []storage.NamedStorage) (string, error) -func loadFirstFunc(short string, stores []storage.NamedStorage) (string, error) { +func loadFirstFunc(ctx context.Context, short string, stores []storage.NamedStorage) (string, error) { if len(stores) == 0 { return "", ErrEmpty } for _, store := range stores { - long, err := store.Load(short) + long, err := store.Load(ctx, short) if err == storage.ErrShortNotSet { continue } @@ -26,14 +28,14 @@ func loadFirstFunc(short string, stores []storage.NamedStorage) (string, error) var ErrUnexpectedMultipleAnswers = errors.New("MultiStorage: results returned were not the same") -func loadCompareAllResultsFunc(short string, stores []storage.NamedStorage) (string, error) { +func loadCompareAllResultsFunc(ctx context.Context, short string, stores []storage.NamedStorage) (string, error) { if len(stores) == 0 { return "", ErrEmpty } results := make([]loadResult, 0, len(stores)) for _, store := range stores { - s, err := store.Load(short) + s, err := store.Load(ctx, short) results = append(results, loadResult{s, err}) } diff --git a/storage/multistorage/loader_internal_test.go b/storage/multistorage/loader_internal_test.go index fe445d6..71b517e 100644 --- a/storage/multistorage/loader_internal_test.go +++ b/storage/multistorage/loader_internal_test.go @@ -1,6 +1,7 @@ package multistorage import ( + "context" "testing" "github.com/pkg/errors" @@ -78,7 +79,7 @@ func TestLoadFirstFunc(t *testing.T) { t.Parallel() t.Logf("querying for %q, expecting (%q,%#v)", tt.inputShort, tt.expectedLong, tt.expectedErr) - long, err := loadFirstFunc(tt.inputShort, tt.stores) + long, err := loadFirstFunc(context.Background(), tt.inputShort, tt.stores) t.Logf("got: (%q, %#v)", long, err) if cause := errors.Cause(err); cause != tt.expectedErr { t.Errorf("unexpected error: expected(%#v) != actual(%#v)", tt.expectedErr, cause) @@ -165,7 +166,7 @@ func TestLoadCompareAllResultsFunc(t *testing.T) { t.Parallel() t.Logf("querying for %q, expecting (%q,%#v)", tt.inputShort, tt.expectedLong, tt.expectedErr) - long, err := loadCompareAllResultsFunc(tt.inputShort, tt.stores) + long, err := loadCompareAllResultsFunc(context.Background(), tt.inputShort, tt.stores) t.Logf("got: (%q, %#v)", long, err) if cause := errors.Cause(err); cause != tt.expectedErr { t.Errorf("unexpected error: expected(%#v) != actual(%#v)", tt.expectedErr, cause) diff --git a/storage/multistorage/multistorage.go b/storage/multistorage/multistorage.go index 60dafe1..d4a9b58 100644 --- a/storage/multistorage/multistorage.go +++ b/storage/multistorage/multistorage.go @@ -1,6 +1,8 @@ package multistorage import ( + "context" + "github.com/pkg/errors" "github.com/thomaso-mirodin/go-shorten/storage" ) @@ -51,19 +53,19 @@ func (s *MultiStorage) validateStore() error { } // Load with a basic MultiStorage will query the underlying storages (in order) returning when either a response or error is encountered, only returning an ErrShortNotSet when all underlying storages have been exhausted. -func (s *MultiStorage) Load(short string) (string, error) { +func (s *MultiStorage) Load(ctx context.Context, short string) (string, error) { if err := s.validateStore(); err != nil { return "", errors.Wrap(err, "failed to validate underlying store") } - return s.loader(short, s.stores) + return s.loader(ctx, short, s.stores) } // SaveName will return the first successful insure that all -func (s *MultiStorage) SaveName(short string, long string) error { +func (s *MultiStorage) SaveName(ctx context.Context, short string, long string) error { if err := s.validateStore(); err != nil { return errors.Wrap(err, "failed to validate underlying store") } - return s.saver(short, long, s.stores) + return s.saver(ctx, short, long, s.stores) } diff --git a/storage/multistorage/multistorage_test.go b/storage/multistorage/multistorage_test.go index d453b40..c402ffb 100644 --- a/storage/multistorage/multistorage_test.go +++ b/storage/multistorage/multistorage_test.go @@ -1,6 +1,7 @@ package multistorage_test import ( + "context" "testing" "github.com/pkg/errors" @@ -41,13 +42,13 @@ func TestSingleBackend(t *testing.T) { } t.Logf("Saving %q->%q", inputShort, inputLong) - if err := m.SaveName(inputShort, inputLong); err != nil { + if err := m.SaveName(context.Background(), inputShort, inputLong); err != nil { t.Fatalf("error saving %q->%q into the store", inputShort, inputLong) } t.Logf("Got: %v", err) t.Logf("Loading %q", inputShort) - long, err := m.Load(inputShort) + long, err := m.Load(context.Background(), inputShort) t.Logf("Got: %q, %v", long, err) if err != nil { t.Fatalf("error loading value that should exist: %q", err) @@ -77,7 +78,7 @@ func TestMultipleBackendLoad(t *testing.T) { for _, input := range inputShorts { for inputShort, expectedLong := range input { t.Logf("Loading %q", inputShort) - long, err := m.Load(inputShort) + long, err := m.Load(context.Background(), inputShort) t.Logf("Got: %q, %v", long, err) if err != nil { t.Errorf("loading %q returned an error: %q", inputShort, err) diff --git a/storage/multistorage/saver.go b/storage/multistorage/saver.go index 8f5d8d3..004f611 100644 --- a/storage/multistorage/saver.go +++ b/storage/multistorage/saver.go @@ -1,22 +1,24 @@ package multistorage import ( + "context" + multierror "github.com/hashicorp/go-multierror" "github.com/pkg/errors" "github.com/thomaso-mirodin/go-shorten/storage" ) // Saver are expected to process a slice of storages and return a result of SaveName(short, url string) -type Saver func(short string, url string, stores []storage.NamedStorage) error +type Saver func(ctx context.Context, short string, url string, stores []storage.NamedStorage) error -func saveAllFunc(short string, url string, stores []storage.NamedStorage) error { +func saveAllFunc(ctx context.Context, short string, url string, stores []storage.NamedStorage) error { if len(stores) == 0 { return ErrEmpty } errs := new(multierror.Error) for _, store := range stores { - err := store.SaveName(short, url) + err := store.SaveName(ctx, short, url) if err != nil { multierror.Append( @@ -30,14 +32,14 @@ func saveAllFunc(short string, url string, stores []storage.NamedStorage) error } // saveOnlyOnceFunc -func saveOnlyOnceFunc(short string, url string, stores []storage.NamedStorage) error { +func saveOnlyOnceFunc(ctx context.Context, short string, url string, stores []storage.NamedStorage) error { if len(stores) == 0 { return ErrEmpty } errs := new(multierror.Error) for _, store := range stores { - err := store.SaveName(short, url) + err := store.SaveName(ctx, short, url) if err == nil { return nil diff --git a/storage/multistorage/saver_internal_test.go b/storage/multistorage/saver_internal_test.go index f1572e9..858bcd7 100644 --- a/storage/multistorage/saver_internal_test.go +++ b/storage/multistorage/saver_internal_test.go @@ -1,6 +1,7 @@ package multistorage import ( + "context" "testing" "github.com/pkg/errors" @@ -21,7 +22,7 @@ func (stv saveTestData) testFunc(t *testing.T, testedFunc Saver) { stores := stv.inputStores t.Logf("saving (%q, %q, %q), expecting (%#v)", stv.inputShort, stv.inputURL, stv.inputStores, stv.expectedErr) - err := testedFunc(stv.inputShort, stv.inputURL, stores) + err := testedFunc(context.Background(), stv.inputShort, stv.inputURL, stores) t.Logf("got: (%#v)", err) if cause := errors.Cause(err); cause != stv.expectedErr { diff --git a/storage/regex.go b/storage/regex.go index ec43285..45ee793 100644 --- a/storage/regex.go +++ b/storage/regex.go @@ -1,6 +1,7 @@ package storage import ( + "context" "fmt" "regexp" ) @@ -37,7 +38,7 @@ func NewRegexFromList(redirects map[string]string) (*Regex, error) { }, nil } -func (r Regex) Load(short string) (string, error) { +func (r Regex) Load(ctx context.Context, short string) (string, error) { // Regex intentionally doesn't do sanitization, each regex can have whatever flexability it wants for _, remap := range r.remaps { @@ -49,7 +50,7 @@ func (r Regex) Load(short string) (string, error) { return "", ErrShortNotSet } -func (r Regex) SaveName(short string, long string) (string, error) { +func (r Regex) SaveName(ctx context.Context, short string, long string) (string, error) { // Regex intentionally doesn't do sanitization, each regex can have whatever flexability it wants return "", fmt.Errorf("regex doesn't yet support saving after creation") diff --git a/storage/regex_test.go b/storage/regex_test.go index 23ef2e4..67378c9 100644 --- a/storage/regex_test.go +++ b/storage/regex_test.go @@ -1,6 +1,7 @@ package storage import ( + "context" "testing" "github.com/pkg/errors" @@ -29,7 +30,7 @@ func TestRegexLoad(t *testing.T) { for _, tt := range testTable { t.Logf("Table: %#v", tt) - actual, err := r.Load(tt.in) + actual, err := r.Load(context.Background(), tt.in) if err != tt.err { t.Errorf("actual err (%q) != expected err (%q)", err, tt.err) } diff --git a/storage/s3.go b/storage/s3.go index 17cf8a2..64b70cf 100644 --- a/storage/s3.go +++ b/storage/s3.go @@ -2,6 +2,7 @@ package storage import ( "bytes" + "context" "crypto/sha256" "encoding/hex" "encoding/json" @@ -118,7 +119,7 @@ func (s *S3) saveKey(short, url string) (err error) { return nil } -func (s *S3) Save(url string) (string, error) { +func (s *S3) Save(ctx context.Context, url string) (string, error) { if _, err := validateURL(url); err != nil { return "", err } @@ -139,7 +140,7 @@ func (s *S3) Save(url string) (string, error) { return "", ErrShortExhaustion } -func (s *S3) SaveName(rawShort string, url string) error { +func (s *S3) SaveName(ctx context.Context, rawShort string, url string) error { short, err := sanitizeShort(rawShort) if err != nil { return err @@ -151,7 +152,7 @@ func (s *S3) SaveName(rawShort string, url string) error { return s.saveKey(short, url) } -func (s *S3) Load(rawShort string) (string, error) { +func (s *S3) Load(ctx context.Context, rawShort string) (string, error) { short, err := sanitizeShort(rawShort) if err != nil { return "", err diff --git a/storage/s3_test.go b/storage/s3_test.go index 3c08e76..5d3e827 100644 --- a/storage/s3_test.go +++ b/storage/s3_test.go @@ -1,6 +1,7 @@ package storage_test import ( + "context" "testing" "github.com/aws/aws-sdk-go/aws" @@ -73,7 +74,7 @@ func BenchmarkS3Save(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - named.SaveName("short", "long") + named.SaveName(context.Background(), "short", "long") } } @@ -82,10 +83,10 @@ func BenchmarkS3Load(b *testing.B) { named, ok := s.(storage.NamedStorage) require.True(b, ok) - named.SaveName("short", "long") + named.SaveName(context.Background(), "short", "long") b.ResetTimer() for i := 0; i < b.N; i++ { - named.Load("short") + named.Load(context.Background(), "short") } } diff --git a/storage/storage.go b/storage/storage.go index 4e31aa5..60460b3 100644 --- a/storage/storage.go +++ b/storage/storage.go @@ -6,26 +6,27 @@ package storage import ( + "context" "errors" "net/url" "strings" ) type Storage interface { - // Load(string) takes a short URL and returns the original full URL by retrieving it from storage - Load(short string) (string, error) + // Load(ctx, string) takes a short URL and returns the original full URL by retrieving it from storage + Load(ctx context.Context, short string) (string, error) } type UnnamedStorage interface { Storage - // Save(string) takes a full URL and returns the short URL after saving it to storage - Save(url string) (string, error) + // Save(ctx, string) takes a full URL and returns the short URL after saving it to storage + Save(ctx context.Context, url string) (string, error) } type NamedStorage interface { Storage // SaveName takes a short and a url and saves the name to use for saving a url - SaveName(short string, url string) error + SaveName(ctx context.Context, short string, url string) error } var SupportedStorageTypes = make(map[string]interface{}) diff --git a/storage/storage_test.go b/storage/storage_test.go index f88fa86..801f053 100644 --- a/storage/storage_test.go +++ b/storage/storage_test.go @@ -1,6 +1,7 @@ package storage_test import ( + "context" "fmt" "log" "math/rand" @@ -32,10 +33,10 @@ func saveSomething(s storage.Storage) (short string, long string, err error) { long = "http://" + randString(20) + ".com" if namedOk { - err := named.SaveName(short, long) + err := named.SaveName(context.Background(), short, long) return short, long, err } else if unnamedOk { - short, err := unnamed.Save(long) + short, err := unnamed.Save(context.Background(), long) return short, long, err } else { return "", "", fmt.Errorf("Storage isn't named or unnamed, can't save anything") @@ -85,7 +86,7 @@ func TestUnnamedStorageSave(t *testing.T) { unnamedStorage, ok := setupStorage(t).(storage.UnnamedStorage) if assert.True(t, ok, name) { - code, err := unnamedStorage.Save(testURL) + code, err := unnamedStorage.Save(context.Background(), testURL) t.Logf("[%s] unnamedStorage.Save(\"%s\") -> %#v", name, testURL, code) assert.Nil(t, err, name) } @@ -104,7 +105,7 @@ func TestNamedStorageSave(t *testing.T) { namedStorage, ok := setupStorage(t).(storage.NamedStorage) if assert.True(t, ok, name) { - err := namedStorage.SaveName(testCode, testURL) + err := namedStorage.SaveName(context.Background(), testCode, testURL) t.Logf("[%s] namedStorage.SaveName(\"%s\", \"%s\") -> %#v", name, testCode, testURL, err) assert.Nil(t, err, name) } @@ -124,13 +125,13 @@ func TestNamedStorageNormalization(t *testing.T) { namedStorage, ok := setupStorage(t).(storage.NamedStorage) if assert.True(t, ok, name) { - err := namedStorage.SaveName(testCode, testURL) + err := namedStorage.SaveName(context.Background(), testCode, testURL) t.Logf("[%s] namedStorage.SaveName(\"%s\", \"%s\") -> %#v", name, testCode, testURL, err) assert.Nil(t, err, name) - a, err := namedStorage.Load(testCode) + a, err := namedStorage.Load(context.Background(), testCode) assert.Nil(t, err, name) - b, err := namedStorage.Load(testNormalizedCode) + b, err := namedStorage.Load(context.Background(), testNormalizedCode) assert.Nil(t, err, name) assert.Equal(t, a, b) @@ -146,7 +147,7 @@ func TestMissingLoad(t *testing.T) { setupStorage := setupStorage t.Run(name, func(t *testing.T) { - long, err := setupStorage(t).Load(testCode) + long, err := setupStorage(t).Load(context.Background(), testCode) t.Logf("[%s] storage.Load(\"%s\") -> %#v, %#v", name, testCode, long, err) assert.NotNil(t, err, name) assert.Equal(t, err, storage.ErrShortNotSet, name) @@ -165,7 +166,7 @@ func TestLoad(t *testing.T) { t.Logf("[%s] saveSomething(s) -> %#v, %#v, %#v", name, short, long, err) assert.Nil(t, err, name) - newLong, err := s.Load(short) + newLong, err := s.Load(context.Background(), short) t.Logf("[%s] storage.Load(\"%s\") -> %#v, %#v", name, short, long, err) assert.Nil(t, err, name) @@ -202,12 +203,12 @@ func TestNamedStorageNames(t *testing.T) { for short, e := range shortNames { t.Logf("[%s] Saving URL '%s' should result in '%s'", storageName, short, e) - err := namedStorage.SaveName(short, testURL) + err := namedStorage.SaveName(context.Background(), short, testURL) assert.Equal(t, err, e, fmt.Sprintf("[%s] Saving URL '%s' should've resulted in '%s'", storageName, short, e)) if err == nil { t.Logf("[%s] Loading URL '%s' should result in '%s'", storageName, short, e) - url, err := namedStorage.Load(short) + url, err := namedStorage.Load(context.Background(), short) assert.Equal(t, err, e, fmt.Sprintf("[%s] Loading URL '%s' should've resulted in '%s'", storageName, short, e)) assert.Equal(t, url, testURL, "Saved URL shoud've matched")