From 36136d6ffb81f3f26cd908141708e1031195659a Mon Sep 17 00:00:00 2001 From: Pavel Brm <5097196+pavelbrm@users.noreply.github.com> Date: Tue, 19 Nov 2024 17:25:42 +1300 Subject: [PATCH] fix: set email when creating new session for trial days (#2710) --- services/skus/controllers.go | 16 +- services/skus/datastore.go | 29 +- services/skus/instrumented_datastore.go | 14 - services/skus/mockdatastore.go | 30 -- services/skus/model/model.go | 20 +- services/skus/model/model_test.go | 137 ++++++-- services/skus/service.go | 39 ++- services/skus/service_nonint_test.go | 300 +++++++++++++++++- services/skus/storage/repository/mock.go | 9 + .../skus/storage/repository/repository.go | 18 +- .../storage/repository/repository_test.go | 105 ++++-- 11 files changed, 560 insertions(+), 157 deletions(-) diff --git a/services/skus/controllers.go b/services/skus/controllers.go index a8c313624..b8dfdb424 100644 --- a/services/skus/controllers.go +++ b/services/skus/controllers.go @@ -10,6 +10,7 @@ import ( "net/http" "os" "strconv" + "time" "github.com/asaskevich/govalidator" "github.com/go-chi/chi" @@ -310,11 +311,6 @@ func VoteRouter(service *Service, instrumentHandler middleware.InstrumentHandler return r } -type setTrialDaysRequest struct { - TrialDays int64 `json:"trialDays"` -} - -// TODO: refactor this to avoid multiple fetches of an order. func handleSetOrderTrialDays(svc *Service) handlers.AppHandler { return handlers.AppHandler(func(w http.ResponseWriter, r *http.Request) *handlers.AppError { ctx := r.Context() @@ -324,21 +320,19 @@ func handleSetOrderTrialDays(svc *Service) handlers.AppHandler { return handlers.ValidationError("request", map[string]interface{}{"orderID": err.Error()}) } - if err := svc.validateOrderMerchantAndCaveats(ctx, orderID); err != nil { - return handlers.ValidationError("merchant and caveats", map[string]interface{}{"orderMerchantAndCaveats": err.Error()}) - } - data, err := io.ReadAll(io.LimitReader(r.Body, reqBodyLimit10MB)) if err != nil { return handlers.WrapError(err, "failed to read request body", http.StatusBadRequest) } - req := &setTrialDaysRequest{} + req := &model.SetTrialDaysRequest{} if err := json.Unmarshal(data, req); err != nil { return handlers.WrapError(err, "failed to parse request", http.StatusBadRequest) } - if err := svc.SetOrderTrialDays(ctx, &orderID, req.TrialDays); err != nil { + now := time.Now().UTC() + + if err := svc.setOrderTrialDays(ctx, orderID, req, now); err != nil { return handlers.WrapError(err, "Error setting the trial days on the order", http.StatusInternalServerError) } diff --git a/services/skus/datastore.go b/services/skus/datastore.go index 75efe6cfd..929d72f44 100644 --- a/services/skus/datastore.go +++ b/services/skus/datastore.go @@ -37,8 +37,7 @@ type Datastore interface { datastore.Datastore CreateOrder(ctx context.Context, dbi sqlx.ExtContext, oreq *model.OrderNew, items []model.OrderItem) (*model.Order, error) - // SetOrderTrialDays - set the number of days of free trial for this order - SetOrderTrialDays(ctx context.Context, orderID *uuid.UUID, days int64) (*Order, error) + // GetOrder by ID GetOrder(orderID uuid.UUID) (*Order, error) // GetOrderByExternalID by the external id from the purchase vendor @@ -99,7 +98,6 @@ type orderStore interface { GetByExternalID(ctx context.Context, dbi sqlx.QueryerContext, extID string) (*model.Order, error) Create(ctx context.Context, dbi sqlx.QueryerContext, oreq *model.OrderNew) (*model.Order, error) SetLastPaidAt(ctx context.Context, dbi sqlx.ExecerContext, id uuid.UUID, when time.Time) error - SetTrialDays(ctx context.Context, dbi sqlx.QueryerContext, id uuid.UUID, ndays int64) (*model.Order, error) SetStatus(ctx context.Context, dbi sqlx.ExecerContext, id uuid.UUID, status string) error GetExpiresAtAfterISOPeriod(ctx context.Context, dbi sqlx.QueryerContext, id uuid.UUID) (time.Time, error) SetExpiresAt(ctx context.Context, dbi sqlx.ExecerContext, id uuid.UUID, when time.Time) error @@ -265,31 +263,6 @@ func (pg *Postgres) GetKey(id uuid.UUID, showExpired bool) (*Key, error) { return &key, nil } -// SetOrderTrialDays sets the number of days of free trial for this order and returns the updated result. -func (pg *Postgres) SetOrderTrialDays(ctx context.Context, orderID *uuid.UUID, days int64) (*Order, error) { - tx, err := pg.RawDB().BeginTxx(ctx, nil) - if err != nil { - return nil, fmt.Errorf("failed to create db tx: %w", err) - } - defer pg.RollbackTx(tx) - - result, err := pg.orderRepo.SetTrialDays(ctx, tx, *orderID, days) - if err != nil { - return nil, fmt.Errorf("failed to execute tx: %w", err) - } - - result.Items, err = pg.orderItemRepo.FindByOrderID(ctx, tx, *orderID) - if err != nil { - return nil, err - } - - if err := tx.Commit(); err != nil { - return nil, err - } - - return result, nil -} - // CreateOrder creates orders for Auto Contribute and Search Captcha. // // Deprecated: This method MUST NOT be used for Premium orders. diff --git a/services/skus/instrumented_datastore.go b/services/skus/instrumented_datastore.go index d6ed36a2e..4fd7c2632 100644 --- a/services/skus/instrumented_datastore.go +++ b/services/skus/instrumented_datastore.go @@ -653,20 +653,6 @@ func (_d DatastoreWithPrometheus) SendSigningRequest(ctx context.Context, signin return _d.base.SendSigningRequest(ctx, signingRequestWriter) } -// SetOrderTrialDays implements Datastore -func (_d DatastoreWithPrometheus) SetOrderTrialDays(ctx context.Context, orderID *uuid.UUID, days int64) (op1 *Order, err error) { - _since := time.Now() - defer func() { - result := "ok" - if err != nil { - result = "error" - } - - datastoreDurationSummaryVec.WithLabelValues(_d.instanceName, "SetOrderTrialDays", result).Observe(time.Since(_since).Seconds()) - }() - return _d.base.SetOrderTrialDays(ctx, orderID, days) -} - // UpdateOrder implements Datastore func (_d DatastoreWithPrometheus) UpdateOrder(orderID uuid.UUID, status string) (err error) { _since := time.Now() diff --git a/services/skus/mockdatastore.go b/services/skus/mockdatastore.go index 13a1ba768..be7ea32dc 100644 --- a/services/skus/mockdatastore.go +++ b/services/skus/mockdatastore.go @@ -692,21 +692,6 @@ func (mr *MockDatastoreMockRecorder) SendSigningRequest(ctx, signingRequestWrite return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendSigningRequest", reflect.TypeOf((*MockDatastore)(nil).SendSigningRequest), ctx, signingRequestWriter) } -// SetOrderTrialDays mocks base method. -func (m *MockDatastore) SetOrderTrialDays(ctx context.Context, orderID *go_uuid.UUID, days int64) (*Order, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "SetOrderTrialDays", ctx, orderID, days) - ret0, _ := ret[0].(*Order) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// SetOrderTrialDays indicates an expected call of SetOrderTrialDays. -func (mr *MockDatastoreMockRecorder) SetOrderTrialDays(ctx, orderID, days interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetOrderTrialDays", reflect.TypeOf((*MockDatastore)(nil).SetOrderTrialDays), ctx, orderID, days) -} - // UpdateOrder mocks base method. func (m *MockDatastore) UpdateOrder(orderID go_uuid.UUID, status string) error { m.ctrl.T.Helper() @@ -976,21 +961,6 @@ func (mr *MockorderStoreMockRecorder) SetStatus(ctx, dbi, id, status interface{} return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetStatus", reflect.TypeOf((*MockorderStore)(nil).SetStatus), ctx, dbi, id, status) } -// SetTrialDays mocks base method. -func (m *MockorderStore) SetTrialDays(ctx context.Context, dbi sqlx.QueryerContext, id go_uuid.UUID, ndays int64) (*model.Order, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "SetTrialDays", ctx, dbi, id, ndays) - ret0, _ := ret[0].(*model.Order) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// SetTrialDays indicates an expected call of SetTrialDays. -func (mr *MockorderStoreMockRecorder) SetTrialDays(ctx, dbi, id, ndays interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetTrialDays", reflect.TypeOf((*MockorderStore)(nil).SetTrialDays), ctx, dbi, id, ndays) -} - // UpdateMetadata mocks base method. func (m *MockorderStore) UpdateMetadata(ctx context.Context, dbi sqlx.ExecerContext, id go_uuid.UUID, data datastore.Metadata) error { m.ctrl.T.Helper() diff --git a/services/skus/model/model.go b/services/skus/model/model.go index 8b89b4e22..5edc68108 100644 --- a/services/skus/model/model.go +++ b/services/skus/model/model.go @@ -118,12 +118,21 @@ func (o *Order) IsRadomPayable() bool { return Slice[string](o.AllowedPaymentMethods).Contains(RadomPaymentMethod) } -func (o *Order) ShouldSetTrialDays() bool { - return !o.IsPaid() && o.IsStripePayable() +func (o *Order) ShouldCreateTrialSessionStripe(now time.Time) bool { + return !o.IsPaidAt(now) && o.IsStripePayable() } // IsPaid returns true if the order is paid. +// +// TODO: Update all callers of the method to pass time explicitly. func (o *Order) IsPaid() bool { + return o.IsPaidAt(time.Now()) +} + +// IsPaidAt returns true if the order is paid. +// +// If canceled, it checks if expires_at is in the future. +func (o *Order) IsPaidAt(now time.Time) bool { switch o.Status { case OrderStatusPaid: // The order is paid if the status is paid. @@ -134,7 +143,7 @@ func (o *Order) IsPaid() bool { return false } - return o.ExpiresAt.After(time.Now()) + return o.ExpiresAt.After(now) default: return false } @@ -731,6 +740,11 @@ type VerifyCredentialOpaque struct { Version float64 `json:"version" validate:"-"` } +type SetTrialDaysRequest struct { + Email string `json:"email"` // TODO: Make it required. + TrialDays int64 `json:"trialDays"` +} + func addURLParam(src, name, val string) (string, error) { raw, err := url.Parse(src) if err != nil { diff --git a/services/skus/model/model_test.go b/services/skus/model/model_test.go index 8bbc7ffe1..140c20f62 100644 --- a/services/skus/model/model_test.go +++ b/services/skus/model/model_test.go @@ -4,6 +4,7 @@ import ( "encoding/json" "errors" "testing" + "time" "github.com/lib/pq" uuid "github.com/satori/go.uuid" @@ -1157,63 +1158,153 @@ func TestOrder_Vendor(t *testing.T) { } } -func TestOrder_ShouldSetTrialDays(t *testing.T) { +func TestOrder_ShouldCreateTrialSessionStripe(t *testing.T) { + type tcGiven struct { + ord *model.Order + now time.Time + } + type testCase struct { name string - given model.Order + given tcGiven exp bool } tests := []testCase{ { - name: "not_paid", - given: model.Order{Status: model.OrderStatusPending}, + name: "false_paid_not_stripe", + given: tcGiven{ + ord: &model.Order{ + Status: model.OrderStatusPaid, + AllowedPaymentMethods: pq.StringArray{"radom"}, + }, + }, }, { - name: "not_paid_not_stripe", - given: model.Order{ - Status: model.OrderStatusPending, - AllowedPaymentMethods: pq.StringArray{"something"}, + name: "false_not_paid_not_stripe", + given: tcGiven{ + ord: &model.Order{ + Status: model.OrderStatusPending, + AllowedPaymentMethods: pq.StringArray{"radom"}, + }, }, }, { - name: "paid", - given: model.Order{Status: model.OrderStatusPaid}, + name: "false_paid_stripe", + given: tcGiven{ + ord: &model.Order{ + Status: model.OrderStatusPaid, + AllowedPaymentMethods: pq.StringArray{"stripe"}, + }, + }, }, { - name: "paid_not_stripe", - given: model.Order{ - Status: model.OrderStatusPaid, - AllowedPaymentMethods: pq.StringArray{"something"}, + name: "false_canceled_not_expired_stripe", + given: tcGiven{ + ord: &model.Order{ + Status: model.OrderStatusPaid, + AllowedPaymentMethods: pq.StringArray{"stripe"}, + ExpiresAt: ptrTo(time.Date(2024, time.November, 1, 0, 0, 0, 0, time.UTC)), + }, + now: time.Date(2024, time.October, 1, 0, 0, 0, 0, time.UTC), }, }, { - name: "paid_stripe", - given: model.Order{ - Status: model.OrderStatusPaid, - AllowedPaymentMethods: pq.StringArray{"stripe"}, + name: "true_pending_stripe", + given: tcGiven{ + ord: &model.Order{ + Status: model.OrderStatusPending, + AllowedPaymentMethods: pq.StringArray{"stripe"}, + }, }, + exp: true, }, + } + for i := range tests { + tc := tests[i] + + t.Run(tc.name, func(t *testing.T) { + actual := tc.given.ord.ShouldCreateTrialSessionStripe(tc.given.now) + should.Equal(t, tc.exp, actual) + }) + } +} + +func TestOrder_IsPaidAt(t *testing.T) { + type tcGiven struct { + ord *model.Order + now time.Time + } + + type testCase struct { + name string + given tcGiven + exp bool + } + + tests := []testCase{ { - name: "not_paid_stripe", - given: model.Order{ - Status: model.OrderStatusPending, - AllowedPaymentMethods: pq.StringArray{"stripe"}, + name: "true_paid", + given: tcGiven{ + ord: &model.Order{ + Status: model.OrderStatusPaid, + }, + }, + exp: true, + }, + + { + name: "false_canceled_no_expiry", + given: tcGiven{ + ord: &model.Order{ + Status: model.OrderStatusCanceled, + }, + }, + }, + + { + name: "true_canceled_expires_later", + given: tcGiven{ + ord: &model.Order{ + Status: model.OrderStatusCanceled, + ExpiresAt: ptrTo(time.Date(2024, time.November, 1, 0, 0, 0, 0, time.UTC)), + }, + now: time.Date(2024, time.October, 1, 0, 0, 0, 0, time.UTC), }, exp: true, }, + + { + name: "false_canceled_expired", + given: tcGiven{ + ord: &model.Order{ + Status: model.OrderStatusCanceled, + ExpiresAt: ptrTo(time.Date(2024, time.November, 1, 0, 0, 0, 0, time.UTC)), + }, + now: time.Date(2024, time.December, 1, 0, 0, 0, 0, time.UTC), + }, + }, + + { + name: "false_pending", + given: tcGiven{ + ord: &model.Order{ + Status: model.OrderStatusPending, + }, + }, + }, } for i := range tests { tc := tests[i] t.Run(tc.name, func(t *testing.T) { - actual := tc.given.ShouldSetTrialDays() + actual := tc.given.ord.IsPaidAt(tc.given.now) should.Equal(t, tc.exp, actual) }) } diff --git a/services/skus/service.go b/services/skus/service.go index ae39e9535..89a4f2a16 100644 --- a/services/skus/service.go +++ b/services/skus/service.go @@ -97,6 +97,7 @@ type orderStoreSvc interface { SetStatus(ctx context.Context, dbi sqlx.ExecerContext, id uuid.UUID, status string) error SetExpiresAt(ctx context.Context, dbi sqlx.ExecerContext, id uuid.UUID, when time.Time) error SetLastPaidAt(ctx context.Context, dbi sqlx.ExecerContext, id uuid.UUID, when time.Time) error + SetTrialDays(ctx context.Context, dbi sqlx.ExecerContext, id uuid.UUID, ndays int64) error AppendMetadata(ctx context.Context, dbi sqlx.ExecerContext, id uuid.UUID, key, val string) error AppendMetadataInt(ctx context.Context, dbi sqlx.ExecerContext, id uuid.UUID, key string, val int) error AppendMetadataInt64(ctx context.Context, dbi sqlx.ExecerContext, id uuid.UUID, key string, val int64) error @@ -634,7 +635,7 @@ func (s *Service) updateOrderStripeSession(ctx context.Context, dbi sqlx.ExtCont var newSessID string if expSessID != "" { - nsessID, err := s.recreateStripeSession(ctx, dbi, ord, expSessID) + nsessID, err := s.recreateStripeSession(ctx, dbi, ord, expSessID, "") if err != nil { return fmt.Errorf("failed to create checkout session: %w", err) } @@ -755,13 +756,31 @@ func (s *Service) CancelOrderLegacy(orderID uuid.UUID) error { return s.Datastore.UpdateOrder(orderID, OrderStatusCanceled) } -func (s *Service) SetOrderTrialDays(ctx context.Context, orderID *uuid.UUID, days int64) error { - ord, err := s.Datastore.SetOrderTrialDays(ctx, orderID, days) +func (s *Service) setOrderTrialDays(ctx context.Context, orderID uuid.UUID, req *model.SetTrialDaysRequest, now time.Time) error { + tx, err := s.Datastore.RawDB().BeginTxx(ctx, nil) if err != nil { - return fmt.Errorf("failed to set the order's trial days: %w", err) + return err + } + defer func() { _ = tx.Rollback() }() + + if err := s.setOrderTrialDaysTx(ctx, tx, orderID, req, now); err != nil { + return err } - if !ord.ShouldSetTrialDays() { + return tx.Commit() +} + +func (s *Service) setOrderTrialDaysTx(ctx context.Context, dbi sqlx.ExtContext, orderID uuid.UUID, req *model.SetTrialDaysRequest, now time.Time) error { + if err := s.orderRepo.SetTrialDays(ctx, dbi, orderID, req.TrialDays); err != nil { + return err + } + + ord, err := s.getOrderFullTx(ctx, dbi, orderID) + if err != nil { + return err + } + + if !ord.ShouldCreateTrialSessionStripe(now) { return nil } @@ -770,7 +789,7 @@ func (s *Service) SetOrderTrialDays(ctx context.Context, orderID *uuid.UUID, day return model.ErrNoStripeCheckoutSessID } - _, err = s.recreateStripeSession(ctx, s.Datastore.RawDB(), ord, oldSessID) + _, err = s.recreateStripeSession(ctx, dbi, ord, oldSessID, req.Email) return err } @@ -2544,7 +2563,7 @@ func (s *Service) renewOrderStripe(ctx context.Context, dbi sqlx.ExecerContext, return s.orderRepo.AppendMetadata(ctx, dbi, ord.ID, "paymentProcessor", model.StripePaymentMethod) } -func (s *Service) recreateStripeSession(ctx context.Context, dbi sqlx.ExecerContext, ord *model.Order, oldSessID string) (string, error) { +func (s *Service) recreateStripeSession(ctx context.Context, dbi sqlx.ExecerContext, ord *model.Order, oldSessID, email string) (string, error) { oldSess, err := s.stripeCl.Session(ctx, oldSessID, nil) if err != nil { return "", err @@ -2552,13 +2571,17 @@ func (s *Service) recreateStripeSession(ctx context.Context, dbi sqlx.ExecerCont req := createStripeSessionRequest{ orderID: ord.ID.String(), - email: xstripe.CustomerEmailFromSession(oldSess), + email: email, successURL: oldSess.SuccessURL, cancelURL: oldSess.CancelURL, trialDays: ord.GetTrialDays(), items: buildStripeLineItems(ord.Items), } + if req.email == "" { + req.email = xstripe.CustomerEmailFromSession(oldSess) + } + sessID, err := createStripeSession(ctx, s.stripeCl, req) if err != nil { return "", err diff --git a/services/skus/service_nonint_test.go b/services/skus/service_nonint_test.go index 931bf1b1f..de3890fad 100644 --- a/services/skus/service_nonint_test.go +++ b/services/skus/service_nonint_test.go @@ -4283,6 +4283,7 @@ func TestService_recreateStripeSession(t *testing.T) { cl *xstripe.MockClient ord *model.Order oldSessID string + email string } type tcExpected struct { @@ -4387,7 +4388,7 @@ func TestService_recreateStripeSession(t *testing.T) { }, { - name: "success", + name: "success_email_from_session", given: tcGiven{ ordRepo: &repository.MockOrder{ FnAppendMetadata: func(ctx context.Context, dbi sqlx.ExecerContext, id uuid.UUID, key, val string) error { @@ -4409,6 +4410,106 @@ func TestService_recreateStripeSession(t *testing.T) { return result, nil }, + + FnCreateSession: func(ctx context.Context, params *stripe.CheckoutSessionParams) (*stripe.CheckoutSession, error) { + if *params.Customer != "cus_id" { + return nil, model.Error("unexpected_customer") + } + + if params.CustomerEmail != nil { + return nil, model.Error("unexpected_customer_email") + } + + result := &stripe.CheckoutSession{ + ID: "cs_test_id", + PaymentMethodTypes: []string{"card"}, + Mode: stripe.CheckoutSessionModeSubscription, + SuccessURL: *params.SuccessURL, + CancelURL: *params.CancelURL, + ClientReferenceID: *params.ClientReferenceID, + Subscription: &stripe.Subscription{ + ID: "sub_id", + Metadata: map[string]string{ + "orderID": *params.ClientReferenceID, + }, + }, + AllowPromotionCodes: true, + } + + return result, nil + }, + }, + ord: &model.Order{ + ID: uuid.Must(uuid.FromString("facade00-0000-4000-a000-000000000000")), + Items: []model.OrderItem{ + { + Quantity: 1, + Metadata: datastore.Metadata{"stripe_item_id": "stripe_item_id"}, + }, + }, + }, + oldSessID: "cs_test_id_old", + }, + exp: tcExpected{ + val: "cs_test_id", + }, + }, + + { + name: "success_email_from_request", + given: tcGiven{ + ordRepo: &repository.MockOrder{ + FnAppendMetadata: func(ctx context.Context, dbi sqlx.ExecerContext, id uuid.UUID, key, val string) error { + if key == "stripeCheckoutSessionId" && val == "cs_test_id" { + return nil + } + + return model.Error("unexpected") + }, + }, + cl: &xstripe.MockClient{ + FnSession: func(ctx context.Context, id string, params *stripe.CheckoutSessionParams) (*stripe.CheckoutSession, error) { + result := &stripe.CheckoutSession{ + ID: "cs_test_id_old", + SuccessURL: "https://example.com/success", + CancelURL: "https://example.com/cancel", + Customer: &stripe.Customer{Email: "session@example.com"}, + } + + return result, nil + }, + + FnFindCustomer: func(ctx context.Context, email string) (*stripe.Customer, bool) { + return nil, false + }, + + FnCreateSession: func(ctx context.Context, params *stripe.CheckoutSessionParams) (*stripe.CheckoutSession, error) { + if params.Customer != nil { + return nil, model.Error("unexpected_customer") + } + + if *params.CustomerEmail != "request@example.com" { + return nil, model.Error("unexpected_customer_email") + } + + result := &stripe.CheckoutSession{ + ID: "cs_test_id", + PaymentMethodTypes: []string{"card"}, + Mode: stripe.CheckoutSessionModeSubscription, + SuccessURL: *params.SuccessURL, + CancelURL: *params.CancelURL, + ClientReferenceID: *params.ClientReferenceID, + Subscription: &stripe.Subscription{ + ID: "sub_id", + Metadata: map[string]string{ + "orderID": *params.ClientReferenceID, + }, + }, + AllowPromotionCodes: true, + } + + return result, nil + }, }, ord: &model.Order{ ID: uuid.Must(uuid.FromString("facade00-0000-4000-a000-000000000000")), @@ -4420,6 +4521,7 @@ func TestService_recreateStripeSession(t *testing.T) { }, }, oldSessID: "cs_test_id_old", + email: "request@example.com", }, exp: tcExpected{ val: "cs_test_id", @@ -4435,7 +4537,7 @@ func TestService_recreateStripeSession(t *testing.T) { ctx := context.Background() - actual, err := svc.recreateStripeSession(ctx, nil, tc.given.ord, tc.given.oldSessID) + actual, err := svc.recreateStripeSession(ctx, nil, tc.given.ord, tc.given.oldSessID, tc.given.email) must.Equal(t, tc.exp.err, err) should.Equal(t, tc.exp.val, actual) @@ -5357,6 +5459,200 @@ func TestService_updateNumPaymentFailed(t *testing.T) { } } +func TestService_setOrderTrialDaysTx(t *testing.T) { + type tcGiven struct { + orepo *repository.MockOrder + oirepo *repository.MockOrderItem + scl stripeClient + + id uuid.UUID + req *model.SetTrialDaysRequest + now time.Time + } + + type testCase struct { + name string + given tcGiven + exp error + } + + tests := []testCase{ + { + name: "error_set_trial_days", + given: tcGiven{ + orepo: &repository.MockOrder{ + FnSetTrialDays: func(ctx context.Context, dbi sqlx.ExecerContext, id uuid.UUID, ndays int64) error { + return model.Error("something_went_wrong") + }, + }, + oirepo: &repository.MockOrderItem{}, + scl: &xstripe.MockClient{}, + id: uuid.FromStringOrNil("facade00-0000-4000-a000-000000000000"), + req: &model.SetTrialDaysRequest{Email: "you@example.com", TrialDays: 7}, + now: time.Date(2024, time.November, 1, 0, 0, 0, 0, time.UTC), + }, + exp: model.Error("something_went_wrong"), + }, + + { + name: "error_get_order_full", + given: tcGiven{ + orepo: &repository.MockOrder{ + FnGet: func(ctx context.Context, dbi sqlx.QueryerContext, id uuid.UUID) (*model.Order, error) { + return nil, model.Error("something_went_wrong") + }, + }, + oirepo: &repository.MockOrderItem{}, + scl: &xstripe.MockClient{}, + id: uuid.FromStringOrNil("facade00-0000-4000-a000-000000000000"), + req: &model.SetTrialDaysRequest{Email: "you@example.com", TrialDays: 7}, + now: time.Date(2024, time.November, 1, 0, 0, 0, 0, time.UTC), + }, + exp: model.Error("something_went_wrong"), + }, + + { + name: "success_no_new_session", + given: tcGiven{ + orepo: &repository.MockOrder{ + FnGet: func(ctx context.Context, dbi sqlx.QueryerContext, id uuid.UUID) (*model.Order, error) { + result := &model.Order{ + ID: uuid.FromStringOrNil("facade00-0000-4000-a000-000000000000"), + Status: model.OrderStatusPaid, + ExpiresAt: ptrTo(time.Date(2024, time.December, 1, 0, 0, 0, 0, time.UTC)), + AllowedPaymentMethods: pq.StringArray{"stripe"}, + } + + return result, nil + }, + }, + oirepo: &repository.MockOrderItem{}, + scl: &xstripe.MockClient{}, + id: uuid.FromStringOrNil("facade00-0000-4000-a000-000000000000"), + req: &model.SetTrialDaysRequest{Email: "you@example.com", TrialDays: 7}, + now: time.Date(2024, time.November, 1, 0, 0, 0, 0, time.UTC), + }, + }, + + { + name: "error_no_session_id", + given: tcGiven{ + orepo: &repository.MockOrder{ + FnGet: func(ctx context.Context, dbi sqlx.QueryerContext, id uuid.UUID) (*model.Order, error) { + result := &model.Order{ + ID: uuid.FromStringOrNil("facade00-0000-4000-a000-000000000000"), + Status: model.OrderStatusPending, + AllowedPaymentMethods: pq.StringArray{"stripe"}, + } + + return result, nil + }, + }, + oirepo: &repository.MockOrderItem{}, + scl: &xstripe.MockClient{}, + id: uuid.FromStringOrNil("facade00-0000-4000-a000-000000000000"), + req: &model.SetTrialDaysRequest{Email: "you@example.com", TrialDays: 7}, + now: time.Date(2024, time.November, 1, 0, 0, 0, 0, time.UTC), + }, + exp: model.ErrNoStripeCheckoutSessID, + }, + + { + name: "error_recreate_stripe_session", + given: tcGiven{ + orepo: &repository.MockOrder{ + FnGet: func(ctx context.Context, dbi sqlx.QueryerContext, id uuid.UUID) (*model.Order, error) { + result := &model.Order{ + ID: uuid.FromStringOrNil("facade00-0000-4000-a000-000000000000"), + Status: model.OrderStatusPending, + AllowedPaymentMethods: pq.StringArray{"stripe"}, + Metadata: datastore.Metadata{ + "stripeCheckoutSessionId": "sess_id", + }, + } + + return result, nil + }, + }, + oirepo: &repository.MockOrderItem{}, + scl: &xstripe.MockClient{ + FnSession: func(ctx context.Context, id string, params *stripe.CheckoutSessionParams) (*stripe.CheckoutSession, error) { + return nil, model.Error("something_went_wrong") + }, + }, + + id: uuid.FromStringOrNil("facade00-0000-4000-a000-000000000000"), + req: &model.SetTrialDaysRequest{Email: "you@example.com", TrialDays: 7}, + now: time.Date(2024, time.November, 1, 0, 0, 0, 0, time.UTC), + }, + exp: model.Error("something_went_wrong"), + }, + + { + name: "success", + given: tcGiven{ + orepo: &repository.MockOrder{ + FnGet: func(ctx context.Context, dbi sqlx.QueryerContext, id uuid.UUID) (*model.Order, error) { + result := &model.Order{ + ID: uuid.FromStringOrNil("facade00-0000-4000-a000-000000000000"), + Status: model.OrderStatusPending, + AllowedPaymentMethods: pq.StringArray{"stripe"}, + Metadata: datastore.Metadata{ + "stripeCheckoutSessionId": "sess_id", + }, + TrialDays: ptrTo(int64(7)), + } + + return result, nil + }, + + FnAppendMetadata: func(ctx context.Context, dbi sqlx.ExecerContext, id uuid.UUID, key, val string) error { + if key == "stripeCheckoutSessionId" && val == "cs_test_id" { + return nil + } + + return model.Error("unexpected_append_metadata") + }, + }, + oirepo: &repository.MockOrderItem{}, + scl: &xstripe.MockClient{ + FnCreateSession: func(ctx context.Context, params *stripe.CheckoutSessionParams) (*stripe.CheckoutSession, error) { + if *params.SubscriptionData.TrialPeriodDays != 7 { + return nil, model.Error("unexpected_trial_period_days") + } + + result := &stripe.CheckoutSession{ + ID: "cs_test_id", + } + + return result, nil + }, + }, + id: uuid.FromStringOrNil("facade00-0000-4000-a000-000000000000"), + req: &model.SetTrialDaysRequest{Email: "you@example.com", TrialDays: 7}, + now: time.Date(2024, time.November, 1, 0, 0, 0, 0, time.UTC), + }, + }, + } + + for i := range tests { + tc := tests[i] + + t.Run(tc.name, func(t *testing.T) { + svc := &Service{ + orderRepo: tc.given.orepo, + orderItemRepo: tc.given.oirepo, + stripeCl: tc.given.scl, + } + + ctx := context.Background() + + actual := svc.setOrderTrialDaysTx(ctx, nil, tc.given.id, tc.given.req, tc.given.now) + should.ErrorIs(t, actual, tc.exp) + }) + } +} + func TestHandleRedeemFnError(t *testing.T) { type tcGiven struct { kind string diff --git a/services/skus/storage/repository/mock.go b/services/skus/storage/repository/mock.go index 7b8f2851e..de42cecad 100644 --- a/services/skus/storage/repository/mock.go +++ b/services/skus/storage/repository/mock.go @@ -19,6 +19,7 @@ type MockOrder struct { FnSetStatus func(ctx context.Context, dbi sqlx.ExecerContext, id uuid.UUID, status string) error FnSetExpiresAt func(ctx context.Context, dbi sqlx.ExecerContext, id uuid.UUID, when time.Time) error FnSetLastPaidAt func(ctx context.Context, dbi sqlx.ExecerContext, id uuid.UUID, when time.Time) error + FnSetTrialDays func(ctx context.Context, dbi sqlx.ExecerContext, id uuid.UUID, ndays int64) error FnAppendMetadata func(ctx context.Context, dbi sqlx.ExecerContext, id uuid.UUID, key, val string) error FnAppendMetadataInt func(ctx context.Context, dbi sqlx.ExecerContext, id uuid.UUID, key string, val int) error FnAppendMetadataInt64 func(ctx context.Context, dbi sqlx.ExecerContext, id uuid.UUID, key string, val int64) error @@ -93,6 +94,14 @@ func (r *MockOrder) SetLastPaidAt(ctx context.Context, dbi sqlx.ExecerContext, i return r.FnSetLastPaidAt(ctx, dbi, id, when) } +func (r *MockOrder) SetTrialDays(ctx context.Context, dbi sqlx.ExecerContext, id uuid.UUID, ndays int64) error { + if r.FnSetTrialDays == nil { + return nil + } + + return r.FnSetTrialDays(ctx, dbi, id, ndays) +} + func (r *MockOrder) AppendMetadata(ctx context.Context, dbi sqlx.ExecerContext, id uuid.UUID, key, val string) error { if r.FnAppendMetadata == nil { return nil diff --git a/services/skus/storage/repository/repository.go b/services/skus/storage/repository/repository.go index 0f3cc4abf..1f55bd150 100644 --- a/services/skus/storage/repository/repository.go +++ b/services/skus/storage/repository/repository.go @@ -92,22 +92,10 @@ func (r *Order) SetLastPaidAt(ctx context.Context, dbi sqlx.ExecerContext, id uu } // SetTrialDays sets trial_days to ndays. -func (r *Order) SetTrialDays(ctx context.Context, dbi sqlx.QueryerContext, id uuid.UUID, ndays int64) (*model.Order, error) { - const q = `UPDATE orders - SET trial_days = $2, updated_at = now() - WHERE id = $1 - RETURNING id, created_at, currency, updated_at, total_price, merchant_id, location, status, allowed_payment_methods, metadata, valid_for, last_paid_at, expires_at, trial_days` - - result := &model.Order{} - if err := dbi.QueryRowxContext(ctx, q, id, ndays).StructScan(result); err != nil { - if errors.Is(err, sql.ErrNoRows) { - return nil, model.ErrOrderNotFound - } +func (r *Order) SetTrialDays(ctx context.Context, dbi sqlx.ExecerContext, id uuid.UUID, ndays int64) error { + const q = `UPDATE orders SET trial_days = $2, updated_at = now() WHERE id = $1` - return nil, err - } - - return result, nil + return r.execUpdate(ctx, dbi, q, id, ndays) } // SetStatus sets status to status. diff --git a/services/skus/storage/repository/repository_test.go b/services/skus/storage/repository/repository_test.go index d21a45e61..8f406d926 100644 --- a/services/skus/storage/repository/repository_test.go +++ b/services/skus/storage/repository/repository_test.go @@ -26,37 +26,88 @@ func TestOrder_SetTrialDays(t *testing.T) { dbi, err := setupDBI() must.Equal(t, nil, err) - t.Cleanup(func() { + defer func() { _, _ = dbi.Exec("TRUNCATE TABLE orders;") - }) + }() + + type tcGiven struct { + id uuid.UUID + ndays int64 + fnBefore func(ctx context.Context, dbi sqlx.ExecerContext) error + } type tcExpected struct { - ndays int64 - err error + num int64 + updateErr error + getErr error } type testCase struct { name string - given int64 + given tcGiven exp tcExpected } tests := []testCase{ { - name: "not_found", + name: "not_set_before", + given: tcGiven{ + id: uuid.FromStringOrNil("facade00-0000-4000-a000-000000000000"), + ndays: 1, + fnBefore: func(ctx context.Context, dbi sqlx.ExecerContext) error { + const q = `INSERT INTO orders ( + id, merchant_id, status, currency, total_price, created_at, updated_at + ) + VALUES ( + 'facade00-0000-4000-a000-000000000000', + 'brave.com', + 'paid', + 'USD', + 9.99, + '2024-01-01 00:00:01', + '2024-01-01 00:00:01' + );` + + _, err := dbi.ExecContext(ctx, q) + + return err + }, + }, + exp: tcExpected{ - err: model.ErrOrderNotFound, + num: 1, }, }, { - name: "no_changes", - }, + name: "overwrites_existing", + given: tcGiven{ + id: uuid.FromStringOrNil("facade00-0000-4000-a000-000000000000"), + ndays: 7, + fnBefore: func(ctx context.Context, dbi sqlx.ExecerContext) error { + const q = `INSERT INTO orders ( + id, merchant_id, status, currency, total_price, trial_days, created_at, updated_at + ) + VALUES ( + 'facade00-0000-4000-a000-000000000000', + 'brave.com', + 'paid', + 'USD', + 9.99, + 3, + '2024-01-01 00:00:01', + '2024-01-01 00:00:01' + );` - { - name: "updated_value", - given: 4, - exp: tcExpected{ndays: 4}, + _, err := dbi.ExecContext(ctx, q) + + return err + }, + }, + + exp: tcExpected{ + num: 7, + }, }, } @@ -73,23 +124,31 @@ func TestOrder_SetTrialDays(t *testing.T) { t.Cleanup(func() { _ = tx.Rollback() }) - order, err := createOrderForTest(ctx, tx, repo) - must.Equal(t, nil, err) + if tc.given.fnBefore != nil { + err := tc.given.fnBefore(ctx, tx) + must.NoError(t, err) + } - id := order.ID - if tc.exp.err == model.ErrOrderNotFound { - // Use any id for testing the not found case. - id = uuid.NamespaceDNS + { + err := repo.SetTrialDays(ctx, tx, tc.given.id, tc.given.ndays) + if err != nil { + t.Log(err) + } + must.Equal(t, tc.exp.updateErr, err) } - actual, err := repo.SetTrialDays(ctx, tx, id, tc.given) - must.Equal(t, true, errors.Is(err, tc.exp.err)) + if tc.exp.updateErr != nil { + return + } - if tc.exp.err != nil { + actual, err := repo.Get(ctx, tx, tc.given.id) + must.Equal(t, tc.exp.getErr, err) + + if tc.exp.getErr != nil { return } - should.Equal(t, tc.exp.ndays, actual.GetTrialDays()) + should.Equal(t, tc.exp.num, actual.GetTrialDays()) }) } }