diff --git a/btd/issuer.go b/btd/issuer.go index 51b1d225..a4760838 100644 --- a/btd/issuer.go +++ b/btd/issuer.go @@ -113,32 +113,39 @@ func ApproveTokens(blindedTokens []*crypto.BlindedToken, key *crypto.SigningKey) } // VerifyTokenRedemption checks a redemption request against the observed request data -// and MAC according a set of keys. keys keeps a set of private keys that -// are ever used to sign the token so we can rotate private key easily -// Returns nil on success and an error on failure. +// and MAC according a set of keys. +// Keys keeps a set of private keys that are ever used to sign the token so we can rotate private key easily. func VerifyTokenRedemption(preimage *crypto.TokenPreimage, signature *crypto.VerificationSignature, payload string, keys []*crypto.SigningKey) error { var valid bool var err error - for _, key := range keys { + + for i := range keys { verifyTokenRedemptionCounter.Add(1) - // server derives the unblinded token using its key and the clients token preimage - unblindedToken := key.RederiveUnblindedToken(preimage) - // server derives the shared key from the unblinded token - timer := prometheus.NewTimer(verifyTokenDeriveKeyDuration) + // Derive the unblinded token using a server's key and the client's preimage. + unblindedToken := keys[i].RederiveUnblindedToken(preimage) + + timerUT := prometheus.NewTimer(verifyTokenDeriveKeyDuration) + + // Derive the shared key from the unblinded token. sharedKey := unblindedToken.DeriveVerificationKey() - timer.ObserveDuration() + _ = timerUT.ObserveDuration() + + timerVrf := prometheus.NewTimer(verifyTokenSignatureDuration) - // server signs the same message using the shared key and compares the client signature to its own - timer = prometheus.NewTimer(verifyTokenSignatureDuration) + // Sign the same message using the shared key and compare the client's signature with the server's. valid, err = sharedKey.Verify(signature, payload) if err != nil { + _ = timerVrf.ObserveDuration() + return err } + + _ = timerVrf.ObserveDuration() + if valid { break } - timer.ObserveDuration() } if !valid { diff --git a/kafka/signed_token_redeem_handler.go b/kafka/signed_token_redeem_handler.go index 88c26d0d..3709d4b9 100644 --- a/kafka/signed_token_redeem_handler.go +++ b/kafka/signed_token_redeem_handler.go @@ -8,15 +8,16 @@ import ( "strings" "time" - "github.com/brave-intl/challenge-bypass-server/model" + "github.com/rs/zerolog" + "github.com/segmentio/kafka-go" crypto "github.com/brave-intl/challenge-bypass-ristretto-ffi" + avroSchema "github.com/brave-intl/challenge-bypass-server/avro/generated" "github.com/brave-intl/challenge-bypass-server/btd" + "github.com/brave-intl/challenge-bypass-server/model" cbpServer "github.com/brave-intl/challenge-bypass-server/server" "github.com/brave-intl/challenge-bypass-server/utils" - "github.com/rs/zerolog" - "github.com/segmentio/kafka-go" ) /* @@ -94,10 +95,12 @@ func SignedTokenRedeemHandler( ) } - // Create a lookup for issuers & signing keys based on public key + // Create a lookup for issuers & signing keys based on public key. signedTokens := make(map[string]SignedIssuerToken) + now := time.Now() + for _, issuer := range issuers { - if !issuer.ExpiresAtTime().IsZero() && issuer.ExpiresAtTime().Before(time.Now()) { + if issuer.HasExpired(now) { continue } diff --git a/model/issuer.go b/model/issuer.go index 0f338f73..bf3371ee 100644 --- a/model/issuer.go +++ b/model/issuer.go @@ -1,12 +1,22 @@ package model import ( + "errors" + "time" + "github.com/google/uuid" "github.com/lib/pq" - "time" + + crypto "github.com/brave-intl/challenge-bypass-ristretto-ffi" ) -// Issuer of tokens +var ( + ErrInvalidIssuerType = errors.New("model: invalid issuer type") + ErrInvalidIV3Key = errors.New("model: issuer_v3: invalid key") + ErrIssuerV3NoCryptoKey = errors.New("model: issuer_v3: no crypto signing key for period") +) + +// Issuer represents an issuer of tokens. type Issuer struct { ID *uuid.UUID `json:"id" db:"issuer_id"` IssuerType string `json:"issuer_type" db:"issuer_type"` @@ -26,11 +36,79 @@ type Issuer struct { Keys []IssuerKeys `json:"keys" db:"-"` } -func (iss *Issuer) ExpiresAtTime() time.Time { - var t time.Time - if !iss.ExpiresAt.Valid { - return t +func (x *Issuer) ExpiresAtTime() time.Time { + if !x.ExpiresAt.Valid { + return time.Time{} + } + + return x.ExpiresAt.Time +} + +func (x *Issuer) HasExpired(now time.Time) bool { + expt := x.ExpiresAtTime() + + return !expt.IsZero() && expt.Before(now) +} + +func (x *Issuer) FindSigningKeys(now time.Time) ([]*crypto.SigningKey, error) { + if x.Version != 3 { + return nil, ErrInvalidIssuerType + } + + const leeway = 1 * time.Hour + + keys, err := x.findActiveKeys(now, leeway) + if err != nil { + return nil, err + } + + if len(keys) == 0 { + return nil, nil + } + + return parseSigningKeys(keys), nil +} + +// findActiveKeys finds active keys in x.Keys that are active for time now. +// +// It searches for strictly matching key first, and places it at the first position of the result. +// Then it searches for keys that match with leeway lw. +// The strictly matching key is excluded from search with lw. +func (x *Issuer) findActiveKeys(now time.Time, lw time.Duration) ([]*IssuerKeys, error) { + var result []*IssuerKeys + + for i := range x.Keys { + active, err := x.Keys[i].isActiveV3(now, 0) + if err != nil { + return nil, err + } + + if active { + result = append([]*IssuerKeys{&x.Keys[i]}, result...) + continue + } + + activeLw, err := x.Keys[i].isActiveV3(now, lw) + if err != nil { + return nil, err + } + + if activeLw { + result = append(result, &x.Keys[i]) + } + } + + return result, nil +} + +func parseSigningKeys(keys []*IssuerKeys) []*crypto.SigningKey { + result := make([]*crypto.SigningKey, 0, len(keys)) + + for i := range keys { + if key := keys[i].CryptoSigningKey(); key != nil { + result = append(result, key) + } } - return iss.ExpiresAt.Time + return result } diff --git a/model/issuer_keys.go b/model/issuer_keys.go index b88e7e5d..70a596e1 100644 --- a/model/issuer_keys.go +++ b/model/issuer_keys.go @@ -1,12 +1,14 @@ package model import ( - crypto "github.com/brave-intl/challenge-bypass-ristretto-ffi" - "github.com/google/uuid" "time" + + "github.com/google/uuid" + + crypto "github.com/brave-intl/challenge-bypass-ristretto-ffi" ) -// IssuerKeys - an issuer that uses time based keys +// IssuerKeys represents time-based keys. type IssuerKeys struct { ID *uuid.UUID `json:"id" db:"key_id"` SigningKey []byte `json:"-" db:"signing_key"` @@ -18,12 +20,33 @@ type IssuerKeys struct { EndAt *time.Time `json:"end_at" db:"end_at"` } -func (key *IssuerKeys) CryptoSigningKey() *crypto.SigningKey { - cryptoSigningKey := crypto.SigningKey{} - err := cryptoSigningKey.UnmarshalText(key.SigningKey) - if err != nil { +func (x *IssuerKeys) CryptoSigningKey() *crypto.SigningKey { + result := &crypto.SigningKey{} + if err := result.UnmarshalText(x.SigningKey); err != nil { return nil } - return &cryptoSigningKey + return result +} + +func (x *IssuerKeys) isActiveV3(now time.Time, lw time.Duration) (bool, error) { + if !x.isValidV3() { + return false, ErrInvalidIV3Key + } + + start, end := *x.StartAt, *x.EndAt + if lw == 0 { + return isTimeWithin(start, end, now), nil + } + + // Shift start/end earlier/later by lw, respectively. + return isTimeWithin(start.Add(-1*lw), end.Add(lw), now), nil +} + +func (x *IssuerKeys) isValidV3() bool { + return x.StartAt != nil && x.EndAt != nil +} + +func isTimeWithin(start, end, now time.Time) bool { + return now.After(start) && now.Before(end) } diff --git a/model/issuer_keys_test.go b/model/issuer_keys_test.go new file mode 100644 index 00000000..47a22bd7 --- /dev/null +++ b/model/issuer_keys_test.go @@ -0,0 +1,201 @@ +package model + +import ( + "testing" + "time" + + should "github.com/stretchr/testify/assert" + must "github.com/stretchr/testify/require" +) + +func TestIssuerKeys_isActiveV3(t *testing.T) { + type tcGiven struct { + key *IssuerKeys + now time.Time + lw time.Duration + } + + type tcExpected struct { + val bool + err error + } + + type testCase struct { + name string + given tcGiven + exp tcExpected + } + + tests := []testCase{ + { + name: "invalid_v3", + given: tcGiven{ + key: &IssuerKeys{}, + now: time.Date(2024, time.January, 1, 0, 0, 1, 0, time.UTC), + lw: 1 * time.Hour, + }, + exp: tcExpected{err: ErrInvalidIV3Key}, + }, + + { + name: "zero_leeway", + given: tcGiven{ + key: &IssuerKeys{ + StartAt: ptrTo(time.Date(2023, time.December, 31, 0, 0, 1, 0, time.UTC)), + EndAt: ptrTo(time.Date(2024, time.January, 2, 0, 0, 1, 0, time.UTC)), + }, + now: time.Date(2024, time.January, 1, 0, 0, 1, 0, time.UTC), + }, + exp: tcExpected{val: true}, + }, + + { + name: "leeway_1hour", + given: tcGiven{ + key: &IssuerKeys{ + StartAt: ptrTo(time.Date(2023, time.December, 31, 0, 0, 1, 0, time.UTC)), + EndAt: ptrTo(time.Date(2024, time.January, 2, 0, 0, 1, 0, time.UTC)), + }, + now: time.Date(2024, time.January, 2, 0, 0, 1, 0, time.UTC), + lw: 1 * time.Hour, + }, + exp: tcExpected{val: true}, + }, + } + + for i := range tests { + tc := tests[i] + + t.Run(tc.name, func(t *testing.T) { + actual, err := tc.given.key.isActiveV3(tc.given.now, tc.given.lw) + must.Equal(t, tc.exp.err, err) + + should.Equal(t, tc.exp.val, actual) + }) + } +} + +func TestIssuerKeys_isValidV3(t *testing.T) { + type testCase struct { + name string + given *IssuerKeys + exp bool + } + + tests := []testCase{ + { + name: "invalid_both", + given: &IssuerKeys{}, + }, + + { + name: "invalid_end", + given: &IssuerKeys{ + StartAt: ptrTo(time.Date(2023, time.December, 31, 0, 0, 1, 0, time.UTC)), + }, + }, + + { + name: "invalid_start", + given: &IssuerKeys{ + EndAt: ptrTo(time.Date(2024, time.January, 2, 0, 0, 1, 0, time.UTC)), + }, + }, + + { + name: "valid", + given: &IssuerKeys{ + StartAt: ptrTo(time.Date(2023, time.December, 31, 0, 0, 1, 0, time.UTC)), + EndAt: ptrTo(time.Date(2024, time.January, 2, 0, 0, 1, 0, time.UTC)), + }, + exp: true, + }, + } + + for i := range tests { + tc := tests[i] + + t.Run(tc.name, func(t *testing.T) { + actual := tc.given.isValidV3() + should.Equal(t, tc.exp, actual) + }) + } +} + +func TestIsTimeWithin(t *testing.T) { + type tcGiven struct { + start time.Time + end time.Time + now time.Time + } + + type testCase struct { + name string + given tcGiven + exp bool + } + + tests := []testCase{ + { + name: "zero_all", + given: tcGiven{}, + }, + + { + name: "zero_start_end", + given: tcGiven{ + now: time.Date(2024, time.January, 1, 0, 0, 1, 0, time.UTC), + }, + }, + + { + name: "zero_start_now", + given: tcGiven{ + end: time.Date(2024, time.January, 2, 0, 0, 1, 0, time.UTC), + }, + }, + + { + name: "zero_now_end", + given: tcGiven{ + start: time.Date(2023, time.December, 31, 0, 0, 1, 0, time.UTC), + }, + }, + + { + name: "zero_now", + given: tcGiven{ + start: time.Date(2023, time.December, 31, 0, 0, 1, 0, time.UTC), + end: time.Date(2024, time.January, 2, 0, 0, 1, 0, time.UTC), + }, + }, + + { + name: "invalid_inverse", + given: tcGiven{ + start: time.Date(2024, time.January, 2, 0, 0, 1, 0, time.UTC), + end: time.Date(2023, time.December, 31, 0, 0, 1, 0, time.UTC), + now: time.Date(2024, time.January, 1, 0, 0, 1, 0, time.UTC), + }, + }, + + { + name: "valid", + given: tcGiven{ + start: time.Date(2023, time.December, 31, 0, 0, 1, 0, time.UTC), + end: time.Date(2024, time.January, 2, 0, 0, 1, 0, time.UTC), + now: time.Date(2024, time.January, 1, 0, 0, 1, 0, time.UTC), + }, + exp: true, + }, + } + + for i := range tests { + tc := tests[i] + + t.Run(tc.name, func(t *testing.T) { + actual := isTimeWithin(tc.given.start, tc.given.end, tc.given.now) + should.Equal(t, tc.exp, actual) + }) + } +} diff --git a/model/issuer_test.go b/model/issuer_test.go new file mode 100644 index 00000000..f60bfa59 --- /dev/null +++ b/model/issuer_test.go @@ -0,0 +1,409 @@ +package model + +import ( + "testing" + "time" + + "github.com/google/uuid" + "github.com/lib/pq" + should "github.com/stretchr/testify/assert" + must "github.com/stretchr/testify/require" + + crypto "github.com/brave-intl/challenge-bypass-ristretto-ffi" +) + +func TestIssuer_HasExpired(t *testing.T) { + type tcGiven struct { + issuer *Issuer + now time.Time + } + + type testCase struct { + name string + given tcGiven + exp bool + } + + tests := []testCase{ + { + name: "expires_at_zero", + given: tcGiven{ + issuer: &Issuer{ + ID: ptrTo(uuid.MustParse("f100ded0-0000-4000-a000-000000000000")), + }, + now: time.Date(2024, time.January, 1, 0, 0, 1, 0, time.UTC), + }, + }, + + { + name: "expires_at_same", + given: tcGiven{ + issuer: &Issuer{ + ID: ptrTo(uuid.MustParse("f100ded0-0000-4000-a000-000000000000")), + ExpiresAt: pq.NullTime{ + Time: time.Date(2024, time.January, 1, 0, 0, 1, 0, time.UTC), + Valid: true, + }, + }, + now: time.Date(2024, time.January, 1, 0, 0, 1, 0, time.UTC), + }, + }, + + { + name: "expires_at_after", + given: tcGiven{ + issuer: &Issuer{ + ID: ptrTo(uuid.MustParse("f100ded0-0000-4000-a000-000000000000")), + ExpiresAt: pq.NullTime{ + Time: time.Date(2024, time.January, 2, 0, 0, 1, 0, time.UTC), + Valid: true, + }, + }, + now: time.Date(2024, time.January, 1, 0, 0, 1, 0, time.UTC), + }, + }, + + { + name: "expires_at_before", + given: tcGiven{ + issuer: &Issuer{ + ID: ptrTo(uuid.MustParse("f100ded0-0000-4000-a000-000000000000")), + ExpiresAt: pq.NullTime{ + Time: time.Date(2023, time.December, 31, 23, 59, 59, 0, time.UTC), + Valid: true, + }, + }, + now: time.Date(2024, time.January, 1, 0, 0, 1, 0, time.UTC), + }, + exp: true, + }, + } + + for i := range tests { + tc := tests[i] + + t.Run(tc.name, func(t *testing.T) { + actual := tc.given.issuer.HasExpired(tc.given.now) + should.Equal(t, tc.exp, actual) + }) + } +} + +func TestFindSigningKeys(t *testing.T) { + type tcGiven struct { + issuer *Issuer + now time.Time + } + + type tcExpected struct { + num int + err error + } + + type testCase struct { + name string + given tcGiven + exp tcExpected + } + + tests := []testCase{ + { + name: "not_v3", + given: tcGiven{ + issuer: &Issuer{Version: 2}, + now: time.Date(2024, time.January, 1, 1, 0, 1, 0, time.UTC), + }, + exp: tcExpected{err: ErrInvalidIssuerType}, + }, + + { + name: "invalid_key_both_times", + given: tcGiven{ + issuer: &Issuer{ + Version: 3, + Keys: []IssuerKeys{{}}, + }, + now: time.Date(2024, time.January, 1, 1, 0, 1, 0, time.UTC), + }, + exp: tcExpected{err: ErrInvalidIV3Key}, + }, + + { + name: "valid_key_inactive", + given: tcGiven{ + issuer: &Issuer{ + Version: 3, + Keys: []IssuerKeys{ + { + StartAt: ptrTo(time.Date(2024, time.January, 1, 0, 0, 1, 0, time.UTC)), + EndAt: ptrTo(time.Date(2024, time.January, 2, 0, 0, 1, 0, time.UTC)), + }, + }, + }, + now: time.Date(2023, time.December, 31, 0, 0, 1, 0, time.UTC), + }, + }, + + { + name: "valid_key_active", + given: tcGiven{ + issuer: &Issuer{ + Version: 3, + Keys: []IssuerKeys{ + { + SigningKey: mustRandomSigningKey(), + StartAt: ptrTo(time.Date(2024, time.January, 1, 0, 0, 1, 0, time.UTC)), + EndAt: ptrTo(time.Date(2024, time.January, 2, 0, 0, 1, 0, time.UTC)), + }, + }, + }, + now: time.Date(2024, time.January, 1, 1, 0, 1, 0, time.UTC), + }, + exp: tcExpected{num: 1}, + }, + } + + for i := range tests { + tc := tests[i] + + t.Run(tc.name, func(t *testing.T) { + actual, err := tc.given.issuer.FindSigningKeys(tc.given.now) + must.Equal(t, tc.exp.err, err) + + if tc.exp.err != nil { + return + } + + should.Equal(t, tc.exp.num, len(actual)) + }) + } +} + +func TestIssuer_findActiveKeys(t *testing.T) { + type tcGiven struct { + issuer *Issuer + now time.Time + lw time.Duration + } + + type tcExpected struct { + result []*IssuerKeys + err error + } + + type testCase struct { + name string + given tcGiven + exp tcExpected + } + + tests := []testCase{ + { + name: "empty", + given: tcGiven{ + issuer: &Issuer{}, + now: time.Date(2024, time.January, 1, 0, 0, 1, 0, time.UTC), + }, + }, + + { + name: "invalid_key", + given: tcGiven{ + issuer: &Issuer{ + Version: 3, + Keys: []IssuerKeys{{}}, + }, + now: time.Date(2024, time.January, 1, 0, 0, 1, 0, time.UTC), + }, + exp: tcExpected{err: ErrInvalidIV3Key}, + }, + + { + name: "valid_key_inactive", + given: tcGiven{ + issuer: &Issuer{ + Version: 3, + Keys: []IssuerKeys{ + { + StartAt: ptrTo(time.Date(2024, time.January, 1, 0, 0, 1, 0, time.UTC)), + EndAt: ptrTo(time.Date(2024, time.January, 2, 0, 0, 1, 0, time.UTC)), + }, + }, + }, + now: time.Date(2023, time.December, 31, 0, 0, 1, 0, time.UTC), + }, + }, + + { + name: "valid_key_active", + given: tcGiven{ + issuer: &Issuer{ + Version: 3, + Keys: []IssuerKeys{ + { + StartAt: ptrTo(time.Date(2024, time.January, 1, 0, 0, 1, 0, time.UTC)), + EndAt: ptrTo(time.Date(2024, time.January, 2, 0, 0, 1, 0, time.UTC)), + }, + }, + }, + now: time.Date(2024, time.January, 1, 1, 0, 1, 0, time.UTC), + }, + exp: tcExpected{ + result: []*IssuerKeys{ + { + StartAt: ptrTo(time.Date(2024, time.January, 1, 0, 0, 1, 0, time.UTC)), + EndAt: ptrTo(time.Date(2024, time.January, 2, 0, 0, 1, 0, time.UTC)), + }, + }, + }, + }, + + { + name: "valid_key_inactive_leeway", + given: tcGiven{ + issuer: &Issuer{ + Version: 3, + Keys: []IssuerKeys{ + { + StartAt: ptrTo(time.Date(2024, time.January, 1, 0, 0, 1, 0, time.UTC)), + EndAt: ptrTo(time.Date(2024, time.January, 2, 0, 0, 1, 0, time.UTC)), + }, + }, + }, + now: time.Date(2023, time.December, 31, 23, 30, 1, 0, time.UTC), + lw: 1 * time.Hour, + }, + exp: tcExpected{ + result: []*IssuerKeys{ + { + StartAt: ptrTo(time.Date(2024, time.January, 1, 0, 0, 1, 0, time.UTC)), + EndAt: ptrTo(time.Date(2024, time.January, 2, 0, 0, 1, 0, time.UTC)), + }, + }, + }, + }, + + { + name: "evq_strict_b_leeway_a", + given: tcGiven{ + issuer: &Issuer{ + Version: 3, + Keys: []IssuerKeys{ + { + SigningKey: []byte(`key_a`), + StartAt: ptrTo(time.Date(2024, time.May, 23, 0, 0, 0, 0, time.UTC)), + EndAt: ptrTo(time.Date(2024, time.May, 24, 0, 0, 0, 0, time.UTC)), + }, + + { + SigningKey: []byte(`key_b`), + StartAt: ptrTo(time.Date(2024, time.May, 24, 0, 0, 0, 0, time.UTC)), + EndAt: ptrTo(time.Date(2024, time.May, 25, 0, 0, 0, 0, time.UTC)), + }, + + { + SigningKey: []byte(`key_c`), + StartAt: ptrTo(time.Date(2024, time.May, 25, 0, 0, 0, 0, time.UTC)), + EndAt: ptrTo(time.Date(2024, time.May, 26, 0, 0, 0, 0, time.UTC)), + }, + }, + }, + now: time.Date(2024, time.May, 24, 0, 52, 25, 0, time.UTC), + lw: 1 * time.Hour, + }, + exp: tcExpected{ + result: []*IssuerKeys{ + { + SigningKey: []byte(`key_b`), + StartAt: ptrTo(time.Date(2024, time.May, 24, 0, 0, 0, 0, time.UTC)), + EndAt: ptrTo(time.Date(2024, time.May, 25, 0, 0, 0, 0, time.UTC)), + }, + + { + SigningKey: []byte(`key_a`), + StartAt: ptrTo(time.Date(2024, time.May, 23, 0, 0, 0, 0, time.UTC)), + EndAt: ptrTo(time.Date(2024, time.May, 24, 0, 0, 0, 0, time.UTC)), + }, + }, + }, + }, + + { + name: "evq_strict_b_leeway_c", + given: tcGiven{ + issuer: &Issuer{ + Version: 3, + Keys: []IssuerKeys{ + { + SigningKey: []byte(`key_a`), + StartAt: ptrTo(time.Date(2024, time.May, 23, 0, 0, 0, 0, time.UTC)), + EndAt: ptrTo(time.Date(2024, time.May, 24, 0, 0, 0, 0, time.UTC)), + }, + + { + SigningKey: []byte(`key_b`), + StartAt: ptrTo(time.Date(2024, time.May, 24, 0, 0, 0, 0, time.UTC)), + EndAt: ptrTo(time.Date(2024, time.May, 25, 0, 0, 0, 0, time.UTC)), + }, + + { + SigningKey: []byte(`key_c`), + StartAt: ptrTo(time.Date(2024, time.May, 25, 0, 0, 0, 0, time.UTC)), + EndAt: ptrTo(time.Date(2024, time.May, 26, 0, 0, 0, 0, time.UTC)), + }, + }, + }, + now: time.Date(2024, time.May, 24, 23, 52, 25, 0, time.UTC), + lw: 1 * time.Hour, + }, + exp: tcExpected{ + result: []*IssuerKeys{ + { + SigningKey: []byte(`key_b`), + StartAt: ptrTo(time.Date(2024, time.May, 24, 0, 0, 0, 0, time.UTC)), + EndAt: ptrTo(time.Date(2024, time.May, 25, 0, 0, 0, 0, time.UTC)), + }, + + { + SigningKey: []byte(`key_c`), + StartAt: ptrTo(time.Date(2024, time.May, 25, 0, 0, 0, 0, time.UTC)), + EndAt: ptrTo(time.Date(2024, time.May, 26, 0, 0, 0, 0, time.UTC)), + }, + }, + }, + }, + } + + for i := range tests { + tc := tests[i] + + t.Run(tc.name, func(t *testing.T) { + actual, err := tc.given.issuer.findActiveKeys(tc.given.now, tc.given.lw) + must.Equal(t, tc.exp.err, err) + + if tc.exp.err != nil { + return + } + + should.Equal(t, tc.exp.result, actual) + }) + } +} + +func ptrTo[T any](v T) *T { + return &v +} + +func mustRandomSigningKey() []byte { + key, err := crypto.RandomSigningKey() + if err != nil { + panic(err) + } + + data, err := key.MarshalText() + if err != nil { + panic(err) + } + + return data +} diff --git a/server/issuers.go b/server/issuers.go index 67bf7cf0..dec703a0 100644 --- a/server/issuers.go +++ b/server/issuers.go @@ -4,19 +4,21 @@ import ( "context" "encoding/json" "errors" - "github.com/brave-intl/challenge-bypass-server/model" "net/http" "os" "time" - "github.com/brave-intl/bat-go/libs/closers" - "github.com/brave-intl/bat-go/libs/handlers" - "github.com/brave-intl/bat-go/libs/middleware" - crypto "github.com/brave-intl/challenge-bypass-ristretto-ffi" "github.com/go-chi/chi" "github.com/lib/pq" "github.com/pressly/lg" "github.com/sirupsen/logrus" + + "github.com/brave-intl/bat-go/libs/closers" + "github.com/brave-intl/bat-go/libs/handlers" + "github.com/brave-intl/bat-go/libs/middleware" + crypto "github.com/brave-intl/challenge-bypass-ristretto-ffi" + + "github.com/brave-intl/challenge-bypass-server/model" ) type issuerResponse struct { @@ -369,8 +371,8 @@ func (c *Server) issuerCreateHandlerV1(w http.ResponseWriter, r *http.Request) * func makeIssuerResponse(iss *model.Issuer) issuerResponse { expiresAt := "" - if !iss.ExpiresAtTime().IsZero() { - expiresAt = iss.ExpiresAtTime().Format(time.RFC3339) + if expt := iss.ExpiresAtTime(); !expt.IsZero() { + expiresAt = expt.Format(time.RFC3339) } // Last key in array is the valid one diff --git a/server/tokens.go b/server/tokens.go index afbce46b..b2325103 100644 --- a/server/tokens.go +++ b/server/tokens.go @@ -4,20 +4,21 @@ import ( "database/sql" "encoding/json" "errors" - "fmt" - "github.com/brave-intl/challenge-bypass-server/model" "net/http" "net/url" "os" "time" + "github.com/go-chi/chi" + "github.com/google/uuid" + "github.com/sirupsen/logrus" + "github.com/brave-intl/bat-go/libs/handlers" "github.com/brave-intl/bat-go/libs/middleware" crypto "github.com/brave-intl/challenge-bypass-ristretto-ffi" + "github.com/brave-intl/challenge-bypass-server/btd" - "github.com/go-chi/chi" - "github.com/google/uuid" - "github.com/sirupsen/logrus" + "github.com/brave-intl/challenge-bypass-server/model" ) const ( @@ -47,6 +48,10 @@ type blindedTokenRedeemRequest struct { Signature *crypto.VerificationSignature `json:"signature"` } +func (r *blindedTokenRedeemRequest) isEmpty() bool { + return r.TokenPreimage == nil || r.Signature == nil +} + type blindedTokenRedeemResponse struct { Cohort int16 `json:"cohort"` } @@ -174,112 +179,119 @@ func (c *Server) blindedTokenIssuerHandler(w http.ResponseWriter, r *http.Reques } func (c *Server) blindedTokenRedeemHandlerV3(w http.ResponseWriter, r *http.Request) *handlers.AppError { - var response blindedTokenRedeemResponse - if issuerType := chi.URLParam(r, "type"); issuerType != "" { - issuer, err := c.fetchIssuerByType(r.Context(), issuerType) - if err != nil { - switch { - case errors.Is(err, sql.ErrNoRows): - return &handlers.AppError{ - Message: "Issuer not found", - Code: 404, - } - default: - c.Logger.WithError(err).Error("error fetching issuer") - return &handlers.AppError{ - Cause: errors.New("internal server error"), - Message: "Internal server error could not retrieve issuer", - Code: 500, - } - } - } + ctx := r.Context() - c.Logger.WithField("issuer", issuer). - Debug("retrieved issuer") + issuerType := chi.URLParamFromCtx(ctx, "type") + if issuerType == "" { + return handlers.RenderContent(ctx, blindedTokenRedeemResponse{}, w, http.StatusOK) + } - if issuer.Version != 3 { + issuer, err := c.fetchIssuerByType(ctx, issuerType) + if err != nil { + switch { + case errors.Is(err, sql.ErrNoRows): return &handlers.AppError{ - Message: "Issuer must be version 3", - Code: http.StatusBadRequest, + Message: "Issuer not found", + Code: http.StatusNotFound, } - } + default: + c.Logger.WithError(err).Error("error fetching issuer") - if issuer.ExpiresAtTime().IsZero() && issuer.ExpiresAtTime().Before(time.Now()) { return &handlers.AppError{ - Message: "Issuer has expired", - Code: http.StatusBadRequest, + Cause: errors.New("internal server error"), + Message: "Internal server error could not retrieve issuer", + Code: http.StatusInternalServerError, } } + } - var request blindedTokenRedeemRequest - if err := json.NewDecoder(http.MaxBytesReader(w, r.Body, maxRequestSize)).Decode(&request); err != nil { - c.Logger.Debug("Could not parse the request body") - return handlers.WrapError(err, "Could not parse the request body", 400) + if issuer.Version != 3 { + return &handlers.AppError{ + Message: "Issuer must be version 3", + Code: http.StatusBadRequest, } + } - if request.TokenPreimage == nil || request.Signature == nil { - c.Logger.Debug("Empty request") - return &handlers.AppError{ - Message: "Empty request", - Code: http.StatusBadRequest, - } + now := time.Now() + + if issuer.HasExpired(now) { + return &handlers.AppError{ + Message: "Issuer has expired", + Code: http.StatusBadRequest, } + } - var signingKey *crypto.SigningKey - for i, k := range issuer.Keys { - if k.StartAt == nil || k.EndAt == nil { - return &handlers.AppError{ - Message: "Issuer has invalid keys for v3", - Code: http.StatusBadRequest, - } - } + var request blindedTokenRedeemRequest + if err := json.NewDecoder(http.MaxBytesReader(w, r.Body, maxRequestSize)).Decode(&request); err != nil { + c.Logger.Debug("Could not parse the request body") + return handlers.WrapError(err, "Could not parse the request body", http.StatusBadRequest) + } - if k.StartAt.Before(time.Now()) && k.EndAt.After(time.Now()) { - pubKeyTxt, _ := k.CryptoSigningKey().PublicKey().MarshalText() - c.Logger.WithFields(logrus.Fields{ - "now": time.Now(), - "start_at": k.StartAt, - "end_at": k.EndAt, - "key": string(pubKeyTxt), - "i": fmt.Sprintf("%d", i), - }).Error("found appropriate key") - signingKey = k.CryptoSigningKey() - break - } + if request.isEmpty() { + return &handlers.AppError{ + Message: "Empty request", + Code: http.StatusBadRequest, } - if signingKey == nil { + } + + skeys, err := issuer.FindSigningKeys(now) + if err != nil { + switch { + case errors.Is(err, model.ErrInvalidIssuerType): return &handlers.AppError{ - Message: "Issuer has no key that corresponds to start < now < end", + Message: "Issuer must be version 3", Code: http.StatusBadRequest, } - } - if err := btd.VerifyTokenRedemption(request.TokenPreimage, request.Signature, request.Payload, - []*crypto.SigningKey{signingKey}); err != nil { + case errors.Is(err, model.ErrInvalidIV3Key): return &handlers.AppError{ - Message: "Could not verify that token redemption is valid", + Message: "Issuer has invalid keys for v3", Code: http.StatusBadRequest, } - } - if err := c.RedeemToken(issuer, request.TokenPreimage, request.Payload, 0); err != nil { - c.Logger.Error("error redeeming token") - if errors.Is(err, errDuplicateRedemption) { - return &handlers.AppError{ - Message: err.Error(), - Code: http.StatusConflict, - } + default: + return &handlers.AppError{ + Message: "Something went wrong", + Code: http.StatusBadRequest, } + } + } + + if len(skeys) == 0 { + c.Logger.WithFields(logrus.Fields{"now": now}).Error("failed to find appropriate key") + + return &handlers.AppError{ + Message: "Issuer has no key that corresponds to start < now < end", + Code: http.StatusBadRequest, + } + } + + if err := btd.VerifyTokenRedemption(request.TokenPreimage, request.Signature, request.Payload, skeys); err != nil { + return &handlers.AppError{ + Message: "Could not verify that token redemption is valid", + Code: http.StatusBadRequest, + } + } + + if err := c.RedeemToken(issuer, request.TokenPreimage, request.Payload, 0); err != nil { + c.Logger.Error("error redeeming token") + if errors.Is(err, errDuplicateRedemption) { return &handlers.AppError{ - Cause: err, - Message: "Could not mark token redemption", - Code: http.StatusInternalServerError, + Message: err.Error(), + Code: http.StatusConflict, } } - response = blindedTokenRedeemResponse{issuer.IssuerCohort} + + return &handlers.AppError{ + Cause: err, + Message: "Could not mark token redemption", + Code: http.StatusInternalServerError, + } } - return handlers.RenderContent(r.Context(), response, w, http.StatusOK) + result := blindedTokenRedeemResponse{issuer.IssuerCohort} + + return handlers.RenderContent(ctx, result, w, http.StatusOK) } func (c *Server) blindedTokenRedeemHandler(w http.ResponseWriter, r *http.Request) *handlers.AppError { @@ -305,11 +317,15 @@ func (c *Server) blindedTokenRedeemHandler(w http.ResponseWriter, r *http.Reques } } - var verified = false - var verifiedIssuer = &model.Issuer{} - var verifiedCohort = int16(0) + var ( + verified bool + verifiedIssuer = &model.Issuer{} + verifiedCohort = int16(0) + now = time.Now() + ) + for _, issuer := range issuers { - if !issuer.ExpiresAtTime().IsZero() && issuer.ExpiresAtTime().Before(time.Now()) { + if issuer.HasExpired(now) { continue } diff --git a/server/tokens_test.go b/server/tokens_test.go new file mode 100644 index 00000000..d9364461 --- /dev/null +++ b/server/tokens_test.go @@ -0,0 +1,56 @@ +package server + +import ( + "testing" + + should "github.com/stretchr/testify/assert" + + crypto "github.com/brave-intl/challenge-bypass-ristretto-ffi" +) + +func TestBlindedTokenRedeemRequest_isEmpty(t *testing.T) { + tests := []struct { + name string + given *blindedTokenRedeemRequest + exp bool + }{ + { + name: "no_token_preimage", + given: &blindedTokenRedeemRequest{ + Signature: &crypto.VerificationSignature{}, + }, + exp: true, + }, + + { + name: "no_signature", + given: &blindedTokenRedeemRequest{ + TokenPreimage: &crypto.TokenPreimage{}, + }, + exp: true, + }, + + { + name: "no_token_preimage_no_signature", + given: &blindedTokenRedeemRequest{}, + exp: true, + }, + + { + name: "valid", + given: &blindedTokenRedeemRequest{ + TokenPreimage: &crypto.TokenPreimage{}, + Signature: &crypto.VerificationSignature{}, + }, + }, + } + + for i := range tests { + tc := tests[i] + + t.Run(tc.name, func(t *testing.T) { + actual := tc.given.isEmpty() + should.Equal(t, tc.exp, actual) + }) + } +}