From 86e21e33ed359662c3f76317f59465e656d331f0 Mon Sep 17 00:00:00 2001 From: Wessie Date: Tue, 2 Apr 2024 19:40:36 +0100 Subject: [PATCH] util: refactor daypass to be a more generic secret key This should allow us to reuse this code to also implement song downloading. --- util/daypass/daypass.go | 78 ------------------------------------ util/daypass/daypass_test.go | 19 --------- util/secret/secret.go | 59 +++++++++++++++++++++++++++ util/secret/secret_test.go | 43 ++++++++++++++++++++ 4 files changed, 102 insertions(+), 97 deletions(-) delete mode 100644 util/daypass/daypass.go delete mode 100644 util/daypass/daypass_test.go create mode 100644 util/secret/secret.go create mode 100644 util/secret/secret_test.go diff --git a/util/daypass/daypass.go b/util/daypass/daypass.go deleted file mode 100644 index a6d5f0bb..00000000 --- a/util/daypass/daypass.go +++ /dev/null @@ -1,78 +0,0 @@ -package daypass - -import ( - "context" - "crypto/rand" - "crypto/sha256" - "encoding/base64" - "sync" - "time" - - "github.com/rs/zerolog" -) - -func New(ctx context.Context) *Daypass { - return &Daypass{ - logger: zerolog.Ctx(ctx), - } -} - -type Daypass struct { - logger *zerolog.Logger - // mu protects update and daypass - mu sync.Mutex - update time.Time - daypass string -} - -type DaypassInfo struct { - // ValidUntil is the time this daypass will expire - ValidUntil time.Time - // Value is the current daypass - Value string -} - -// Info returns info about the daypass -func (di *Daypass) Info() DaypassInfo { - var info DaypassInfo - info.Value = di.get() - di.mu.Lock() - info.ValidUntil = di.update.Add(time.Hour * 24) - di.mu.Unlock() - return info -} - -// Is checks if the daypass given is equal to the current daypass -func (di *Daypass) Is(daypass string) bool { - return di.get() == daypass -} - -// get returns the current daypass and optionally generates a new one -// if it has expired -func (di *Daypass) get() string { - di.mu.Lock() - defer di.mu.Unlock() - - if time.Since(di.update) >= time.Hour*24 { - di.update = time.Now() - di.daypass = di.generate() - } - - return di.daypass -} - -// generate a new random daypass, this is a random sequence of -// bytes, sha256'd and base64 encoded before trimming it down to 16 characters -func (di *Daypass) generate() string { - var b [32]byte - if _, err := rand.Read(b[:]); err != nil { - di.logger.WithLevel(zerolog.PanicLevel).Err(err).Msg("daypass failure") - // keep using the old daypass if we error - return di.daypass - } - - b = sha256.Sum256(b[:]) - new := base64.RawURLEncoding.EncodeToString(b[:])[:16] - di.logger.Info().Str("value", new).Msg("daypass update") - return new -} diff --git a/util/daypass/daypass_test.go b/util/daypass/daypass_test.go deleted file mode 100644 index 1029cfc7..00000000 --- a/util/daypass/daypass_test.go +++ /dev/null @@ -1,19 +0,0 @@ -package daypass_test - -import ( - "context" - "testing" - "time" - - "github.com/R-a-dio/valkyrie/util/daypass" - "github.com/stretchr/testify/assert" -) - -func TestDaypassIs(t *testing.T) { - d := daypass.New(context.Background()) - - // should be equal right away - assert.True(t, d.Is(d.Info().Value)) - // and should be valid for a while - assert.True(t, d.Info().ValidUntil.After(time.Now())) -} diff --git a/util/secret/secret.go b/util/secret/secret.go new file mode 100644 index 00000000..842719ed --- /dev/null +++ b/util/secret/secret.go @@ -0,0 +1,59 @@ +package secret + +import ( + "crypto/rand" + "crypto/sha256" + "encoding/base64" + "time" +) + +const keySize = 256 + +func NewSecretWithKey(length int, key []byte) Secret { + return secret{length, key} +} + +func NewSecret(length int) (Secret, error) { + key := make([]byte, keySize) + _, err := rand.Read(key[:]) + if err != nil { + return nil, err + } + + return NewSecretWithKey(length, key), nil +} + +const DaypassLength = 16 + +type Secret interface { + Equal(secret string, salt []byte) bool + Get(salt []byte) (secret string) +} + +type secret struct { + maxLen int + key []byte +} + +func (s secret) Get(salt []byte) (secret string) { + sc := append(date(), s.key...) + if salt != nil { + sc = append(sc, salt...) + } + b := sha256.Sum256(sc) + res := base64.RawURLEncoding.EncodeToString(b[:]) + if len(res) > s.maxLen { + res = res[:s.maxLen] + } + return res +} + +func (s secret) Equal(secret string, salt []byte) bool { + return secret == s.Get(salt) +} + +var date = dateNow + +func dateNow() []byte { + return []byte(time.Now().Format(time.DateOnly)) +} diff --git a/util/secret/secret_test.go b/util/secret/secret_test.go new file mode 100644 index 00000000..24ca5d3d --- /dev/null +++ b/util/secret/secret_test.go @@ -0,0 +1,43 @@ +package secret_test + +import ( + "crypto/sha256" + "strconv" + "testing" + + "github.com/R-a-dio/valkyrie/util/daypass" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestSecretKeyGeneration(t *testing.T) { + s1, err := daypass.NewSecret(16) + require.NoError(t, err) + require.True(t, s1.Equal(s1.Get(nil), nil), "s1 should equal itself") + + s2, err := daypass.NewSecret(16) + require.NoError(t, err) + require.True(t, s2.Equal(s2.Get(nil), nil), "s2 should equal itself") + + // compare to each other. should never be true + assert.False(t, s1.Equal(s2.Get(nil), nil), "s2 should not equal s1") + assert.False(t, s2.Equal(s1.Get(nil), nil), "s1 should not equal s2") +} + +func TestSecretSaltComparison(t *testing.T) { + for i := 1; i < sha256.Size*2; i++ { + t.Run(strconv.Itoa(i), func(t *testing.T) { + s, err := daypass.NewSecret(i) + require.NoError(t, err) + + salt := []byte("testing world") + differentSalt := []byte("hello world") + + assert.True(t, s.Equal(s.Get(salt), salt), "same salt should equal") + assert.False(t, s.Equal(s.Get(salt), nil), "salt and no salt should not equal") + assert.False(t, s.Equal(s.Get(nil), salt), "no salt and salt should not equal") + assert.False(t, s.Equal(s.Get(salt), differentSalt), "salt and differentSalt should not equal") + assert.False(t, s.Equal(s.Get(differentSalt), salt), "differentSalt and salt should not equal") + }) + } +}