From a6d3f8e9e5555a8e35fc7d411f65ab0a79a9afb0 Mon Sep 17 00:00:00 2001 From: David Barroso Date: Mon, 22 Apr 2024 14:41:41 +0200 Subject: [PATCH] fix: special treatment to custom claims for backwards compatibility #505 (#505) --- go/controller/custom_claims.go | 60 ++++++++++++++++++++++------- go/controller/custom_claims_test.go | 16 +++++++- 2 files changed, 61 insertions(+), 15 deletions(-) diff --git a/go/controller/custom_claims.go b/go/controller/custom_claims.go index 631c1686e..99af1c32f 100644 --- a/go/controller/custom_claims.go +++ b/go/controller/custom_claims.go @@ -8,6 +8,7 @@ import ( "fmt" "io" "net/http" + "reflect" "sort" "strings" @@ -158,26 +159,57 @@ func (c *CustomClaims) GraphQLQuery() string { return c.graphqlQuery } +func (c *CustomClaims) getClaimsBackwardsCompatibility(data any, path []string) any { + if len(path) == 0 { + return data + } + + curPath := strings.TrimRight(path[0], "[]") + + value := reflect.ValueOf(data) + switch value.Kind() { //nolint:exhaustive + case reflect.Map: + for _, key := range value.MapKeys() { + if key.String() == curPath { + return c.getClaimsBackwardsCompatibility(value.MapIndex(key).Interface(), path[1:]) + } + } + case reflect.Slice: + got := make([]any, value.Len()) + for i := 0; i < value.Len(); i++ { + got[i] = c.getClaimsBackwardsCompatibility(value.Index(i).Interface(), path) + } + return got + default: + // we should not reach here + } + + return nil +} + func (c *CustomClaims) ExtractClaims(data any) (map[string]any, error) { claims := make(map[string]any) for name, j := range c.jsonPaths { - v, err := j.jpath.FindResults(data) - if err != nil { - claims[name] = nil - continue - } - var got any - if j.IsArrary() { - g := make([]any, len(v[0])) - for i, r := range v[0] { - g[i] = r.Interface() - } - got = g + if strings.HasSuffix(j.path, "[]") { + got = c.getClaimsBackwardsCompatibility(data, strings.Split(j.path, ".")) } else { - got = v[0][0].Interface() - } + v, err := j.jpath.FindResults(data) + if err != nil { + claims[name] = nil + continue + } + if j.IsArrary() { + g := make([]any, len(v[0])) + for i, r := range v[0] { + g[i] = r.Interface() + } + got = g + } else { + got = v[0][0].Interface() + } + } claims[name] = got } return claims, nil diff --git a/go/controller/custom_claims_test.go b/go/controller/custom_claims_test.go index d284ad541..c6b763fea 100644 --- a/go/controller/custom_claims_test.go +++ b/go/controller/custom_claims_test.go @@ -20,6 +20,16 @@ func TestCustomClaims(t *testing.T) { {"id": 2}, {"id": 3}, }, + "ln": []any{ + []any{ + map[string]any{"id": 1}, + map[string]any{"id": 2}, + }, + []any{ + map[string]any{"id": 3}, + map[string]any{"id": 4}, + }, + }, }, "metadata": map[string]any{ "m1": 1, @@ -40,7 +50,9 @@ func TestCustomClaims(t *testing.T) { "element": "m.l[2]", "array[]": "m.l[]", "array[*]": "m.l[*]", - "array[].ids": "m.lm[*].id", + "array[].ids": "m.lm[].id", + "array[*].ids": "m.lm[*].id", + "array.ids[]": "m.lm.id[]", "arrayOneElement[]": "m.l2[]", "metadata.m1": "metadata.m1", "nonexistent": "nonexistent.nonexistent", @@ -54,6 +66,8 @@ func TestCustomClaims(t *testing.T) { "array[]": []any{"a", "b", "c"}, "array[*]": []any{"a", "b", "c"}, "array[].ids": []any{1, 2, 3}, + "array[*].ids": []any{1, 2, 3}, + "array.ids[]": []any{1, 2, 3}, "metadata.m1": 1, "nonexistent": nil, },