diff --git a/services/skus/controllers.go b/services/skus/controllers.go index b8dfdb424..6874a10a3 100644 --- a/services/skus/controllers.go +++ b/services/skus/controllers.go @@ -847,47 +847,51 @@ type VoteRequest struct { // MakeVote is the handler for making a vote using credentials func MakeVote(service *Service) handlers.AppHandler { return handlers.AppHandler(func(w http.ResponseWriter, r *http.Request) *handlers.AppError { - var ( - req VoteRequest - ctx = r.Context() - ) - err := requestutils.ReadJSON(ctx, r.Body, &req) - if err != nil { + ctx := r.Context() + + req := VoteRequest{} + if err := requestutils.ReadJSON(ctx, r.Body, &req); err != nil { return handlers.WrapError(err, "Error in request body", http.StatusBadRequest) } - logger := logging.Logger(ctx, "skus.MakeVote") - - _, err = govalidator.ValidateStruct(req) - if err != nil { + if _, err := govalidator.ValidateStruct(req); err != nil { return handlers.WrapValidationError(err) } - err = service.Vote(ctx, req.Credentials, req.Vote) - if err != nil { + lg := logging.Logger(ctx, "skus").With().Str("func", "MakeVote").Logger() + + if err := service.Vote(ctx, req.Credentials, req.Vote); err != nil { switch err.(type) { case govalidator.Error: - logger.Warn().Err(err).Msg("failed vote validation") + lg.Warn().Err(err).Msg("failed vote validation") return handlers.WrapValidationError(err) + case govalidator.Errors: - logger.Warn().Err(err).Msg("failed multiple vote validation") + lg.Warn().Err(err).Msg("failed multiple vote validation") return handlers.WrapValidationError(err) + default: // check for custom vote invalidations if errors.Is(err, ErrInvalidSKUToken) { verr := handlers.ValidationError("failed to validate sku token", nil) + data := []string{} if errors.Is(err, ErrInvalidSKUTokenSKU) { data = append(data, "invalid sku value") } + if errors.Is(err, ErrInvalidSKUTokenBadMerchant) { data = append(data, "invalid merchant value") } + verr.Data = data - logger.Warn().Err(err).Msg("failed sku validations") + lg.Warn().Err(err).Msg("failed sku validations") + return verr } - logger.Warn().Err(err).Msg("failed to perform vote") + + lg.Warn().Err(err).Msg("failed to perform vote") + return handlers.WrapError(err, "Error making vote", http.StatusBadRequest) } } diff --git a/services/skus/credentials.go b/services/skus/credentials.go index 28ce6a4cb..fd8e57fe5 100644 --- a/services/skus/credentials.go +++ b/services/skus/credentials.go @@ -520,6 +520,10 @@ var generateCredentialRedemptions = func(ctx context.Context, cb []CredentialBin } } + if issuer == nil { + return nil, model.ErrIssuerNotFound + } + requestCredentials[i].Issuer = issuer.Name() requestCredentials[i].TokenPreimage = cb[i].TokenPreimage requestCredentials[i].Signature = cb[i].Signature diff --git a/services/skus/model/model.go b/services/skus/model/model.go index 5edc68108..bdb6d1642 100644 --- a/services/skus/model/model.go +++ b/services/skus/model/model.go @@ -413,6 +413,7 @@ type OrderItemRequest struct { // CreateOrderRequestNew includes information needed to create an order. type CreateOrderRequestNew struct { Email string `json:"email" validate:"required,email"` + CustomerID string `json:"customer_id"` // Optional. Currency string `json:"currency" validate:"required,iso4217"` StripeMetadata *OrderStripeMetadata `json:"stripe_metadata"` RadomMetadata *OrderRadomMetadata `json:"radom_metadata"` diff --git a/services/skus/service.go b/services/skus/service.go index 89a4f2a16..96e384ab3 100644 --- a/services/skus/service.go +++ b/services/skus/service.go @@ -2021,6 +2021,7 @@ func (s *Service) createStripeSession(ctx context.Context, req *model.CreateOrde sreq := createStripeSessionRequest{ orderID: oid, email: req.Email, + customerID: req.CustomerID, successURL: surl, cancelURL: curl, trialDays: order.GetTrialDays(), @@ -2105,7 +2106,17 @@ func (s *Service) redeemBlindedCred(ctx context.Context, w http.ResponseWriter, // FIXME: we shouldn't be using the issuer as the payload, it ideally would be a unique request identifier // to allow for more flexible idempotent behavior. if err := redeemFn(ctx, cred.Issuer, cred.TokenPreimage, cred.Signature, cred.Issuer); err != nil { - return handleRedeemFnError(ctx, w, kind, cred, err) + if !shouldRetryRedeemFn(kind, cred.Issuer, err) { + return handleRedeemFnError(ctx, w, kind, cred, err) + } + + // TODO: remove this as there should be no credentials in Production signed by brave-leo-premium-year. + // + // Fix for https://github.com/brave-intl/challenge-bypass-server/pull/371. + const leoa = "brave.com?sku=brave-leo-premium-year" + if err := redeemFn(ctx, leoa, cred.TokenPreimage, cred.Signature, cred.Issuer); err != nil { + return handleRedeemFnError(ctx, w, kind, cred, err) + } } // TODO(clD11): cleanup after quick fix @@ -2572,6 +2583,7 @@ func (s *Service) recreateStripeSession(ctx context.Context, dbi sqlx.ExecerCont req := createStripeSessionRequest{ orderID: ord.ID.String(), email: email, + customerID: xstripe.CustomerIDFromSession(oldSess), successURL: oldSess.SuccessURL, cancelURL: oldSess.CancelURL, trialDays: ord.GetTrialDays(), @@ -2798,6 +2810,7 @@ func chooseStripeSessID(ord *model.Order, canBeNewSessID string) (string, bool) type createStripeSessionRequest struct { orderID string email string + customerID string successURL string cancelURL string trialDays int64 @@ -2815,9 +2828,17 @@ func createStripeSession(ctx context.Context, cl stripeClient, req createStripeS LineItems: req.items, } - // Email might not be given. - // This could happen while recreating a session, and the email was not extracted from the old one. - if req.email != "" { + // Different processes can supply different info about customer: + // - when customerID is present, it takes precedence; + // - when email is present: + // - first, search for customer; + // - fallback to using the email directly. + // Based on the rules above, if both are present, customerID wins. + switch { + case req.customerID != "": + params.Customer = &req.customerID + + case req.customerID == "" && req.email != "": if cust, ok := cl.FindCustomer(ctx, req.email); ok && cust.Email != "" { params.Customer = &cust.ID } else { @@ -2876,6 +2897,12 @@ func handleRedeemFnError(ctx context.Context, w http.ResponseWriter, kind string return handlers.WrapError(err, "Error verifying credentials", http.StatusInternalServerError) } +func shouldRetryRedeemFn(kind, issuer string, err error) bool { + const leo = "brave.com?sku=brave-leo-premium" + + return kind == timeLimitedV2 && issuer == leo && err.Error() == cbr.ErrBadRequest.Error() +} + func newRadomGateway(env string) (*radom.Gateway, error) { switch env { case "development", "staging": diff --git a/services/skus/service_nonint_test.go b/services/skus/service_nonint_test.go index de3890fad..dca5d150d 100644 --- a/services/skus/service_nonint_test.go +++ b/services/skus/service_nonint_test.go @@ -4388,7 +4388,7 @@ func TestService_recreateStripeSession(t *testing.T) { }, { - name: "success_email_from_session", + name: "success_email_cust_from_session", given: tcGiven{ ordRepo: &repository.MockOrder{ FnAppendMetadata: func(ctx context.Context, dbi sqlx.ExecerContext, id uuid.UUID, key, val string) error { @@ -4405,7 +4405,7 @@ func TestService_recreateStripeSession(t *testing.T) { ID: "cs_test_id_old", SuccessURL: "https://example.com/success", CancelURL: "https://example.com/cancel", - Customer: &stripe.Customer{Email: "you@example.com"}, + Customer: &stripe.Customer{ID: "cus_id", Email: "you@example.com"}, } return result, nil @@ -4455,6 +4455,79 @@ func TestService_recreateStripeSession(t *testing.T) { }, }, + { + name: "success_email_from_request_cust_without_email", + given: tcGiven{ + ordRepo: &repository.MockOrder{ + FnAppendMetadata: func(ctx context.Context, dbi sqlx.ExecerContext, id uuid.UUID, key, val string) error { + if key == "stripeCheckoutSessionId" && val == "cs_test_id" { + return nil + } + + return model.Error("unexpected") + }, + }, + cl: &xstripe.MockClient{ + FnSession: func(ctx context.Context, id string, params *stripe.CheckoutSessionParams) (*stripe.CheckoutSession, error) { + result := &stripe.CheckoutSession{ + ID: "cs_test_id_old", + SuccessURL: "https://example.com/success", + CancelURL: "https://example.com/cancel", + Customer: &stripe.Customer{ID: "cus_id"}, + } + + return result, nil + }, + + FnFindCustomer: func(ctx context.Context, email string) (*stripe.Customer, bool) { + return nil, false + }, + + FnCreateSession: func(ctx context.Context, params *stripe.CheckoutSessionParams) (*stripe.CheckoutSession, error) { + if params.Customer != nil { + return nil, model.Error("unexpected_customer") + } + + if *params.CustomerEmail != "request@example.com" { + return nil, model.Error("unexpected_customer_email") + } + + result := &stripe.CheckoutSession{ + ID: "cs_test_id", + PaymentMethodTypes: []string{"card"}, + Mode: stripe.CheckoutSessionModeSubscription, + SuccessURL: *params.SuccessURL, + CancelURL: *params.CancelURL, + ClientReferenceID: *params.ClientReferenceID, + Subscription: &stripe.Subscription{ + ID: "sub_id", + Metadata: map[string]string{ + "orderID": *params.ClientReferenceID, + }, + }, + AllowPromotionCodes: true, + } + + return result, nil + }, + }, + ord: &model.Order{ + ID: uuid.Must(uuid.FromString("facade00-0000-4000-a000-000000000000")), + Items: []model.OrderItem{ + { + Quantity: 1, + Metadata: datastore.Metadata{"stripe_item_id": "stripe_item_id"}, + }, + }, + }, + oldSessID: "cs_test_id_old", + email: "request@example.com", + }, + exp: tcExpected{ + val: "cs_test_id", + }, + }, + { name: "success_email_from_request", given: tcGiven{ @@ -4473,7 +4546,6 @@ func TestService_recreateStripeSession(t *testing.T) { ID: "cs_test_id_old", SuccessURL: "https://example.com/success", CancelURL: "https://example.com/cancel", - Customer: &stripe.Customer{Email: "session@example.com"}, } return result, nil @@ -4564,7 +4636,84 @@ func TestCreateStripeSession(t *testing.T) { tests := []testCase{ { - name: "success_found_customer", + name: "success_cust_id", + given: tcGiven{ + cl: &xstripe.MockClient{ + FnCreateSession: func(ctx context.Context, params *stripe.CheckoutSessionParams) (*stripe.CheckoutSession, error) { + if params.Customer == nil || *params.Customer != "cus_id" { + return nil, model.Error("unexpected") + } + + result := &stripe.CheckoutSession{ID: "cs_test_id"} + + return result, nil + }, + + FnFindCustomer: func(ctx context.Context, email string) (*stripe.Customer, bool) { + panic("unexpected_find_customer") + }, + }, + + req: createStripeSessionRequest{ + orderID: "facade00-0000-4000-a000-000000000000", + customerID: "cus_id", + successURL: "https://example.com/success", + cancelURL: "https://example.com/cancel", + trialDays: 7, + items: []*stripe.CheckoutSessionLineItemParams{ + { + Quantity: ptrTo[int64](1), + Price: ptrTo("stripe_item_id"), + }, + }, + }, + }, + exp: tcExpected{ + val: "cs_test_id", + }, + }, + + { + name: "success_cust_id_email", + given: tcGiven{ + cl: &xstripe.MockClient{ + FnCreateSession: func(ctx context.Context, params *stripe.CheckoutSessionParams) (*stripe.CheckoutSession, error) { + if params.Customer == nil || *params.Customer != "cus_id" { + return nil, model.Error("unexpected") + } + + result := &stripe.CheckoutSession{ID: "cs_test_id"} + + return result, nil + }, + + FnFindCustomer: func(ctx context.Context, email string) (*stripe.Customer, bool) { + panic("unexpected_find_customer") + }, + }, + + req: createStripeSessionRequest{ + orderID: "facade00-0000-4000-a000-000000000000", + customerID: "cus_id", + email: "you@example.com", + successURL: "https://example.com/success", + cancelURL: "https://example.com/cancel", + trialDays: 7, + items: []*stripe.CheckoutSessionLineItemParams{ + { + Quantity: ptrTo[int64](1), + Price: ptrTo("stripe_item_id"), + }, + }, + }, + }, + exp: tcExpected{ + val: "cs_test_id", + }, + }, + + { + name: "success_email_found_customer", given: tcGiven{ cl: &xstripe.MockClient{ FnCreateSession: func(ctx context.Context, params *stripe.CheckoutSessionParams) (*stripe.CheckoutSession, error) { @@ -4598,7 +4747,7 @@ func TestCreateStripeSession(t *testing.T) { }, { - name: "success_customer_not_found", + name: "success_email_customer_not_found", given: tcGiven{ cl: &xstripe.MockClient{ FnFindCustomer: func(ctx context.Context, email string) (*stripe.Customer, bool) { @@ -4636,7 +4785,7 @@ func TestCreateStripeSession(t *testing.T) { }, { - name: "success_no_customer_email", + name: "success_email_no_customer_email", given: tcGiven{ cl: &xstripe.MockClient{ FnFindCustomer: func(ctx context.Context, email string) (*stripe.Customer, bool) { @@ -4663,7 +4812,7 @@ func TestCreateStripeSession(t *testing.T) { }, { - name: "success_no_trial_days", + name: "success_email_no_trial_days", given: tcGiven{ cl: &xstripe.MockClient{}, diff --git a/services/skus/vote.go b/services/skus/vote.go index c33a7b99c..1d5a9de04 100644 --- a/services/skus/vote.go +++ b/services/skus/vote.go @@ -236,9 +236,7 @@ func (s *Service) RunNextVoteDrainJob(ctx context.Context) (bool, error) { } // Vote based on the browser's attention -func (s *Service) Vote( - ctx context.Context, credentials []CredentialBinding, voteText string) error { - +func (s *Service) Vote(ctx context.Context, credentials []CredentialBinding, voteText string) error { logger := logging.Logger(ctx, "skus.Vote") var vote Vote @@ -248,8 +246,7 @@ func (s *Service) Vote( } // generate all the cb credential redemptions - requestCredentials, err := generateCredentialRedemptions( - context.WithValue(ctx, appctx.DatastoreCTXKey, s.Datastore), credentials) + requestCredentials, err := generateCredentialRedemptions(context.WithValue(ctx, appctx.DatastoreCTXKey, s.Datastore), credentials) if err != nil { return fmt.Errorf("error generating credential redemptions: %w", err) } diff --git a/services/skus/xstripe/mock.go b/services/skus/xstripe/mock.go index 442ff3f57..803eb4943 100644 --- a/services/skus/xstripe/mock.go +++ b/services/skus/xstripe/mock.go @@ -15,7 +15,10 @@ type MockClient struct { func (c *MockClient) Session(ctx context.Context, id string, params *stripe.CheckoutSessionParams) (*stripe.CheckoutSession, error) { if c.FnSession == nil { - result := &stripe.CheckoutSession{ID: id} + result := &stripe.CheckoutSession{ + ID: id, + Customer: &stripe.Customer{ID: "cus_id", Email: "customer@example.com"}, + } return result, nil } diff --git a/services/skus/xstripe/xstripe.go b/services/skus/xstripe/xstripe.go index 688928726..e9f2a5580 100644 --- a/services/skus/xstripe/xstripe.go +++ b/services/skus/xstripe/xstripe.go @@ -56,3 +56,13 @@ func CustomerEmailFromSession(sess *stripe.CheckoutSession) string { // Default to empty, Stripe will ask the customer. return "" } + +func CustomerIDFromSession(sess *stripe.CheckoutSession) string { + // Return the customer id only if the customer is present AND it has email set. + // Without the email, the customer record is not fully formed, and does not suit the use case. + if sess.Customer != nil && sess.Customer.Email != "" { + return sess.Customer.ID + } + + return "" +} diff --git a/services/skus/xstripe/xstripe_test.go b/services/skus/xstripe/xstripe_test.go index 6af57ac4f..bcf367d49 100644 --- a/services/skus/xstripe/xstripe_test.go +++ b/services/skus/xstripe/xstripe_test.go @@ -65,3 +65,52 @@ func TestCustomerEmailFromSession(t *testing.T) { }) } } + +func TestCustomerIDFromSession(t *testing.T) { + tests := []struct { + name string + exp string + given *stripe.CheckoutSession + }{ + { + name: "nil_customer_no_email", + given: &stripe.CheckoutSession{}, + }, + + { + name: "customer_empty_email", + given: &stripe.CheckoutSession{ + Customer: &stripe.Customer{}, + }, + }, + + { + name: "customer_email_no_id", + given: &stripe.CheckoutSession{ + Customer: &stripe.Customer{ + Email: "me@example.com", + }, + }, + }, + + { + name: "customer_email_id", + given: &stripe.CheckoutSession{ + Customer: &stripe.Customer{ + ID: "cus_id", + Email: "me@example.com", + }, + }, + exp: "cus_id", + }, + } + + for i := range tests { + tc := tests[i] + + t.Run(tc.name, func(t *testing.T) { + actual := CustomerIDFromSession(tc.given) + should.Equal(t, tc.exp, actual) + }) + } +} diff --git a/services/wallet/controllers_v3.go b/services/wallet/controllers_v3.go index 4a7122caf..21201ff8e 100644 --- a/services/wallet/controllers_v3.go +++ b/services/wallet/controllers_v3.go @@ -396,20 +396,23 @@ func LinkUpholdDepositAccountV3(s *Service) func(w http.ResponseWriter, r *http. ) } - // read post body if err := inputs.DecodeAndValidateReader(ctx, cuw, r.Body); err != nil { return cuw.HandleErrors(err) } - // get the wallet wallet, err := s.GetWallet(ctx, *id.UUID()) if err != nil { - if strings.Contains(err.Error(), "looking up wallet") { - return handlers.WrapError(err, "unable to find wallet", http.StatusNotFound) - } + l.Err(err).Msg("failed to get wallet") + return handlers.WrapError(err, "unable to get or create wallets", http.StatusServiceUnavailable) } + if wallet == nil { + l.Err(model.ErrWalletNotFound).Msg("wallet not found") + + return handlers.WrapError(err, "unable to find wallet", http.StatusNotFound) + } + var aa uuid.UUID if cuw.AnonymousAddress != "" {