Skip to content

Commit

Permalink
fix: send correct verification status in post-recovery hook
Browse files Browse the repository at this point in the history
  • Loading branch information
aeneasr committed Nov 26, 2024
1 parent e6d2d4d commit b1dfe9d
Show file tree
Hide file tree
Showing 5 changed files with 76 additions and 38 deletions.
30 changes: 28 additions & 2 deletions internal/testhelpers/selfservice_verification.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ func NewVerifyAfterHookWebHookTarget(ctx context.Context, t *testing.T, conf *co

assert(t, msg)
}))

before := conf.GetProvider(ctx).Get(config.ViperKeySelfServiceVerificationAfter + ".hooks")
// A hook to ensure that the verification hook is called with the correct data
conf.MustSet(ctx, config.ViperKeySelfServiceVerificationAfter+".hooks", []map[string]interface{}{
{
Expand All @@ -52,7 +52,33 @@ func NewVerifyAfterHookWebHookTarget(ctx context.Context, t *testing.T, conf *co

t.Cleanup(ts.Close)
t.Cleanup(func() {
conf.MustSet(ctx, config.ViperKeySelfServiceVerificationAfter+".hooks", []map[string]interface{}{})
conf.MustSet(ctx, config.ViperKeySelfServiceVerificationAfter+".hooks", before)
})
}

func NewRecoveryAfterHookWebHookTarget(ctx context.Context, t *testing.T, conf *config.Config, assert func(t *testing.T, body []byte)) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
msg, err := io.ReadAll(r.Body)
require.NoError(t, err)

assert(t, msg)
}))

// A hook to ensure that the recovery hook is called with the correct data
conf.MustSet(ctx, config.ViperKeySelfServiceRecoveryAfter+".hooks", []map[string]interface{}{
{
"hook": "web_hook",
"config": map[string]interface{}{
"url": ts.URL,
"method": "POST",
"body": "base64://ZnVuY3Rpb24oY3R4KSB7CiAgICBpZGVudGl0eTogY3R4LmlkZW50aXR5Cn0=",
},
},
})

t.Cleanup(ts.Close)
t.Cleanup(func() {
conf.MustSet(ctx, config.ViperKeySelfServiceRecoveryAfter+".hooks", []map[string]interface{}{})
})
}

Expand Down
26 changes: 10 additions & 16 deletions selfservice/strategy/code/strategy_recovery.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ import (
"net/url"
"time"

"github.com/ory/x/pointerx"

"github.com/gofrs/uuid"
"github.com/pkg/errors"
"go.opentelemetry.io/otel/attribute"
Expand Down Expand Up @@ -430,22 +432,14 @@ func (s *Strategy) recoveryHandleFormSubmission(w http.ResponseWriter, r *http.R
}

func (s *Strategy) markRecoveryAddressVerified(w http.ResponseWriter, r *http.Request, f *recovery.Flow, id *identity.Identity, recoveryAddress *identity.RecoveryAddress) error {
var address *identity.VerifiableAddress
for idx := range id.VerifiableAddresses {
va := id.VerifiableAddresses[idx]
if va.Value == recoveryAddress.Value {
address = &va
break
}
}

if address != nil && !address.Verified { // can it be that the address is nil?
address.Verified = true
verifiedAt := sqlxx.NullTime(time.Now().UTC())
address.VerifiedAt = &verifiedAt
address.Status = identity.VerifiableAddressStatusCompleted
if err := s.deps.PrivilegedIdentityPool().UpdateVerifiableAddress(r.Context(), address); err != nil {
return s.HandleRecoveryError(w, r, f, nil, err)
for k, v := range id.VerifiableAddresses {
if v.Value == recoveryAddress.Value {
id.VerifiableAddresses[k].Verified = true
id.VerifiableAddresses[k].VerifiedAt = pointerx.Ptr(sqlxx.NullTime(time.Now().UTC()))
id.VerifiableAddresses[k].Status = identity.VerifiableAddressStatusCompleted
if err := s.deps.PrivilegedIdentityPool().UpdateVerifiableAddress(r.Context(), &id.VerifiableAddresses[k]); err != nil {
return s.HandleRecoveryError(w, r, f, nil, err)
}
}
}

Expand Down
14 changes: 13 additions & 1 deletion selfservice/strategy/code/strategy_recovery_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"net/http"
"net/http/httptest"
"net/url"
"sync"
"testing"
"time"

Expand Down Expand Up @@ -253,6 +254,15 @@ func TestRecovery(t *testing.T) {
}

t.Run("type=browser", func(t *testing.T) {
var wg sync.WaitGroup
wg.Add(1)
testhelpers.NewRecoveryAfterHookWebHookTarget(ctx, t, conf, func(t *testing.T, msg []byte) {
defer wg.Done()
assert.EqualValues(t, "[email protected]", gjson.GetBytes(msg, "identity.verifiable_addresses.0.value").String(), string(msg))
assert.EqualValues(t, true, gjson.GetBytes(msg, "identity.verifiable_addresses.0.verified").Bool(), string(msg))
assert.EqualValues(t, "completed", gjson.GetBytes(msg, "identity.verifiable_addresses.0.status").String(), string(msg))
})

client := testhelpers.NewClientWithCookies(t)
email := "[email protected]"
createIdentityToRecover(t, reg, email)
Expand All @@ -270,6 +280,8 @@ func TestRecovery(t *testing.T) {
require.NoError(t, res.Body.Close())
assert.Equal(t, "code_recovery", gjson.Get(body, "authentication_methods.0.method").String(), "%s", body)
assert.Equal(t, "aal1", gjson.Get(body, "authenticator_assurance_level").String(), "%s", body)

wg.Wait()
})

t.Run("type=spa", func(t *testing.T) {
Expand Down Expand Up @@ -990,7 +1002,7 @@ func TestRecovery(t *testing.T) {
body = submitRecoveryCode(t, cl, body, RecoveryClientTypeBrowser, recoveryCode, http.StatusSeeOther)
assert.NotEqual(t, gjson.Get(body, "id"), initialFlowId)

require.Len(t, cl.Jar.Cookies(urlx.ParseOrPanic(public.URL)), 1)
require.Len(t, cl.Jar.Cookies(urlx.ParseOrPanic(public.URL)), 1) // No session
cookies := spew.Sdump(cl.Jar.Cookies(urlx.ParseOrPanic(public.URL)))
assert.NotContains(t, cookies, "ory_kratos_session")
})
Expand Down
32 changes: 13 additions & 19 deletions selfservice/strategy/link/strategy_recovery.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ import (
"net/url"
"time"

"github.com/ory/x/pointerx"

"github.com/gofrs/uuid"
"github.com/julienschmidt/httprouter"
"github.com/pkg/errors"
Expand Down Expand Up @@ -313,16 +315,16 @@ func (s *Strategy) recoveryIssueSession(ctx context.Context, w http.ResponseWrit
return s.retryRecoveryFlowWithError(w, r, flow.TypeBrowser, err)
}

if err := s.d.RecoveryExecutor().PostRecoveryHook(w, r, f, sess); err != nil {
// Force load.
if err := s.d.PrivilegedIdentityPool().HydrateIdentityAssociations(ctx, sess.Identity, identity.ExpandEverything); err != nil {
return s.retryRecoveryFlowWithError(w, r, flow.TypeBrowser, err)
}

if err := s.d.SessionManager().UpsertAndIssueCookie(r.Context(), w, r, sess); err != nil {
return s.retryRecoveryFlowWithError(w, r, flow.TypeBrowser, err)
}

// Force load.
if err := s.d.PrivilegedIdentityPool().HydrateIdentityAssociations(ctx, sess.Identity, identity.ExpandEverything); err != nil {
if err := s.d.RecoveryExecutor().PostRecoveryHook(w, r, f, sess); err != nil {
return s.retryRecoveryFlowWithError(w, r, flow.TypeBrowser, err)
}

Expand Down Expand Up @@ -489,22 +491,14 @@ func (s *Strategy) recoveryHandleFormSubmission(w http.ResponseWriter, r *http.R
}

func (s *Strategy) markRecoveryAddressVerified(w http.ResponseWriter, r *http.Request, f *recovery.Flow, id *identity.Identity, recoveryAddress *identity.RecoveryAddress) error {
var address *identity.VerifiableAddress
for idx := range id.VerifiableAddresses {
va := id.VerifiableAddresses[idx]
if va.Value == recoveryAddress.Value {
address = &va
break
}
}

if address != nil && !address.Verified { // can it be that the address is nil?
address.Verified = true
verifiedAt := sqlxx.NullTime(time.Now().UTC())
address.VerifiedAt = &verifiedAt
address.Status = identity.VerifiableAddressStatusCompleted
if err := s.d.PrivilegedIdentityPool().UpdateVerifiableAddress(r.Context(), address); err != nil {
return s.HandleRecoveryError(w, r, f, nil, err)
for k, v := range id.VerifiableAddresses {
if v.Value == recoveryAddress.Value {
id.VerifiableAddresses[k].Verified = true
id.VerifiableAddresses[k].VerifiedAt = pointerx.Ptr(sqlxx.NullTime(time.Now().UTC()))
id.VerifiableAddresses[k].Status = identity.VerifiableAddressStatusCompleted
if err := s.d.PrivilegedIdentityPool().UpdateVerifiableAddress(r.Context(), &id.VerifiableAddresses[k]); err != nil {
return s.HandleRecoveryError(w, r, f, nil, err)
}
}
}

Expand Down
12 changes: 12 additions & 0 deletions selfservice/strategy/link/strategy_recovery_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"net/http/httptest"
"net/url"
"strings"
"sync"
"testing"
"time"

Expand Down Expand Up @@ -541,11 +542,22 @@ func TestRecovery(t *testing.T) {
}

t.Run("type=browser", func(t *testing.T) {
var wg sync.WaitGroup
wg.Add(1)
testhelpers.NewVerifyAfterHookWebHookTarget(ctx, t, conf, func(t *testing.T, msg []byte) {
defer wg.Done()
assert.EqualValues(t, "[email protected]", gjson.GetBytes(msg, "identity.verifiable_addresses.0.value").String(), string(msg))
assert.EqualValues(t, true, gjson.GetBytes(msg, "identity.verifiable_addresses.0.verified").Bool(), string(msg))
assert.EqualValues(t, "completed", gjson.GetBytes(msg, "identity.verifiable_addresses.0.status").String(), string(msg))
})

email := "[email protected]"
createIdentityToRecover(t, reg, email)
check(t, expectSuccess(t, nil, false, false, func(v url.Values) {
v.Set("email", email)
}), email, "")

wg.Wait()
})

t.Run("description=should return browser to return url", func(t *testing.T) {
Expand Down

0 comments on commit b1dfe9d

Please sign in to comment.