From a88b850daf7d732452d1c82bb7266323acf417b1 Mon Sep 17 00:00:00 2001 From: "Lixia (Sylvia) Lei" Date: Mon, 18 Sep 2023 17:17:17 +0800 Subject: [PATCH 01/29] draft Signed-off-by: Lixia (Sylvia) Lei --- registry/remote/auth/scope.go | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/registry/remote/auth/scope.go b/registry/remote/auth/scope.go index 24a0f898..4d417d3e 100644 --- a/registry/remote/auth/scope.go +++ b/registry/remote/auth/scope.go @@ -57,6 +57,8 @@ func ScopeRepository(repository string, actions ...string) string { // scopesContextKey is the context key for scopes. type scopesContextKey struct{} +type perRegistryScopesContextKey struct{} + // WithScopes returns a context with scopes added. Scopes are de-duplicated. // Scopes are used as hints for the auth client to fetch bearer tokens with // larger scopes. @@ -80,6 +82,13 @@ func WithScopes(ctx context.Context, scopes ...string) context.Context { return context.WithValue(ctx, scopesContextKey{}, scopes) } +func WithPerRegistryScopes(ctx context.Context, registry string, scopes ...string) context.Context { + scopes = CleanScopes(scopes) + regMap := make(map[string][]string, 0) + regMap[registry] = scopes + return context.WithValue(ctx, perRegistryScopesContextKey{}, regMap) +} + // AppendScopes appends additional scopes to the existing scopes in the context // and returns a new context. The resulted scopes are de-duplicated. // The append operation does modify the existing scope in the context passed in. @@ -98,6 +107,13 @@ func GetScopes(ctx context.Context) []string { return nil } +func GetPerRegistryScopes(ctx context.Context, registry string) []string { + if regMap, ok := ctx.Value(perRegistryScopesContextKey{}).(map[string][]string); ok { + return append([]string(nil), regMap[registry]...) + } + return nil +} + // CleanScopes merges and sort the actions in ascending order if the scopes have // the same resource type and name. The final scopes are sorted in ascending // order. In other words, the scopes passed in are de-duplicated and sorted. From 65c4acb13448160419dfd81d10c5541f998ce499 Mon Sep 17 00:00:00 2001 From: "Lixia (Sylvia) Lei" Date: Mon, 18 Sep 2023 19:09:59 +0800 Subject: [PATCH 02/29] fix + test Signed-off-by: Lixia (Sylvia) Lei --- registry/remote/auth/scope.go | 17 ++++--- registry/remote/auth/scope_test.go | 79 ++++++++++++++++++++++++++++++ 2 files changed, 90 insertions(+), 6 deletions(-) diff --git a/registry/remote/auth/scope.go b/registry/remote/auth/scope.go index 4d417d3e..5f312ede 100644 --- a/registry/remote/auth/scope.go +++ b/registry/remote/auth/scope.go @@ -57,7 +57,7 @@ func ScopeRepository(repository string, actions ...string) string { // scopesContextKey is the context key for scopes. type scopesContextKey struct{} -type perRegistryScopesContextKey struct{} +type ScopesPerRegistryContextKey struct{} // WithScopes returns a context with scopes added. Scopes are de-duplicated. // Scopes are used as hints for the auth client to fetch bearer tokens with @@ -82,11 +82,16 @@ func WithScopes(ctx context.Context, scopes ...string) context.Context { return context.WithValue(ctx, scopesContextKey{}, scopes) } -func WithPerRegistryScopes(ctx context.Context, registry string, scopes ...string) context.Context { +func WithScopesPerRegistry(ctx context.Context, registry string, scopes ...string) context.Context { + var regMap map[string][]string + var ok bool + regMap, ok = ctx.Value(ScopesPerRegistryContextKey{}).(map[string][]string) + if !ok { + regMap = make(map[string][]string, 0) + } scopes = CleanScopes(scopes) - regMap := make(map[string][]string, 0) regMap[registry] = scopes - return context.WithValue(ctx, perRegistryScopesContextKey{}, regMap) + return context.WithValue(ctx, ScopesPerRegistryContextKey{}, regMap) } // AppendScopes appends additional scopes to the existing scopes in the context @@ -107,8 +112,8 @@ func GetScopes(ctx context.Context) []string { return nil } -func GetPerRegistryScopes(ctx context.Context, registry string) []string { - if regMap, ok := ctx.Value(perRegistryScopesContextKey{}).(map[string][]string); ok { +func GetScopesPerRegistry(ctx context.Context, registry string) []string { + if regMap, ok := ctx.Value(ScopesPerRegistryContextKey{}).(map[string][]string); ok { return append([]string(nil), regMap[registry]...) } return nil diff --git a/registry/remote/auth/scope_test.go b/registry/remote/auth/scope_test.go index ac41ad7b..b86fa56d 100644 --- a/registry/remote/auth/scope_test.go +++ b/registry/remote/auth/scope_test.go @@ -148,6 +148,85 @@ func TestWithScopes(t *testing.T) { } } +func TestWithScopesPerRegistry(t *testing.T) { + ctx := context.Background() + reg1 := "registry1.example.com" + reg2 := "registry2.example.com" + + // with single scope + want1 := []string{ + "repository:foo:pull", + } + want2 := []string{ + "repository:foo:push", + } + ctx = WithScopesPerRegistry(ctx, reg1, want1...) + ctx = WithScopesPerRegistry(ctx, reg2, want2...) + if got := GetScopesPerRegistry(ctx, reg1); !reflect.DeepEqual(got, want1) { + t.Errorf("GetScopes(WithScopes()) = %v, want %v", got, want1) + } + if got := GetScopesPerRegistry(ctx, reg2); !reflect.DeepEqual(got, want2) { + t.Errorf("GetScopes(WithScopes()) = %v, want %v", got, want2) + } + + // overwrite scopes + want1 = []string{ + "repository:bar:push", + } + want2 = []string{ + "repository:bar:pull", + } + ctx = WithScopesPerRegistry(ctx, reg1, want1...) + ctx = WithScopesPerRegistry(ctx, reg2, want2...) + if got := GetScopesPerRegistry(ctx, reg1); !reflect.DeepEqual(got, want1) { + t.Errorf("GetScopes(WithScopes()) = %v, want %v", got, want1) + } + if got := GetScopesPerRegistry(ctx, reg2); !reflect.DeepEqual(got, want2) { + t.Errorf("GetScopes(WithScopes()) = %v, want %v", got, want2) + } + + // overwrite scopes with de-duplication + scopes1 := []string{ + "repository:hello-world:push", + "repository:alpine:delete", + "repository:hello-world:pull", + "repository:alpine:delete", + } + want1 = []string{ + "repository:alpine:delete", + "repository:hello-world:pull,push", + } + scopes2 := []string{ + "repository:goodbye-world:push", + "repository:nginx:delete", + "repository:goodbye-world:pull", + "repository:nginx:delete", + } + want2 = []string{ + "repository:goodbye-world:pull,push", + "repository:nginx:delete", + } + ctx = WithScopesPerRegistry(ctx, reg1, scopes1...) + ctx = WithScopesPerRegistry(ctx, reg2, scopes2...) + if got := GetScopesPerRegistry(ctx, reg1); !reflect.DeepEqual(got, want1) { + t.Errorf("GetScopes(WithScopes()) = %v, want %v", got, want1) + } + if got := GetScopesPerRegistry(ctx, reg2); !reflect.DeepEqual(got, want2) { + t.Errorf("GetScopes(WithScopes()) = %v, want %v", got, want2) + } + + // clean scopes + var want []string + ctx = WithScopesPerRegistry(ctx, reg1, want...) + ctx = WithScopesPerRegistry(ctx, reg2, want...) + if got := GetScopesPerRegistry(ctx, reg1); !reflect.DeepEqual(got, want) { + t.Errorf("GetScopes(WithScopes()) = %v, want %v", got, want) + } + if got := GetScopesPerRegistry(ctx, reg2); !reflect.DeepEqual(got, want) { + t.Errorf("GetScopes(WithScopes()) = %v, want %v", got, want) + } +} + func TestAppendScopes(t *testing.T) { ctx := context.Background() From 1657b055a82e0ee7932af171e5ff49ba18457d23 Mon Sep 17 00:00:00 2001 From: "Lixia (Sylvia) Lei" Date: Tue, 19 Sep 2023 16:08:20 +0800 Subject: [PATCH 03/29] unit tests Signed-off-by: Lixia (Sylvia) Lei --- registry/remote/auth/scope.go | 33 +++++---- registry/remote/auth/scope_test.go | 108 +++++++++++++++++++++++------ 2 files changed, 107 insertions(+), 34 deletions(-) diff --git a/registry/remote/auth/scope.go b/registry/remote/auth/scope.go index 5f312ede..7f4ce2bc 100644 --- a/registry/remote/auth/scope.go +++ b/registry/remote/auth/scope.go @@ -82,18 +82,6 @@ func WithScopes(ctx context.Context, scopes ...string) context.Context { return context.WithValue(ctx, scopesContextKey{}, scopes) } -func WithScopesPerRegistry(ctx context.Context, registry string, scopes ...string) context.Context { - var regMap map[string][]string - var ok bool - regMap, ok = ctx.Value(ScopesPerRegistryContextKey{}).(map[string][]string) - if !ok { - regMap = make(map[string][]string, 0) - } - scopes = CleanScopes(scopes) - regMap[registry] = scopes - return context.WithValue(ctx, ScopesPerRegistryContextKey{}, regMap) -} - // AppendScopes appends additional scopes to the existing scopes in the context // and returns a new context. The resulted scopes are de-duplicated. // The append operation does modify the existing scope in the context passed in. @@ -112,6 +100,27 @@ func GetScopes(ctx context.Context) []string { return nil } +func WithScopesPerRegistry(ctx context.Context, registry string, scopes ...string) context.Context { + var regMap map[string][]string + var ok bool + regMap, ok = ctx.Value(ScopesPerRegistryContextKey{}).(map[string][]string) + if !ok { + regMap = make(map[string][]string, 0) + } + scopes = CleanScopes(scopes) + regMap[registry] = scopes + return context.WithValue(ctx, ScopesPerRegistryContextKey{}, regMap) +} + +func AppendScopesPerRegistry(ctx context.Context, registry string, scopes ...string) context.Context { + if len(scopes) == 0 { + return ctx + } + + oldScopes := GetScopesPerRegistry(ctx, registry) + return WithScopesPerRegistry(ctx, registry, append(oldScopes, scopes...)...) +} + func GetScopesPerRegistry(ctx context.Context, registry string) []string { if regMap, ok := ctx.Value(ScopesPerRegistryContextKey{}).(map[string][]string); ok { return append([]string(nil), regMap[registry]...) diff --git a/registry/remote/auth/scope_test.go b/registry/remote/auth/scope_test.go index b86fa56d..2b79cd89 100644 --- a/registry/remote/auth/scope_test.go +++ b/registry/remote/auth/scope_test.go @@ -148,6 +148,42 @@ func TestWithScopes(t *testing.T) { } } +func TestAppendScopes(t *testing.T) { + ctx := context.Background() + + // append single scope + want := []string{ + "repository:foo:pull", + } + ctx = AppendScopes(ctx, want...) + if got := GetScopes(ctx); !reflect.DeepEqual(got, want) { + t.Errorf("GetScopes(AppendScopes()) = %v, want %v", got, want) + } + + // append scopes with de-duplication + scopes := []string{ + "repository:hello-world:push", + "repository:alpine:delete", + "repository:hello-world:pull", + "repository:alpine:delete", + } + want = []string{ + "repository:alpine:delete", + "repository:foo:pull", + "repository:hello-world:pull,push", + } + ctx = AppendScopes(ctx, scopes...) + if got := GetScopes(ctx); !reflect.DeepEqual(got, want) { + t.Errorf("GetScopes(AppendScopes()) = %v, want %v", got, want) + } + + // append empty scopes + ctx = AppendScopes(ctx) + if got := GetScopes(ctx); !reflect.DeepEqual(got, want) { + t.Errorf("GetScopes(AppendScopes()) = %v, want %v", got, want) + } +} + func TestWithScopesPerRegistry(t *testing.T) { ctx := context.Background() reg1 := "registry1.example.com" @@ -163,10 +199,10 @@ func TestWithScopesPerRegistry(t *testing.T) { ctx = WithScopesPerRegistry(ctx, reg1, want1...) ctx = WithScopesPerRegistry(ctx, reg2, want2...) if got := GetScopesPerRegistry(ctx, reg1); !reflect.DeepEqual(got, want1) { - t.Errorf("GetScopes(WithScopes()) = %v, want %v", got, want1) + t.Errorf("GetScopesPerRegistry(WithScopesPerRegistry()) = %v, want %v", got, want1) } if got := GetScopesPerRegistry(ctx, reg2); !reflect.DeepEqual(got, want2) { - t.Errorf("GetScopes(WithScopes()) = %v, want %v", got, want2) + t.Errorf("GetScopesPerRegistry(WithScopesPerRegistry()) = %v, want %v", got, want2) } // overwrite scopes @@ -179,10 +215,10 @@ func TestWithScopesPerRegistry(t *testing.T) { ctx = WithScopesPerRegistry(ctx, reg1, want1...) ctx = WithScopesPerRegistry(ctx, reg2, want2...) if got := GetScopesPerRegistry(ctx, reg1); !reflect.DeepEqual(got, want1) { - t.Errorf("GetScopes(WithScopes()) = %v, want %v", got, want1) + t.Errorf("GetScopesPerRegistry(WithScopesPerRegistry()) = %v, want %v", got, want1) } if got := GetScopesPerRegistry(ctx, reg2); !reflect.DeepEqual(got, want2) { - t.Errorf("GetScopes(WithScopes()) = %v, want %v", got, want2) + t.Errorf("GetScopesPerRegistry(WithScopesPerRegistry()) = %v, want %v", got, want2) } // overwrite scopes with de-duplication @@ -209,10 +245,10 @@ func TestWithScopesPerRegistry(t *testing.T) { ctx = WithScopesPerRegistry(ctx, reg1, scopes1...) ctx = WithScopesPerRegistry(ctx, reg2, scopes2...) if got := GetScopesPerRegistry(ctx, reg1); !reflect.DeepEqual(got, want1) { - t.Errorf("GetScopes(WithScopes()) = %v, want %v", got, want1) + t.Errorf("GetScopesPerRegistry(WithScopesPerRegistry()) = %v, want %v", got, want1) } if got := GetScopesPerRegistry(ctx, reg2); !reflect.DeepEqual(got, want2) { - t.Errorf("GetScopes(WithScopes()) = %v, want %v", got, want2) + t.Errorf("GetScopesPerRegistry(WithScopesPerRegistry()) = %v, want %v", got, want2) } // clean scopes @@ -220,46 +256,74 @@ func TestWithScopesPerRegistry(t *testing.T) { ctx = WithScopesPerRegistry(ctx, reg1, want...) ctx = WithScopesPerRegistry(ctx, reg2, want...) if got := GetScopesPerRegistry(ctx, reg1); !reflect.DeepEqual(got, want) { - t.Errorf("GetScopes(WithScopes()) = %v, want %v", got, want) + t.Errorf("GetScopesPerRegistry(WithScopesPerRegistry()) = %v, want %v", got, want) } if got := GetScopesPerRegistry(ctx, reg2); !reflect.DeepEqual(got, want) { - t.Errorf("GetScopes(WithScopes()) = %v, want %v", got, want) + t.Errorf("GetScopesPerRegistry(WithScopesPerRegistry()) = %v, want %v", got, want) } } -func TestAppendScopes(t *testing.T) { +func TestAppendScopesPerRegistry(t *testing.T) { ctx := context.Background() + reg1 := "registry1.example.com" + reg2 := "registry2.example.com" - // append single scope - want := []string{ + // with single scope + want1 := []string{ "repository:foo:pull", } - ctx = AppendScopes(ctx, want...) - if got := GetScopes(ctx); !reflect.DeepEqual(got, want) { - t.Errorf("GetScopes(AppendScopes()) = %v, want %v", got, want) + want2 := []string{ + "repository:foo:push", + } + ctx = AppendScopesPerRegistry(ctx, reg1, want1...) + ctx = AppendScopesPerRegistry(ctx, reg2, want2...) + if got := GetScopesPerRegistry(ctx, reg1); !reflect.DeepEqual(got, want1) { + t.Errorf("GetScopesPerRegistry(AppendScopesPerRegistry()) = %v, want %v", got, want1) + } + if got := GetScopesPerRegistry(ctx, reg2); !reflect.DeepEqual(got, want2) { + t.Errorf("GetScopesPerRegistry(AppendScopesPerRegistry()) = %v, want %v", got, want2) } // append scopes with de-duplication - scopes := []string{ + scopes1 := []string{ "repository:hello-world:push", "repository:alpine:delete", "repository:hello-world:pull", "repository:alpine:delete", } - want = []string{ + want1 = []string{ "repository:alpine:delete", "repository:foo:pull", "repository:hello-world:pull,push", } - ctx = AppendScopes(ctx, scopes...) - if got := GetScopes(ctx); !reflect.DeepEqual(got, want) { - t.Errorf("GetScopes(AppendScopes()) = %v, want %v", got, want) + scopes2 := []string{ + "repository:goodbye-world:push", + "repository:nginx:delete", + "repository:goodbye-world:pull", + "repository:nginx:delete", + } + want2 = []string{ + "repository:foo:push", + "repository:goodbye-world:pull,push", + "repository:nginx:delete", + } + ctx = AppendScopesPerRegistry(ctx, reg1, scopes1...) + ctx = AppendScopesPerRegistry(ctx, reg2, scopes2...) + if got := GetScopesPerRegistry(ctx, reg1); !reflect.DeepEqual(got, want1) { + t.Errorf("GetScopesPerRegistry(AppendScopesPerRegistry()) = %v, want %v", got, want1) + } + if got := GetScopesPerRegistry(ctx, reg2); !reflect.DeepEqual(got, want2) { + t.Errorf("GetScopesPerRegistry(AppendScopesPerRegistry()) = %v, want %v", got, want2) } // append empty scopes - ctx = AppendScopes(ctx) - if got := GetScopes(ctx); !reflect.DeepEqual(got, want) { - t.Errorf("GetScopes(AppendScopes()) = %v, want %v", got, want) + ctx = AppendScopesPerRegistry(ctx, reg1) + ctx = AppendScopesPerRegistry(ctx, reg2) + if got := GetScopesPerRegistry(ctx, reg1); !reflect.DeepEqual(got, want1) { + t.Errorf("GetScopesPerRegistry(AppendScopesPerRegistry()) = %v, want %v", got, want1) + } + if got := GetScopesPerRegistry(ctx, reg2); !reflect.DeepEqual(got, want2) { + t.Errorf("GetScopesPerRegistry(AppendScopesPerRegistry()) = %v, want %v", got, want2) } } From e260ef9d096f96a00cd7c41836df78a9f857f5e5 Mon Sep 17 00:00:00 2001 From: "Lixia (Sylvia) Lei" Date: Tue, 19 Sep 2023 16:16:08 +0800 Subject: [PATCH 04/29] update internal Signed-off-by: Lixia (Sylvia) Lei --- internal/registryutil/auth.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/registryutil/auth.go b/internal/registryutil/auth.go index 4a601f0c..04aa0b09 100644 --- a/internal/registryutil/auth.go +++ b/internal/registryutil/auth.go @@ -25,5 +25,5 @@ import ( // WithScopeHint adds a hinted scope to the context. func WithScopeHint(ctx context.Context, ref registry.Reference, actions ...string) context.Context { scope := auth.ScopeRepository(ref.Repository, actions...) - return auth.AppendScopes(ctx, scope) + return auth.AppendScopesPerRegistry(ctx, ref.Registry, scope) } From 6737ba20b4ffc4fb0c00ea973e7498d7182c2a7c Mon Sep 17 00:00:00 2001 From: "Lixia (Sylvia) Lei" Date: Tue, 19 Sep 2023 17:23:38 +0800 Subject: [PATCH 05/29] update client Signed-off-by: Lixia (Sylvia) Lei --- registry/remote/auth/client.go | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/registry/remote/auth/client.go b/registry/remote/auth/client.go index b4b0261a..c84d411c 100644 --- a/registry/remote/auth/client.go +++ b/registry/remote/auth/client.go @@ -177,6 +177,7 @@ func (c *Client) Do(originalReq *http.Request) (*http.Response, error) { // attempt cached auth token var attemptedKey string cache := c.cache() + // TODO: handle docker.io? registry := originalReq.Host scheme, err := cache.GetScheme(ctx, registry) if err == nil { @@ -187,7 +188,11 @@ func (c *Client) Do(originalReq *http.Request) (*http.Response, error) { req.Header.Set("Authorization", "Basic "+token) } case SchemeBearer: - scopes := GetScopes(ctx) + scopes := GetScopesPerRegistry(ctx, registry) + if len(scopes) == 0 { + // fallback to get scopes + scopes = GetScopes(ctx) + } attemptedKey = strings.Join(scopes, " ") token, err := cache.GetToken(ctx, registry, SchemeBearer, attemptedKey) if err == nil { @@ -224,7 +229,11 @@ func (c *Client) Do(originalReq *http.Request) (*http.Response, error) { resp.Body.Close() // merge hinted scopes with challenged scopes - scopes := GetScopes(ctx) + scopes := GetScopesPerRegistry(ctx, registry) + if len(scopes) == 0 { + // fallback to get scopes + scopes = GetScopes(ctx) + } if scope := params["scope"]; scope != "" { scopes = append(scopes, strings.Split(scope, " ")...) scopes = CleanScopes(scopes) From 4f80680df9180d8fd769aa8a8935f5bcbdde3e06 Mon Sep 17 00:00:00 2001 From: "Lixia (Sylvia) Lei" Date: Wed, 20 Sep 2023 14:51:32 +0800 Subject: [PATCH 06/29] TODOs Signed-off-by: Lixia (Sylvia) Lei --- internal/registryutil/auth.go | 1 + registry/remote/auth/scope.go | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/internal/registryutil/auth.go b/internal/registryutil/auth.go index 04aa0b09..0984da74 100644 --- a/internal/registryutil/auth.go +++ b/internal/registryutil/auth.go @@ -22,6 +22,7 @@ import ( "oras.land/oras-go/v2/registry/remote/auth" ) +// TODO: where to put this? // WithScopeHint adds a hinted scope to the context. func WithScopeHint(ctx context.Context, ref registry.Reference, actions ...string) context.Context { scope := auth.ScopeRepository(ref.Repository, actions...) diff --git a/registry/remote/auth/scope.go b/registry/remote/auth/scope.go index 7f4ce2bc..493fe114 100644 --- a/registry/remote/auth/scope.go +++ b/registry/remote/auth/scope.go @@ -105,7 +105,7 @@ func WithScopesPerRegistry(ctx context.Context, registry string, scopes ...strin var ok bool regMap, ok = ctx.Value(ScopesPerRegistryContextKey{}).(map[string][]string) if !ok { - regMap = make(map[string][]string, 0) + regMap = make(map[string][]string) } scopes = CleanScopes(scopes) regMap[registry] = scopes From 14547a7a8600790051d836e024b5293d50f16205 Mon Sep 17 00:00:00 2001 From: "Lixia (Sylvia) Lei" Date: Wed, 20 Sep 2023 15:12:40 +0800 Subject: [PATCH 07/29] rename Signed-off-by: Lixia (Sylvia) Lei --- internal/registryutil/auth.go | 2 +- registry/remote/auth/client.go | 4 +-- registry/remote/auth/scope.go | 23 ++++++++---- registry/remote/auth/scope_test.go | 56 +++++++++++++++--------------- 4 files changed, 47 insertions(+), 38 deletions(-) diff --git a/internal/registryutil/auth.go b/internal/registryutil/auth.go index 0984da74..0e318aef 100644 --- a/internal/registryutil/auth.go +++ b/internal/registryutil/auth.go @@ -26,5 +26,5 @@ import ( // WithScopeHint adds a hinted scope to the context. func WithScopeHint(ctx context.Context, ref registry.Reference, actions ...string) context.Context { scope := auth.ScopeRepository(ref.Repository, actions...) - return auth.AppendScopesPerRegistry(ctx, ref.Registry, scope) + return auth.AppendScopesPerHost(ctx, ref.Registry, scope) } diff --git a/registry/remote/auth/client.go b/registry/remote/auth/client.go index c84d411c..0b9b8e85 100644 --- a/registry/remote/auth/client.go +++ b/registry/remote/auth/client.go @@ -188,7 +188,7 @@ func (c *Client) Do(originalReq *http.Request) (*http.Response, error) { req.Header.Set("Authorization", "Basic "+token) } case SchemeBearer: - scopes := GetScopesPerRegistry(ctx, registry) + scopes := GetScopesPerHost(ctx, registry) if len(scopes) == 0 { // fallback to get scopes scopes = GetScopes(ctx) @@ -229,7 +229,7 @@ func (c *Client) Do(originalReq *http.Request) (*http.Response, error) { resp.Body.Close() // merge hinted scopes with challenged scopes - scopes := GetScopesPerRegistry(ctx, registry) + scopes := GetScopesPerHost(ctx, registry) if len(scopes) == 0 { // fallback to get scopes scopes = GetScopes(ctx) diff --git a/registry/remote/auth/scope.go b/registry/remote/auth/scope.go index 493fe114..dd0ef82d 100644 --- a/registry/remote/auth/scope.go +++ b/registry/remote/auth/scope.go @@ -19,6 +19,8 @@ import ( "context" "sort" "strings" + + "oras.land/oras-go/v2/registry" ) // Actions used in scopes. @@ -100,7 +102,7 @@ func GetScopes(ctx context.Context) []string { return nil } -func WithScopesPerRegistry(ctx context.Context, registry string, scopes ...string) context.Context { +func WithScopesPerHost(ctx context.Context, host string, scopes ...string) context.Context { var regMap map[string][]string var ok bool regMap, ok = ctx.Value(ScopesPerRegistryContextKey{}).(map[string][]string) @@ -108,26 +110,33 @@ func WithScopesPerRegistry(ctx context.Context, registry string, scopes ...strin regMap = make(map[string][]string) } scopes = CleanScopes(scopes) - regMap[registry] = scopes + regMap[host] = scopes return context.WithValue(ctx, ScopesPerRegistryContextKey{}, regMap) } -func AppendScopesPerRegistry(ctx context.Context, registry string, scopes ...string) context.Context { +func AppendScopesPerHost(ctx context.Context, host string, scopes ...string) context.Context { if len(scopes) == 0 { return ctx } - oldScopes := GetScopesPerRegistry(ctx, registry) - return WithScopesPerRegistry(ctx, registry, append(oldScopes, scopes...)...) + oldScopes := GetScopesPerHost(ctx, host) + return WithScopesPerHost(ctx, host, append(oldScopes, scopes...)...) } -func GetScopesPerRegistry(ctx context.Context, registry string) []string { +func GetScopesPerHost(ctx context.Context, host string) []string { if regMap, ok := ctx.Value(ScopesPerRegistryContextKey{}).(map[string][]string); ok { - return append([]string(nil), regMap[registry]...) + return append([]string(nil), regMap[host]...) } return nil } +// TODO: where to put this? +// WithScopeHints adds a hinted scope to the context. +func WithScopeHints(ctx context.Context, ref registry.Reference, actions ...string) context.Context { + scope := ScopeRepository(ref.Repository, actions...) + return AppendScopesPerHost(ctx, ref.Host(), scope) +} + // CleanScopes merges and sort the actions in ascending order if the scopes have // the same resource type and name. The final scopes are sorted in ascending // order. In other words, the scopes passed in are de-duplicated and sorted. diff --git a/registry/remote/auth/scope_test.go b/registry/remote/auth/scope_test.go index 2b79cd89..087e7ec5 100644 --- a/registry/remote/auth/scope_test.go +++ b/registry/remote/auth/scope_test.go @@ -196,12 +196,12 @@ func TestWithScopesPerRegistry(t *testing.T) { want2 := []string{ "repository:foo:push", } - ctx = WithScopesPerRegistry(ctx, reg1, want1...) - ctx = WithScopesPerRegistry(ctx, reg2, want2...) - if got := GetScopesPerRegistry(ctx, reg1); !reflect.DeepEqual(got, want1) { + ctx = WithScopesPerHost(ctx, reg1, want1...) + ctx = WithScopesPerHost(ctx, reg2, want2...) + if got := GetScopesPerHost(ctx, reg1); !reflect.DeepEqual(got, want1) { t.Errorf("GetScopesPerRegistry(WithScopesPerRegistry()) = %v, want %v", got, want1) } - if got := GetScopesPerRegistry(ctx, reg2); !reflect.DeepEqual(got, want2) { + if got := GetScopesPerHost(ctx, reg2); !reflect.DeepEqual(got, want2) { t.Errorf("GetScopesPerRegistry(WithScopesPerRegistry()) = %v, want %v", got, want2) } @@ -212,12 +212,12 @@ func TestWithScopesPerRegistry(t *testing.T) { want2 = []string{ "repository:bar:pull", } - ctx = WithScopesPerRegistry(ctx, reg1, want1...) - ctx = WithScopesPerRegistry(ctx, reg2, want2...) - if got := GetScopesPerRegistry(ctx, reg1); !reflect.DeepEqual(got, want1) { + ctx = WithScopesPerHost(ctx, reg1, want1...) + ctx = WithScopesPerHost(ctx, reg2, want2...) + if got := GetScopesPerHost(ctx, reg1); !reflect.DeepEqual(got, want1) { t.Errorf("GetScopesPerRegistry(WithScopesPerRegistry()) = %v, want %v", got, want1) } - if got := GetScopesPerRegistry(ctx, reg2); !reflect.DeepEqual(got, want2) { + if got := GetScopesPerHost(ctx, reg2); !reflect.DeepEqual(got, want2) { t.Errorf("GetScopesPerRegistry(WithScopesPerRegistry()) = %v, want %v", got, want2) } @@ -242,23 +242,23 @@ func TestWithScopesPerRegistry(t *testing.T) { "repository:goodbye-world:pull,push", "repository:nginx:delete", } - ctx = WithScopesPerRegistry(ctx, reg1, scopes1...) - ctx = WithScopesPerRegistry(ctx, reg2, scopes2...) - if got := GetScopesPerRegistry(ctx, reg1); !reflect.DeepEqual(got, want1) { + ctx = WithScopesPerHost(ctx, reg1, scopes1...) + ctx = WithScopesPerHost(ctx, reg2, scopes2...) + if got := GetScopesPerHost(ctx, reg1); !reflect.DeepEqual(got, want1) { t.Errorf("GetScopesPerRegistry(WithScopesPerRegistry()) = %v, want %v", got, want1) } - if got := GetScopesPerRegistry(ctx, reg2); !reflect.DeepEqual(got, want2) { + if got := GetScopesPerHost(ctx, reg2); !reflect.DeepEqual(got, want2) { t.Errorf("GetScopesPerRegistry(WithScopesPerRegistry()) = %v, want %v", got, want2) } // clean scopes var want []string - ctx = WithScopesPerRegistry(ctx, reg1, want...) - ctx = WithScopesPerRegistry(ctx, reg2, want...) - if got := GetScopesPerRegistry(ctx, reg1); !reflect.DeepEqual(got, want) { + ctx = WithScopesPerHost(ctx, reg1, want...) + ctx = WithScopesPerHost(ctx, reg2, want...) + if got := GetScopesPerHost(ctx, reg1); !reflect.DeepEqual(got, want) { t.Errorf("GetScopesPerRegistry(WithScopesPerRegistry()) = %v, want %v", got, want) } - if got := GetScopesPerRegistry(ctx, reg2); !reflect.DeepEqual(got, want) { + if got := GetScopesPerHost(ctx, reg2); !reflect.DeepEqual(got, want) { t.Errorf("GetScopesPerRegistry(WithScopesPerRegistry()) = %v, want %v", got, want) } } @@ -275,12 +275,12 @@ func TestAppendScopesPerRegistry(t *testing.T) { want2 := []string{ "repository:foo:push", } - ctx = AppendScopesPerRegistry(ctx, reg1, want1...) - ctx = AppendScopesPerRegistry(ctx, reg2, want2...) - if got := GetScopesPerRegistry(ctx, reg1); !reflect.DeepEqual(got, want1) { + ctx = AppendScopesPerHost(ctx, reg1, want1...) + ctx = AppendScopesPerHost(ctx, reg2, want2...) + if got := GetScopesPerHost(ctx, reg1); !reflect.DeepEqual(got, want1) { t.Errorf("GetScopesPerRegistry(AppendScopesPerRegistry()) = %v, want %v", got, want1) } - if got := GetScopesPerRegistry(ctx, reg2); !reflect.DeepEqual(got, want2) { + if got := GetScopesPerHost(ctx, reg2); !reflect.DeepEqual(got, want2) { t.Errorf("GetScopesPerRegistry(AppendScopesPerRegistry()) = %v, want %v", got, want2) } @@ -307,22 +307,22 @@ func TestAppendScopesPerRegistry(t *testing.T) { "repository:goodbye-world:pull,push", "repository:nginx:delete", } - ctx = AppendScopesPerRegistry(ctx, reg1, scopes1...) - ctx = AppendScopesPerRegistry(ctx, reg2, scopes2...) - if got := GetScopesPerRegistry(ctx, reg1); !reflect.DeepEqual(got, want1) { + ctx = AppendScopesPerHost(ctx, reg1, scopes1...) + ctx = AppendScopesPerHost(ctx, reg2, scopes2...) + if got := GetScopesPerHost(ctx, reg1); !reflect.DeepEqual(got, want1) { t.Errorf("GetScopesPerRegistry(AppendScopesPerRegistry()) = %v, want %v", got, want1) } - if got := GetScopesPerRegistry(ctx, reg2); !reflect.DeepEqual(got, want2) { + if got := GetScopesPerHost(ctx, reg2); !reflect.DeepEqual(got, want2) { t.Errorf("GetScopesPerRegistry(AppendScopesPerRegistry()) = %v, want %v", got, want2) } // append empty scopes - ctx = AppendScopesPerRegistry(ctx, reg1) - ctx = AppendScopesPerRegistry(ctx, reg2) - if got := GetScopesPerRegistry(ctx, reg1); !reflect.DeepEqual(got, want1) { + ctx = AppendScopesPerHost(ctx, reg1) + ctx = AppendScopesPerHost(ctx, reg2) + if got := GetScopesPerHost(ctx, reg1); !reflect.DeepEqual(got, want1) { t.Errorf("GetScopesPerRegistry(AppendScopesPerRegistry()) = %v, want %v", got, want1) } - if got := GetScopesPerRegistry(ctx, reg2); !reflect.DeepEqual(got, want2) { + if got := GetScopesPerHost(ctx, reg2); !reflect.DeepEqual(got, want2) { t.Errorf("GetScopesPerRegistry(AppendScopesPerRegistry()) = %v, want %v", got, want2) } } From 57d018188e634801e0e59740a83859d4ece20200 Mon Sep 17 00:00:00 2001 From: "Lixia (Sylvia) Lei" Date: Wed, 20 Sep 2023 15:38:36 +0800 Subject: [PATCH 08/29] rename Signed-off-by: Lixia (Sylvia) Lei --- content.go | 5 ++--- internal/registryutil/auth.go | 30 --------------------------- registry/remote/auth/client.go | 37 ++++++++++++++-------------------- registry/remote/auth/scope.go | 8 ++++---- registry/remote/registry.go | 2 +- registry/remote/repository.go | 31 ++++++++++++++-------------- 6 files changed, 37 insertions(+), 76 deletions(-) delete mode 100644 internal/registryutil/auth.go diff --git a/content.go b/content.go index 53eb6c75..2d1ff3f7 100644 --- a/content.go +++ b/content.go @@ -29,7 +29,6 @@ import ( "oras.land/oras-go/v2/internal/docker" "oras.land/oras-go/v2/internal/interfaces" "oras.land/oras-go/v2/internal/platform" - "oras.land/oras-go/v2/internal/registryutil" "oras.land/oras-go/v2/internal/syncutil" "oras.land/oras-go/v2/registry" "oras.land/oras-go/v2/registry/remote/auth" @@ -91,7 +90,7 @@ func TagN(ctx context.Context, target Target, srcReference string, dstReferences if err != nil { return ocispec.Descriptor{}, err } - ctx = registryutil.WithScopeHint(ctx, ref, auth.ActionPull, auth.ActionPush) + ctx = auth.WithScopeHints(ctx, ref, auth.ActionPull, auth.ActionPush) } desc, contentBytes, err := FetchBytes(ctx, target, srcReference, FetchBytesOptions{ @@ -149,7 +148,7 @@ func Tag(ctx context.Context, target Target, src, dst string) (ocispec.Descripto if err != nil { return ocispec.Descriptor{}, err } - ctx = registryutil.WithScopeHint(ctx, ref, auth.ActionPull, auth.ActionPush) + ctx = auth.WithScopeHints(ctx, ref, auth.ActionPull, auth.ActionPush) } desc, rc, err := refFetcher.FetchReference(ctx, src) if err != nil { diff --git a/internal/registryutil/auth.go b/internal/registryutil/auth.go deleted file mode 100644 index 0e318aef..00000000 --- a/internal/registryutil/auth.go +++ /dev/null @@ -1,30 +0,0 @@ -/* -Copyright The ORAS Authors. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package registryutil - -import ( - "context" - - "oras.land/oras-go/v2/registry" - "oras.land/oras-go/v2/registry/remote/auth" -) - -// TODO: where to put this? -// WithScopeHint adds a hinted scope to the context. -func WithScopeHint(ctx context.Context, ref registry.Reference, actions ...string) context.Context { - scope := auth.ScopeRepository(ref.Repository, actions...) - return auth.AppendScopesPerHost(ctx, ref.Registry, scope) -} diff --git a/registry/remote/auth/client.go b/registry/remote/auth/client.go index 0b9b8e85..22e7b8aa 100644 --- a/registry/remote/auth/client.go +++ b/registry/remote/auth/client.go @@ -177,24 +177,20 @@ func (c *Client) Do(originalReq *http.Request) (*http.Response, error) { // attempt cached auth token var attemptedKey string cache := c.cache() - // TODO: handle docker.io? - registry := originalReq.Host - scheme, err := cache.GetScheme(ctx, registry) + host := originalReq.Host + scheme, err := cache.GetScheme(ctx, host) if err == nil { switch scheme { case SchemeBasic: - token, err := cache.GetToken(ctx, registry, SchemeBasic, "") + token, err := cache.GetToken(ctx, host, SchemeBasic, "") if err == nil { req.Header.Set("Authorization", "Basic "+token) } case SchemeBearer: - scopes := GetScopesPerHost(ctx, registry) - if len(scopes) == 0 { - // fallback to get scopes - scopes = GetScopes(ctx) - } + // merge per-host scopes with generic scopes + scopes := append(GetScopesPerHost(ctx, host), GetScopes(ctx)...) attemptedKey = strings.Join(scopes, " ") - token, err := cache.GetToken(ctx, registry, SchemeBearer, attemptedKey) + token, err := cache.GetToken(ctx, host, SchemeBearer, attemptedKey) if err == nil { req.Header.Set("Authorization", "Bearer "+token) } @@ -216,8 +212,8 @@ func (c *Client) Do(originalReq *http.Request) (*http.Response, error) { case SchemeBasic: resp.Body.Close() - token, err := cache.Set(ctx, registry, SchemeBasic, "", func(ctx context.Context) (string, error) { - return c.fetchBasicAuth(ctx, registry) + token, err := cache.Set(ctx, host, SchemeBasic, "", func(ctx context.Context) (string, error) { + return c.fetchBasicAuth(ctx, host) }) if err != nil { return nil, fmt.Errorf("%s %q: %w", resp.Request.Method, resp.Request.URL, err) @@ -228,21 +224,18 @@ func (c *Client) Do(originalReq *http.Request) (*http.Response, error) { case SchemeBearer: resp.Body.Close() + // merge per-host scopes with generic scopes + scopes := append(GetScopesPerHost(ctx, host), GetScopes(ctx)...) // merge hinted scopes with challenged scopes - scopes := GetScopesPerHost(ctx, registry) - if len(scopes) == 0 { - // fallback to get scopes - scopes = GetScopes(ctx) - } - if scope := params["scope"]; scope != "" { - scopes = append(scopes, strings.Split(scope, " ")...) + if paramScope := params["scope"]; paramScope != "" { + scopes = append(scopes, strings.Split(paramScope, " ")...) scopes = CleanScopes(scopes) } key := strings.Join(scopes, " ") // attempt the cache again if there is a scope change if key != attemptedKey { - if token, err := cache.GetToken(ctx, registry, SchemeBearer, key); err == nil { + if token, err := cache.GetToken(ctx, host, SchemeBearer, key); err == nil { req = originalReq.Clone(ctx) req.Header.Set("Authorization", "Bearer "+token) if err := rewindRequestBody(req); err != nil { @@ -263,8 +256,8 @@ func (c *Client) Do(originalReq *http.Request) (*http.Response, error) { // attempt with credentials realm := params["realm"] service := params["service"] - token, err := cache.Set(ctx, registry, SchemeBearer, key, func(ctx context.Context) (string, error) { - return c.fetchBearerToken(ctx, registry, realm, service, scopes) + token, err := cache.Set(ctx, host, SchemeBearer, key, func(ctx context.Context) (string, error) { + return c.fetchBearerToken(ctx, host, realm, service, scopes) }) if err != nil { return nil, fmt.Errorf("%s %q: %w", resp.Request.Method, resp.Request.URL, err) diff --git a/registry/remote/auth/scope.go b/registry/remote/auth/scope.go index dd0ef82d..776a18fe 100644 --- a/registry/remote/auth/scope.go +++ b/registry/remote/auth/scope.go @@ -59,7 +59,7 @@ func ScopeRepository(repository string, actions ...string) string { // scopesContextKey is the context key for scopes. type scopesContextKey struct{} -type ScopesPerRegistryContextKey struct{} +type scopesPerHostContextKey struct{} // WithScopes returns a context with scopes added. Scopes are de-duplicated. // Scopes are used as hints for the auth client to fetch bearer tokens with @@ -105,13 +105,13 @@ func GetScopes(ctx context.Context) []string { func WithScopesPerHost(ctx context.Context, host string, scopes ...string) context.Context { var regMap map[string][]string var ok bool - regMap, ok = ctx.Value(ScopesPerRegistryContextKey{}).(map[string][]string) + regMap, ok = ctx.Value(scopesPerHostContextKey{}).(map[string][]string) if !ok { regMap = make(map[string][]string) } scopes = CleanScopes(scopes) regMap[host] = scopes - return context.WithValue(ctx, ScopesPerRegistryContextKey{}, regMap) + return context.WithValue(ctx, scopesPerHostContextKey{}, regMap) } func AppendScopesPerHost(ctx context.Context, host string, scopes ...string) context.Context { @@ -124,7 +124,7 @@ func AppendScopesPerHost(ctx context.Context, host string, scopes ...string) con } func GetScopesPerHost(ctx context.Context, host string) []string { - if regMap, ok := ctx.Value(ScopesPerRegistryContextKey{}).(map[string][]string); ok { + if regMap, ok := ctx.Value(scopesPerHostContextKey{}).(map[string][]string); ok { return append([]string(nil), regMap[host]...) } return nil diff --git a/registry/remote/registry.go b/registry/remote/registry.go index 8ae538d9..f0c417e4 100644 --- a/registry/remote/registry.go +++ b/registry/remote/registry.go @@ -127,7 +127,7 @@ func (r *Registry) Ping(ctx context.Context) error { // // Reference: https://docs.docker.com/registry/spec/api/#catalog func (r *Registry) Repositories(ctx context.Context, last string, fn func(repos []string) error) error { - ctx = auth.AppendScopes(ctx, auth.ScopeRegistryCatalog) + ctx = auth.AppendScopesPerHost(ctx, r.Reference.Host(), auth.ScopeRegistryCatalog) url := buildRegistryCatalogURL(r.PlainHTTP, r.Reference) var err error for err == nil { diff --git a/registry/remote/repository.go b/registry/remote/repository.go index fc4f6bf6..65eea1e1 100644 --- a/registry/remote/repository.go +++ b/registry/remote/repository.go @@ -37,7 +37,6 @@ import ( "oras.land/oras-go/v2/internal/cas" "oras.land/oras-go/v2/internal/httputil" "oras.land/oras-go/v2/internal/ioutil" - "oras.land/oras-go/v2/internal/registryutil" "oras.land/oras-go/v2/internal/slices" "oras.land/oras-go/v2/internal/spec" "oras.land/oras-go/v2/internal/syncutil" @@ -392,7 +391,7 @@ func (r *Repository) ParseReference(reference string) (registry.Reference, error // - https://github.com/opencontainers/distribution-spec/blob/v1.1.0-rc3/spec.md#content-discovery // - https://docs.docker.com/registry/spec/api/#tags func (r *Repository) Tags(ctx context.Context, last string, fn func(tags []string) error) error { - ctx = registryutil.WithScopeHint(ctx, r.Reference, auth.ActionPull) + ctx = auth.WithScopeHints(ctx, r.Reference, auth.ActionPull) url := buildRepositoryTagListURL(r.PlainHTTP, r.Reference) var err error for err == nil { @@ -509,7 +508,7 @@ func (r *Repository) Referrers(ctx context.Context, desc ocispec.Descriptor, art func (r *Repository) referrersByAPI(ctx context.Context, desc ocispec.Descriptor, artifactType string, fn func(referrers []ocispec.Descriptor) error) error { ref := r.Reference ref.Reference = desc.Digest.String() - ctx = registryutil.WithScopeHint(ctx, ref, auth.ActionPull) + ctx = auth.WithScopeHints(ctx, ref, auth.ActionPull) url := buildReferrersURL(r.PlainHTTP, ref, artifactType) var err error @@ -643,7 +642,7 @@ func (r *Repository) pingReferrers(ctx context.Context) (bool, error) { ref := r.Reference ref.Reference = zeroDigest - ctx = registryutil.WithScopeHint(ctx, ref, auth.ActionPull) + ctx = auth.WithScopeHints(ctx, ref, auth.ActionPull) url := buildReferrersURL(r.PlainHTTP, ref, "") req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) @@ -677,7 +676,7 @@ func (r *Repository) pingReferrers(ctx context.Context) (bool, error) { func (r *Repository) delete(ctx context.Context, target ocispec.Descriptor, isManifest bool) error { ref := r.Reference ref.Reference = target.Digest.String() - ctx = registryutil.WithScopeHint(ctx, ref, auth.ActionDelete) + ctx = auth.WithScopeHints(ctx, ref, auth.ActionDelete) buildURL := buildRepositoryBlobURL if isManifest { buildURL = buildRepositoryManifestURL @@ -713,7 +712,7 @@ type blobStore struct { func (s *blobStore) Fetch(ctx context.Context, target ocispec.Descriptor) (rc io.ReadCloser, err error) { ref := s.repo.Reference ref.Reference = target.Digest.String() - ctx = registryutil.WithScopeHint(ctx, ref, auth.ActionPull) + ctx = auth.WithScopeHints(ctx, ref, auth.ActionPull) url := buildRepositoryBlobURL(s.repo.PlainHTTP, ref) req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) if err != nil { @@ -755,12 +754,12 @@ func (s *blobStore) Fetch(ctx context.Context, target ocispec.Descriptor) (rc io func (s *blobStore) Mount(ctx context.Context, desc ocispec.Descriptor, fromRepo string, getContent func() (io.ReadCloser, error)) error { // pushing usually requires both pull and push actions. // Reference: https://github.com/distribution/distribution/blob/v2.7.1/registry/handlers/app.go#L921-L930 - ctx = registryutil.WithScopeHint(ctx, s.repo.Reference, auth.ActionPull, auth.ActionPush) + ctx = auth.WithScopeHints(ctx, s.repo.Reference, auth.ActionPull, auth.ActionPush) // We also need pull access to the source repo. fromRef := s.repo.Reference fromRef.Repository = fromRepo - ctx = registryutil.WithScopeHint(ctx, fromRef, auth.ActionPull) + ctx = auth.WithScopeHints(ctx, fromRef, auth.ActionPull) url := buildRepositoryBlobMountURL(s.repo.PlainHTTP, s.repo.Reference, desc.Digest, fromRepo) req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, nil) @@ -833,7 +832,7 @@ func (s *blobStore) Push(ctx context.Context, expected ocispec.Descriptor, conte // start an upload // pushing usually requires both pull and push actions. // Reference: https://github.com/distribution/distribution/blob/v2.7.1/registry/handlers/app.go#L921-L930 - ctx = registryutil.WithScopeHint(ctx, s.repo.Reference, auth.ActionPull, auth.ActionPush) + ctx = auth.WithScopeHints(ctx, s.repo.Reference, auth.ActionPull, auth.ActionPush) url := buildRepositoryBlobUploadURL(s.repo.PlainHTTP, s.repo.Reference) req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, nil) if err != nil { @@ -934,7 +933,7 @@ func (s *blobStore) Resolve(ctx context.Context, reference string) (ocispec.Desc if err != nil { return ocispec.Descriptor{}, err } - ctx = registryutil.WithScopeHint(ctx, ref, auth.ActionPull) + ctx = auth.WithScopeHints(ctx, ref, auth.ActionPull) url := buildRepositoryBlobURL(s.repo.PlainHTTP, ref) req, err := http.NewRequestWithContext(ctx, http.MethodHead, url, nil) if err != nil { @@ -969,7 +968,7 @@ func (s *blobStore) FetchReference(ctx context.Context, reference string) (desc return ocispec.Descriptor{}, nil, err } - ctx = registryutil.WithScopeHint(ctx, ref, auth.ActionPull) + ctx = auth.WithScopeHints(ctx, ref, auth.ActionPull) url := buildRepositoryBlobURL(s.repo.PlainHTTP, ref) req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) if err != nil { @@ -1044,7 +1043,7 @@ type manifestStore struct { func (s *manifestStore) Fetch(ctx context.Context, target ocispec.Descriptor) (rc io.ReadCloser, err error) { ref := s.repo.Reference ref.Reference = target.Digest.String() - ctx = registryutil.WithScopeHint(ctx, ref, auth.ActionPull) + ctx = auth.WithScopeHints(ctx, ref, auth.ActionPull) url := buildRepositoryManifestURL(s.repo.PlainHTTP, ref) req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) if err != nil { @@ -1170,7 +1169,7 @@ func (s *manifestStore) Resolve(ctx context.Context, reference string) (ocispec. if err != nil { return ocispec.Descriptor{}, err } - ctx = registryutil.WithScopeHint(ctx, ref, auth.ActionPull) + ctx = auth.WithScopeHints(ctx, ref, auth.ActionPull) url := buildRepositoryManifestURL(s.repo.PlainHTTP, ref) req, err := http.NewRequestWithContext(ctx, http.MethodHead, url, nil) if err != nil { @@ -1202,7 +1201,7 @@ func (s *manifestStore) FetchReference(ctx context.Context, reference string) (d return ocispec.Descriptor{}, nil, err } - ctx = registryutil.WithScopeHint(ctx, ref, auth.ActionPull) + ctx = auth.WithScopeHints(ctx, ref, auth.ActionPull) url := buildRepositoryManifestURL(s.repo.PlainHTTP, ref) req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) if err != nil { @@ -1245,7 +1244,7 @@ func (s *manifestStore) Tag(ctx context.Context, desc ocispec.Descriptor, refere return err } - ctx = registryutil.WithScopeHint(ctx, ref, auth.ActionPull, auth.ActionPush) + ctx = auth.WithScopeHints(ctx, ref, auth.ActionPull, auth.ActionPush) rc, err := s.Fetch(ctx, desc) if err != nil { return err @@ -1270,7 +1269,7 @@ func (s *manifestStore) push(ctx context.Context, expected ocispec.Descriptor, c ref.Reference = reference // pushing usually requires both pull and push actions. // Reference: https://github.com/distribution/distribution/blob/v2.7.1/registry/handlers/app.go#L921-L930 - ctx = registryutil.WithScopeHint(ctx, ref, auth.ActionPull, auth.ActionPush) + ctx = auth.WithScopeHints(ctx, ref, auth.ActionPull, auth.ActionPush) url := buildRepositoryManifestURL(s.repo.PlainHTTP, ref) // unwrap the content for optimizations of built-in types. body := ioutil.UnwrapNopCloser(content) From fb139e5f8574e7d1a8d08ef927e13e5ce32e0fca Mon Sep 17 00:00:00 2001 From: "Lixia (Sylvia) Lei" Date: Wed, 20 Sep 2023 17:24:14 +0800 Subject: [PATCH 09/29] move + test Signed-off-by: Lixia (Sylvia) Lei --- registry/remote/auth/scope.go | 26 ++++++----- registry/remote/auth/scope_test.go | 70 +++++++++++++++++++++++++++++- 2 files changed, 82 insertions(+), 14 deletions(-) diff --git a/registry/remote/auth/scope.go b/registry/remote/auth/scope.go index 776a18fe..4d24b6e0 100644 --- a/registry/remote/auth/scope.go +++ b/registry/remote/auth/scope.go @@ -20,6 +20,7 @@ import ( "sort" "strings" + "oras.land/oras-go/v2/internal/slices" "oras.land/oras-go/v2/registry" ) @@ -56,11 +57,18 @@ func ScopeRepository(repository string, actions ...string) string { }, ":") } +// WithScopeHints adds a hinted scope to the context. +func WithScopeHints(ctx context.Context, ref registry.Reference, actions ...string) context.Context { + if len(actions) == 0 { + return ctx + } + scope := ScopeRepository(ref.Repository, actions...) + return AppendScopesPerHost(ctx, ref.Host(), scope) +} + // scopesContextKey is the context key for scopes. type scopesContextKey struct{} -type scopesPerHostContextKey struct{} - // WithScopes returns a context with scopes added. Scopes are de-duplicated. // Scopes are used as hints for the auth client to fetch bearer tokens with // larger scopes. @@ -97,11 +105,13 @@ func AppendScopes(ctx context.Context, scopes ...string) context.Context { // GetScopes returns the scopes in the context. func GetScopes(ctx context.Context) []string { if scopes, ok := ctx.Value(scopesContextKey{}).([]string); ok { - return append([]string(nil), scopes...) + return slices.Clone(scopes) } return nil } +type scopesPerHostContextKey struct{} + func WithScopesPerHost(ctx context.Context, host string, scopes ...string) context.Context { var regMap map[string][]string var ok bool @@ -118,25 +128,17 @@ func AppendScopesPerHost(ctx context.Context, host string, scopes ...string) con if len(scopes) == 0 { return ctx } - oldScopes := GetScopesPerHost(ctx, host) return WithScopesPerHost(ctx, host, append(oldScopes, scopes...)...) } func GetScopesPerHost(ctx context.Context, host string) []string { if regMap, ok := ctx.Value(scopesPerHostContextKey{}).(map[string][]string); ok { - return append([]string(nil), regMap[host]...) + return slices.Clone(regMap[host]) } return nil } -// TODO: where to put this? -// WithScopeHints adds a hinted scope to the context. -func WithScopeHints(ctx context.Context, ref registry.Reference, actions ...string) context.Context { - scope := ScopeRepository(ref.Repository, actions...) - return AppendScopesPerHost(ctx, ref.Host(), scope) -} - // CleanScopes merges and sort the actions in ascending order if the scopes have // the same resource type and name. The final scopes are sorted in ascending // order. In other words, the scopes passed in are de-duplicated and sorted. diff --git a/registry/remote/auth/scope_test.go b/registry/remote/auth/scope_test.go index 087e7ec5..97389fe0 100644 --- a/registry/remote/auth/scope_test.go +++ b/registry/remote/auth/scope_test.go @@ -19,6 +19,8 @@ import ( "context" "reflect" "testing" + + "oras.land/oras-go/v2/registry" ) func TestScopeRepository(t *testing.T) { @@ -103,6 +105,70 @@ func TestScopeRepository(t *testing.T) { } } +func TestWithScopeHints(t *testing.T) { + ctx := context.Background() + ref1, err := registry.ParseReference("registry.example.com/foo") + if err != nil { + t.Fatal("registry.ParseReference() error =", err) + } + ref2, err := registry.ParseReference("docker.io/foo") + if err != nil { + t.Fatal("registry.ParseReference() error =", err) + } + + // with single scope + want1 := []string{ + "repository:foo:pull", + } + want2 := []string{ + "repository:foo:push", + } + ctx = WithScopeHints(ctx, ref1, ActionPull) + ctx = WithScopeHints(ctx, ref2, ActionPush) + if got := GetScopesPerHost(ctx, ref1.Host()); !reflect.DeepEqual(got, want1) { + t.Errorf("GetScopesPerRegistry(WithScopeHints()) = %v, want %v", got, want1) + } + if got := GetScopesPerHost(ctx, ref2.Host()); !reflect.DeepEqual(got, want2) { + t.Errorf("GetScopesPerRegistry(WithScopeHints()) = %v, want %v", got, want2) + } + + // with duplicated scopes + scopes1 := []string{ + ActionDelete, + ActionDelete, + ActionPull, + } + want1 = []string{ + "repository:foo:delete,pull", + } + scopes2 := []string{ + ActionPush, + ActionPush, + ActionDelete, + } + want2 = []string{ + "repository:foo:delete,push", + } + ctx = WithScopeHints(ctx, ref1, scopes1...) + ctx = WithScopeHints(ctx, ref2, scopes2...) + if got := GetScopesPerHost(ctx, ref1.Host()); !reflect.DeepEqual(got, want1) { + t.Errorf("GetScopesPerRegistry(WithScopeHints()) = %v, want %v", got, want1) + } + if got := GetScopesPerHost(ctx, ref2.Host()); !reflect.DeepEqual(got, want2) { + t.Errorf("GetScopesPerRegistry(WithScopeHints()) = %v, want %v", got, want2) + } + + // append empty scopes + ctx = WithScopeHints(ctx, ref1) + ctx = WithScopeHints(ctx, ref2) + if got := GetScopesPerHost(ctx, ref1.Host()); !reflect.DeepEqual(got, want1) { + t.Errorf("GetScopesPerRegistry(WithScopeHints()) = %v, want %v", got, want1) + } + if got := GetScopesPerHost(ctx, ref2.Host()); !reflect.DeepEqual(got, want2) { + t.Errorf("GetScopesPerRegistry(WithScopeHints()) = %v, want %v", got, want2) + } +} + func TestWithScopes(t *testing.T) { ctx := context.Background() @@ -184,7 +250,7 @@ func TestAppendScopes(t *testing.T) { } } -func TestWithScopesPerRegistry(t *testing.T) { +func TestWithScopesPerHost(t *testing.T) { ctx := context.Background() reg1 := "registry1.example.com" reg2 := "registry2.example.com" @@ -263,7 +329,7 @@ func TestWithScopesPerRegistry(t *testing.T) { } } -func TestAppendScopesPerRegistry(t *testing.T) { +func TestAppendScopesPerHost(t *testing.T) { ctx := context.Background() reg1 := "registry1.example.com" reg2 := "registry2.example.com" From bced6518613108a031b1ee1fa041869246fdf79b Mon Sep 17 00:00:00 2001 From: "Lixia (Sylvia) Lei" Date: Wed, 20 Sep 2023 18:38:40 +0800 Subject: [PATCH 10/29] WithRepositoryScopes Signed-off-by: Lixia (Sylvia) Lei --- content.go | 4 ++-- registry/remote/auth/scope.go | 4 ++-- registry/remote/auth/scope_test.go | 12 ++++++------ registry/remote/repository.go | 30 +++++++++++++++--------------- 4 files changed, 25 insertions(+), 25 deletions(-) diff --git a/content.go b/content.go index 2d1ff3f7..40647b8e 100644 --- a/content.go +++ b/content.go @@ -90,7 +90,7 @@ func TagN(ctx context.Context, target Target, srcReference string, dstReferences if err != nil { return ocispec.Descriptor{}, err } - ctx = auth.WithScopeHints(ctx, ref, auth.ActionPull, auth.ActionPush) + ctx = auth.WithRepositoryScopes(ctx, ref, auth.ActionPull, auth.ActionPush) } desc, contentBytes, err := FetchBytes(ctx, target, srcReference, FetchBytesOptions{ @@ -148,7 +148,7 @@ func Tag(ctx context.Context, target Target, src, dst string) (ocispec.Descripto if err != nil { return ocispec.Descriptor{}, err } - ctx = auth.WithScopeHints(ctx, ref, auth.ActionPull, auth.ActionPush) + ctx = auth.WithRepositoryScopes(ctx, ref, auth.ActionPull, auth.ActionPush) } desc, rc, err := refFetcher.FetchReference(ctx, src) if err != nil { diff --git a/registry/remote/auth/scope.go b/registry/remote/auth/scope.go index 4d24b6e0..859e22b6 100644 --- a/registry/remote/auth/scope.go +++ b/registry/remote/auth/scope.go @@ -57,8 +57,8 @@ func ScopeRepository(repository string, actions ...string) string { }, ":") } -// WithScopeHints adds a hinted scope to the context. -func WithScopeHints(ctx context.Context, ref registry.Reference, actions ...string) context.Context { +// WithRepositoryScopes adds a hinted scope to the context. +func WithRepositoryScopes(ctx context.Context, ref registry.Reference, actions ...string) context.Context { if len(actions) == 0 { return ctx } diff --git a/registry/remote/auth/scope_test.go b/registry/remote/auth/scope_test.go index 97389fe0..75106bc6 100644 --- a/registry/remote/auth/scope_test.go +++ b/registry/remote/auth/scope_test.go @@ -123,8 +123,8 @@ func TestWithScopeHints(t *testing.T) { want2 := []string{ "repository:foo:push", } - ctx = WithScopeHints(ctx, ref1, ActionPull) - ctx = WithScopeHints(ctx, ref2, ActionPush) + ctx = WithRepositoryScopes(ctx, ref1, ActionPull) + ctx = WithRepositoryScopes(ctx, ref2, ActionPush) if got := GetScopesPerHost(ctx, ref1.Host()); !reflect.DeepEqual(got, want1) { t.Errorf("GetScopesPerRegistry(WithScopeHints()) = %v, want %v", got, want1) } @@ -149,8 +149,8 @@ func TestWithScopeHints(t *testing.T) { want2 = []string{ "repository:foo:delete,push", } - ctx = WithScopeHints(ctx, ref1, scopes1...) - ctx = WithScopeHints(ctx, ref2, scopes2...) + ctx = WithRepositoryScopes(ctx, ref1, scopes1...) + ctx = WithRepositoryScopes(ctx, ref2, scopes2...) if got := GetScopesPerHost(ctx, ref1.Host()); !reflect.DeepEqual(got, want1) { t.Errorf("GetScopesPerRegistry(WithScopeHints()) = %v, want %v", got, want1) } @@ -159,8 +159,8 @@ func TestWithScopeHints(t *testing.T) { } // append empty scopes - ctx = WithScopeHints(ctx, ref1) - ctx = WithScopeHints(ctx, ref2) + ctx = WithRepositoryScopes(ctx, ref1) + ctx = WithRepositoryScopes(ctx, ref2) if got := GetScopesPerHost(ctx, ref1.Host()); !reflect.DeepEqual(got, want1) { t.Errorf("GetScopesPerRegistry(WithScopeHints()) = %v, want %v", got, want1) } diff --git a/registry/remote/repository.go b/registry/remote/repository.go index 65eea1e1..a6ccb3fe 100644 --- a/registry/remote/repository.go +++ b/registry/remote/repository.go @@ -391,7 +391,7 @@ func (r *Repository) ParseReference(reference string) (registry.Reference, error // - https://github.com/opencontainers/distribution-spec/blob/v1.1.0-rc3/spec.md#content-discovery // - https://docs.docker.com/registry/spec/api/#tags func (r *Repository) Tags(ctx context.Context, last string, fn func(tags []string) error) error { - ctx = auth.WithScopeHints(ctx, r.Reference, auth.ActionPull) + ctx = auth.WithRepositoryScopes(ctx, r.Reference, auth.ActionPull) url := buildRepositoryTagListURL(r.PlainHTTP, r.Reference) var err error for err == nil { @@ -508,7 +508,7 @@ func (r *Repository) Referrers(ctx context.Context, desc ocispec.Descriptor, art func (r *Repository) referrersByAPI(ctx context.Context, desc ocispec.Descriptor, artifactType string, fn func(referrers []ocispec.Descriptor) error) error { ref := r.Reference ref.Reference = desc.Digest.String() - ctx = auth.WithScopeHints(ctx, ref, auth.ActionPull) + ctx = auth.WithRepositoryScopes(ctx, ref, auth.ActionPull) url := buildReferrersURL(r.PlainHTTP, ref, artifactType) var err error @@ -642,7 +642,7 @@ func (r *Repository) pingReferrers(ctx context.Context) (bool, error) { ref := r.Reference ref.Reference = zeroDigest - ctx = auth.WithScopeHints(ctx, ref, auth.ActionPull) + ctx = auth.WithRepositoryScopes(ctx, ref, auth.ActionPull) url := buildReferrersURL(r.PlainHTTP, ref, "") req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) @@ -676,7 +676,7 @@ func (r *Repository) pingReferrers(ctx context.Context) (bool, error) { func (r *Repository) delete(ctx context.Context, target ocispec.Descriptor, isManifest bool) error { ref := r.Reference ref.Reference = target.Digest.String() - ctx = auth.WithScopeHints(ctx, ref, auth.ActionDelete) + ctx = auth.WithRepositoryScopes(ctx, ref, auth.ActionDelete) buildURL := buildRepositoryBlobURL if isManifest { buildURL = buildRepositoryManifestURL @@ -712,7 +712,7 @@ type blobStore struct { func (s *blobStore) Fetch(ctx context.Context, target ocispec.Descriptor) (rc io.ReadCloser, err error) { ref := s.repo.Reference ref.Reference = target.Digest.String() - ctx = auth.WithScopeHints(ctx, ref, auth.ActionPull) + ctx = auth.WithRepositoryScopes(ctx, ref, auth.ActionPull) url := buildRepositoryBlobURL(s.repo.PlainHTTP, ref) req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) if err != nil { @@ -754,12 +754,12 @@ func (s *blobStore) Fetch(ctx context.Context, target ocispec.Descriptor) (rc io func (s *blobStore) Mount(ctx context.Context, desc ocispec.Descriptor, fromRepo string, getContent func() (io.ReadCloser, error)) error { // pushing usually requires both pull and push actions. // Reference: https://github.com/distribution/distribution/blob/v2.7.1/registry/handlers/app.go#L921-L930 - ctx = auth.WithScopeHints(ctx, s.repo.Reference, auth.ActionPull, auth.ActionPush) + ctx = auth.WithRepositoryScopes(ctx, s.repo.Reference, auth.ActionPull, auth.ActionPush) // We also need pull access to the source repo. fromRef := s.repo.Reference fromRef.Repository = fromRepo - ctx = auth.WithScopeHints(ctx, fromRef, auth.ActionPull) + ctx = auth.WithRepositoryScopes(ctx, fromRef, auth.ActionPull) url := buildRepositoryBlobMountURL(s.repo.PlainHTTP, s.repo.Reference, desc.Digest, fromRepo) req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, nil) @@ -832,7 +832,7 @@ func (s *blobStore) Push(ctx context.Context, expected ocispec.Descriptor, conte // start an upload // pushing usually requires both pull and push actions. // Reference: https://github.com/distribution/distribution/blob/v2.7.1/registry/handlers/app.go#L921-L930 - ctx = auth.WithScopeHints(ctx, s.repo.Reference, auth.ActionPull, auth.ActionPush) + ctx = auth.WithRepositoryScopes(ctx, s.repo.Reference, auth.ActionPull, auth.ActionPush) url := buildRepositoryBlobUploadURL(s.repo.PlainHTTP, s.repo.Reference) req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, nil) if err != nil { @@ -933,7 +933,7 @@ func (s *blobStore) Resolve(ctx context.Context, reference string) (ocispec.Desc if err != nil { return ocispec.Descriptor{}, err } - ctx = auth.WithScopeHints(ctx, ref, auth.ActionPull) + ctx = auth.WithRepositoryScopes(ctx, ref, auth.ActionPull) url := buildRepositoryBlobURL(s.repo.PlainHTTP, ref) req, err := http.NewRequestWithContext(ctx, http.MethodHead, url, nil) if err != nil { @@ -968,7 +968,7 @@ func (s *blobStore) FetchReference(ctx context.Context, reference string) (desc return ocispec.Descriptor{}, nil, err } - ctx = auth.WithScopeHints(ctx, ref, auth.ActionPull) + ctx = auth.WithRepositoryScopes(ctx, ref, auth.ActionPull) url := buildRepositoryBlobURL(s.repo.PlainHTTP, ref) req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) if err != nil { @@ -1043,7 +1043,7 @@ type manifestStore struct { func (s *manifestStore) Fetch(ctx context.Context, target ocispec.Descriptor) (rc io.ReadCloser, err error) { ref := s.repo.Reference ref.Reference = target.Digest.String() - ctx = auth.WithScopeHints(ctx, ref, auth.ActionPull) + ctx = auth.WithRepositoryScopes(ctx, ref, auth.ActionPull) url := buildRepositoryManifestURL(s.repo.PlainHTTP, ref) req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) if err != nil { @@ -1169,7 +1169,7 @@ func (s *manifestStore) Resolve(ctx context.Context, reference string) (ocispec. if err != nil { return ocispec.Descriptor{}, err } - ctx = auth.WithScopeHints(ctx, ref, auth.ActionPull) + ctx = auth.WithRepositoryScopes(ctx, ref, auth.ActionPull) url := buildRepositoryManifestURL(s.repo.PlainHTTP, ref) req, err := http.NewRequestWithContext(ctx, http.MethodHead, url, nil) if err != nil { @@ -1201,7 +1201,7 @@ func (s *manifestStore) FetchReference(ctx context.Context, reference string) (d return ocispec.Descriptor{}, nil, err } - ctx = auth.WithScopeHints(ctx, ref, auth.ActionPull) + ctx = auth.WithRepositoryScopes(ctx, ref, auth.ActionPull) url := buildRepositoryManifestURL(s.repo.PlainHTTP, ref) req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) if err != nil { @@ -1244,7 +1244,7 @@ func (s *manifestStore) Tag(ctx context.Context, desc ocispec.Descriptor, refere return err } - ctx = auth.WithScopeHints(ctx, ref, auth.ActionPull, auth.ActionPush) + ctx = auth.WithRepositoryScopes(ctx, ref, auth.ActionPull, auth.ActionPush) rc, err := s.Fetch(ctx, desc) if err != nil { return err @@ -1269,7 +1269,7 @@ func (s *manifestStore) push(ctx context.Context, expected ocispec.Descriptor, c ref.Reference = reference // pushing usually requires both pull and push actions. // Reference: https://github.com/distribution/distribution/blob/v2.7.1/registry/handlers/app.go#L921-L930 - ctx = auth.WithScopeHints(ctx, ref, auth.ActionPull, auth.ActionPush) + ctx = auth.WithRepositoryScopes(ctx, ref, auth.ActionPull, auth.ActionPush) url := buildRepositoryManifestURL(s.repo.PlainHTTP, ref) // unwrap the content for optimizations of built-in types. body := ioutil.UnwrapNopCloser(content) From 744e55bb128371b4b2f29141ad7d25a5854cd530 Mon Sep 17 00:00:00 2001 From: "Lixia (Sylvia) Lei" Date: Wed, 20 Sep 2023 19:42:35 +0800 Subject: [PATCH 11/29] 2 client tests Signed-off-by: Lixia (Sylvia) Lei --- registry/remote/auth/client_test.go | 492 ++++++++++++++++++++++++++++ 1 file changed, 492 insertions(+) diff --git a/registry/remote/auth/client_test.go b/registry/remote/auth/client_test.go index 9e5ed69d..2d3fd2d0 100644 --- a/registry/remote/auth/client_test.go +++ b/registry/remote/auth/client_test.go @@ -449,6 +449,205 @@ func TestClient_Do_Bearer_AccessToken_Cached(t *testing.T) { } } +func TestClient_Do_Bearer_AccessToken_Cached_PerHost(t *testing.T) { + as := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + t.Error("unexecuted attempt of authorization service") + w.WriteHeader(http.StatusUnauthorized) + })) + defer as.Close() + // set up server 1 + var requestCount1, wantRequestCount1 int64 + var successCount1, wantSuccessCount1 int64 + var service1 string + scope1 := "repository:test:pull" + accessToken1 := "test/access/token/1" + ts1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt64(&requestCount1, 1) + if r.Method != http.MethodGet || r.URL.Path != "/" { + t.Errorf("unexpected access: %s %s", r.Method, r.URL) + w.WriteHeader(http.StatusNotFound) + return + } + header := "Bearer " + accessToken1 + if auth := r.Header.Get("Authorization"); auth != header { + challenge := fmt.Sprintf("Bearer realm=%q,service=%q,scope=%q", as.URL, service1, scope1) + w.Header().Set("Www-Authenticate", challenge) + w.WriteHeader(http.StatusUnauthorized) + return + } + atomic.AddInt64(&successCount1, 1) + })) + defer ts1.Close() + uri1, err := url.Parse(ts1.URL) + if err != nil { + t.Fatalf("invalid test http server: %v", err) + } + service1 = uri1.Host + client1 := &Client{ + Credential: StaticCredential(uri1.Host, Credential{ + AccessToken: accessToken1, + }), + Cache: NewCache(), + } + + // set up server 2 + var requestCount2, wantRequestCount2 int64 + var successCount2, wantSuccessCount2 int64 + var service2 string + scope2 := "repository:test:pull,push" + accessToken2 := "test/access/token/2" + ts2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt64(&requestCount2, 1) + if r.Method != http.MethodGet || r.URL.Path != "/" { + t.Errorf("unexpected access: %s %s", r.Method, r.URL) + w.WriteHeader(http.StatusNotFound) + return + } + header := "Bearer " + accessToken2 + if auth := r.Header.Get("Authorization"); auth != header { + challenge := fmt.Sprintf("Bearer realm=%q,service=%q,scope=%q", as.URL, service2, scope2) + w.Header().Set("Www-Authenticate", challenge) + w.WriteHeader(http.StatusUnauthorized) + return + } + atomic.AddInt64(&successCount2, 1) + })) + defer ts2.Close() + uri2, err := url.Parse(ts2.URL) + if err != nil { + t.Fatalf("invalid test http server: %v", err) + } + service2 = uri2.Host + client2 := &Client{ + Credential: StaticCredential(uri2.Host, Credential{ + AccessToken: accessToken2, + }), + Cache: NewCache(), + } + + ctx := context.Background() + ctx = WithScopesPerHost(ctx, uri1.Host, scope1) + ctx = WithScopesPerHost(ctx, uri2.Host, scope2) + // first request to server 1 + req1, err := http.NewRequestWithContext(ctx, http.MethodGet, ts1.URL, nil) + if err != nil { + t.Fatalf("failed to create test request: %v", err) + } + resp1, err := client1.Do(req1) + if err != nil { + t.Fatalf("Client.Do() error = %v", err) + } + if resp1.StatusCode != http.StatusOK { + t.Errorf("Client.Do() = %v, want %v", resp1.StatusCode, http.StatusOK) + } + if wantRequestCount1 += 2; requestCount1 != wantRequestCount1 { + t.Errorf("unexpected number of requests: %d, want %d", requestCount1, wantRequestCount1) + } + if wantSuccessCount1++; successCount1 != wantSuccessCount1 { + t.Errorf("unexpected number of successful requests: %d, want %d", successCount1, wantSuccessCount1) + } + // first request to server 2 + req2, err := http.NewRequestWithContext(ctx, http.MethodGet, ts2.URL, nil) + if err != nil { + t.Fatalf("failed to create test request: %v", err) + } + resp2, err := client2.Do(req2) + if err != nil { + t.Fatalf("Client.Do() error = %v", err) + } + if resp2.StatusCode != http.StatusOK { + t.Errorf("Client.Do() = %v, want %v", resp2.StatusCode, http.StatusOK) + } + if wantRequestCount2 += 2; requestCount1 != wantRequestCount2 { + t.Errorf("unexpected number of requests: %d, want %d", requestCount1, wantRequestCount1) + } + if wantSuccessCount2++; successCount2 != wantSuccessCount2 { + t.Errorf("unexpected number of successful requests: %d, want %d", successCount2, wantSuccessCount2) + } + + // repeated request to server 1 + req1, err = http.NewRequestWithContext(ctx, http.MethodGet, ts1.URL, nil) + if err != nil { + t.Fatalf("failed to create test request: %v", err) + } + resp1, err = client1.Do(req1) + if err != nil { + t.Fatalf("Client.Do() error = %v", err) + } + if resp1.StatusCode != http.StatusOK { + t.Errorf("Client.Do() = %v, want %v", resp1.StatusCode, http.StatusOK) + } + if wantRequestCount1++; requestCount1 != wantRequestCount1 { + t.Errorf("unexpected number of requests: %d, want %d", requestCount1, wantRequestCount1) + } + if wantSuccessCount1++; successCount1 != wantSuccessCount1 { + t.Errorf("unexpected number of successful requests: %d, want %d", successCount1, wantSuccessCount1) + } + // repeated request to server 2 + req2, err = http.NewRequestWithContext(ctx, http.MethodGet, ts2.URL, nil) + if err != nil { + t.Fatalf("failed to create test request: %v", err) + } + resp2, err = client2.Do(req2) + if err != nil { + t.Fatalf("Client.Do() error = %v", err) + } + if resp2.StatusCode != http.StatusOK { + t.Errorf("Client.Do() = %v, want %v", resp2.StatusCode, http.StatusOK) + } + if wantRequestCount2++; requestCount2 != wantRequestCount2 { + t.Errorf("unexpected number of requests: %d, want %d", requestCount2, wantRequestCount2) + } + if wantSuccessCount2++; successCount2 != wantSuccessCount2 { + t.Errorf("unexpected number of successful requests: %d, want %d", successCount2, wantSuccessCount2) + } + + // credential change for server 1 + accessToken1 = "test/access/token/1/new" + req1, err = http.NewRequestWithContext(ctx, http.MethodGet, ts1.URL, nil) + if err != nil { + t.Fatalf("failed to create test request: %v", err) + } + client1.Credential = StaticCredential(uri1.Host, Credential{ + AccessToken: accessToken1, + }) + resp1, err = client1.Do(req1) + if err != nil { + t.Fatalf("Client.Do() error = %v", err) + } + if resp1.StatusCode != http.StatusOK { + t.Errorf("Client.Do() = %v, want %v", resp1.StatusCode, http.StatusOK) + } + if wantRequestCount1 += 2; requestCount1 != wantRequestCount1 { + t.Errorf("unexpected number of requests: %d, want %d", requestCount1, wantRequestCount1) + } + if wantSuccessCount1++; successCount1 != wantSuccessCount1 { + t.Errorf("unexpected number of successful requests: %d, want %d", successCount1, wantSuccessCount1) + } + // credential change for server 2 + accessToken2 = "test/access/token/2/new" + req2, err = http.NewRequestWithContext(ctx, http.MethodGet, ts2.URL, nil) + if err != nil { + t.Fatalf("failed to create test request: %v", err) + } + client2.Credential = StaticCredential(uri2.Host, Credential{ + AccessToken: accessToken2, + }) + resp2, err = client2.Do(req2) + if err != nil { + t.Fatalf("Client.Do() error = %v", err) + } + if resp2.StatusCode != http.StatusOK { + t.Errorf("Client.Do() = %v, want %v", resp2.StatusCode, http.StatusOK) + } + if wantRequestCount2 += 2; requestCount2 != wantRequestCount2 { + t.Errorf("unexpected number of requests: %d, want %d", requestCount2, wantRequestCount2) + } + if wantSuccessCount2++; successCount2 != wantSuccessCount2 { + t.Errorf("unexpected number of successful requests: %d, want %d", successCount2, wantSuccessCount2) + } +} + func TestClient_Do_Bearer_Auth(t *testing.T) { username := "test_user" password := "test_password" @@ -725,6 +924,299 @@ func TestClient_Do_Bearer_Auth_Cached(t *testing.T) { } } +func TestClient_Do_Bearer_Auth_Cached_PerHost(t *testing.T) { + // set up server 1 + username1 := "test_user1" + password1 := "test_password1" + accessToken1 := "test/access/token/1" + var requestCount1, wantRequestCount1 int64 + var successCount1, wantSuccessCount1 int64 + var authCount1, wantAuthCount1 int64 + var service1 string + scopes1 := []string{ + "repository:dst:pull,push", + "repository:src:pull", + } + as1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet || r.URL.Path != "/" { + t.Error("unexecuted attempt of authorization service") + w.WriteHeader(http.StatusUnauthorized) + return + } + header := "Basic " + base64.StdEncoding.EncodeToString([]byte(username1+":"+password1)) + if auth := r.Header.Get("Authorization"); auth != header { + t.Errorf("unexpected auth: got %s, want %s", auth, header) + w.WriteHeader(http.StatusUnauthorized) + return + } + if got := r.URL.Query().Get("service"); got != service1 { + t.Errorf("unexpected service: got %s, want %s", got, service1) + w.WriteHeader(http.StatusUnauthorized) + return + } + if got := r.URL.Query()["scope"]; !reflect.DeepEqual(got, scopes1) { + t.Errorf("unexpected scope: got %s, want %s", got, scopes1) + w.WriteHeader(http.StatusUnauthorized) + return + } + + atomic.AddInt64(&authCount1, 1) + if _, err := fmt.Fprintf(w, `{"access_token":%q}`, accessToken1); err != nil { + t.Errorf("failed to write %q: %v", r.URL, err) + } + })) + defer as1.Close() + ts1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt64(&requestCount1, 1) + if r.Method != http.MethodGet || r.URL.Path != "/" { + t.Errorf("unexpected access: %s %s", r.Method, r.URL) + w.WriteHeader(http.StatusNotFound) + return + } + header := "Bearer " + accessToken1 + if auth := r.Header.Get("Authorization"); auth != header { + challenge := fmt.Sprintf("Bearer realm=%q,service=%q,scope=%q", as1.URL, service1, strings.Join(scopes1, " ")) + w.Header().Set("Www-Authenticate", challenge) + w.WriteHeader(http.StatusUnauthorized) + return + } + atomic.AddInt64(&successCount1, 1) + })) + defer ts1.Close() + uri1, err := url.Parse(ts1.URL) + if err != nil { + t.Fatalf("invalid test http server: %v", err) + } + service1 = uri1.Host + client1 := &Client{ + Credential: StaticCredential(uri1.Host, Credential{ + Username: username1, + Password: password1, + }), + Cache: NewCache(), + } + + // set up server 2 + username2 := "test_user2" + password2 := "test_password2" + accessToken2 := "test/access/token/1" + var requestCount2, wantRequestCount2 int64 + var successCount2, wantSuccessCount2 int64 + var authCount2, wantAuthCount2 int64 + var service2 string + scopes2 := []string{ + "repository:dst2:pull,push", + "repository:src2:pull", + } + as2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet || r.URL.Path != "/" { + t.Error("unexecuted attempt of authorization service") + w.WriteHeader(http.StatusUnauthorized) + return + } + header := "Basic " + base64.StdEncoding.EncodeToString([]byte(username2+":"+password2)) + if auth := r.Header.Get("Authorization"); auth != header { + t.Errorf("unexpected auth: got %s, want %s", auth, header) + w.WriteHeader(http.StatusUnauthorized) + return + } + if got := r.URL.Query().Get("service"); got != service2 { + t.Errorf("unexpected service: got %s, want %s", got, service2) + w.WriteHeader(http.StatusUnauthorized) + return + } + if got := r.URL.Query()["scope"]; !reflect.DeepEqual(got, scopes2) { + t.Errorf("unexpected scope: got %s, want %s", got, scopes2) + w.WriteHeader(http.StatusUnauthorized) + return + } + + atomic.AddInt64(&authCount2, 1) + if _, err := fmt.Fprintf(w, `{"access_token":%q}`, accessToken2); err != nil { + t.Errorf("failed to write %q: %v", r.URL, err) + } + })) + defer as1.Close() + ts2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt64(&requestCount2, 1) + if r.Method != http.MethodGet || r.URL.Path != "/" { + t.Errorf("unexpected access: %s %s", r.Method, r.URL) + w.WriteHeader(http.StatusNotFound) + return + } + header := "Bearer " + accessToken2 + if auth := r.Header.Get("Authorization"); auth != header { + challenge := fmt.Sprintf("Bearer realm=%q,service=%q,scope=%q", as2.URL, service2, strings.Join(scopes2, " ")) + w.Header().Set("Www-Authenticate", challenge) + w.WriteHeader(http.StatusUnauthorized) + return + } + atomic.AddInt64(&successCount2, 1) + })) + defer ts2.Close() + uri2, err := url.Parse(ts2.URL) + if err != nil { + t.Fatalf("invalid test http server: %v", err) + } + service2 = uri2.Host + client2 := &Client{ + Credential: StaticCredential(uri2.Host, Credential{ + Username: username2, + Password: password2, + }), + Cache: NewCache(), + } + + ctx := context.Background() + ctx = WithScopesPerHost(ctx, uri1.Host, scopes1...) + ctx = WithScopesPerHost(ctx, uri2.Host, scopes2...) + // first request to server 1 + req1, err := http.NewRequestWithContext(ctx, http.MethodGet, ts1.URL, nil) + if err != nil { + t.Fatalf("failed to create test request: %v", err) + } + resp1, err := client1.Do(req1) + if err != nil { + t.Fatalf("Client.Do() error = %v", err) + } + if resp1.StatusCode != http.StatusOK { + t.Errorf("Client.Do() = %v, want %v", resp1.StatusCode, http.StatusOK) + } + if wantRequestCount1 += 2; requestCount1 != wantRequestCount1 { + t.Errorf("unexpected number of requests: %d, want %d", requestCount1, wantRequestCount1) + } + if wantSuccessCount1++; successCount1 != wantSuccessCount1 { + t.Errorf("unexpected number of successful requests: %d, want %d", successCount1, wantSuccessCount1) + } + if wantAuthCount1++; authCount1 != wantAuthCount1 { + t.Errorf("unexpected number of auth requests: %d, want %d", authCount1, wantAuthCount1) + } + + // first request to server 2 + req2, err := http.NewRequestWithContext(ctx, http.MethodGet, ts2.URL, nil) + if err != nil { + t.Fatalf("failed to create test request: %v", err) + } + resp2, err := client2.Do(req2) + if err != nil { + t.Fatalf("Client.Do() error = %v", err) + } + if resp2.StatusCode != http.StatusOK { + t.Errorf("Client.Do() = %v, want %v", resp2.StatusCode, http.StatusOK) + } + if wantRequestCount2 += 2; requestCount2 != wantRequestCount2 { + t.Errorf("unexpected number of requests: %d, want %d", requestCount2, wantRequestCount2) + } + if wantSuccessCount2++; successCount2 != wantSuccessCount2 { + t.Errorf("unexpected number of successful requests: %d, want %d", successCount2, wantSuccessCount2) + } + if wantAuthCount2++; authCount2 != wantAuthCount2 { + t.Errorf("unexpected number of auth requests: %d, want %d", authCount2, wantAuthCount2) + } + + // repeated request to server 1 + req1, err = http.NewRequestWithContext(ctx, http.MethodGet, ts1.URL, nil) + if err != nil { + t.Fatalf("failed to create test request: %v", err) + } + resp1, err = client1.Do(req1) + if err != nil { + t.Fatalf("Client.Do() error = %v", err) + } + if resp1.StatusCode != http.StatusOK { + t.Errorf("Client.Do() = %v, want %v", resp1.StatusCode, http.StatusOK) + } + if wantRequestCount1++; requestCount1 != wantRequestCount1 { + t.Errorf("unexpected number of requests: %d, want %d", requestCount1, wantRequestCount1) + } + if wantSuccessCount1++; successCount1 != wantSuccessCount1 { + t.Errorf("unexpected number of successful requests: %d, want %d", successCount1, wantSuccessCount1) + } + if authCount1 != wantAuthCount1 { + t.Errorf("unexpected number of auth requests: %d, want %d", authCount1, wantAuthCount1) + } + + // repeated request to server 2 + req2, err = http.NewRequestWithContext(ctx, http.MethodGet, ts2.URL, nil) + if err != nil { + t.Fatalf("failed to create test request: %v", err) + } + resp2, err = client2.Do(req2) + if err != nil { + t.Fatalf("Client.Do() error = %v", err) + } + if resp2.StatusCode != http.StatusOK { + t.Errorf("Client.Do() = %v, want %v", resp2.StatusCode, http.StatusOK) + } + if wantRequestCount2++; requestCount2 != wantRequestCount2 { + t.Errorf("unexpected number of requests: %d, want %d", requestCount2, wantRequestCount2) + } + if wantSuccessCount2++; successCount2 != wantSuccessCount2 { + t.Errorf("unexpected number of successful requests: %d, want %d", successCount2, wantSuccessCount2) + } + if authCount2 != wantAuthCount2 { + t.Errorf("unexpected number of auth requests: %d, want %d", authCount2, wantAuthCount2) + } + + // credential change for server 1 + username1 = "test_user1_new" + password1 = "test_password1_new" + accessToken1 = "test/access/token/1/new" + req1, err = http.NewRequestWithContext(ctx, http.MethodGet, ts1.URL, nil) + if err != nil { + t.Fatalf("failed to create test request: %v", err) + } + client1.Credential = StaticCredential(uri1.Host, Credential{ + Username: username1, + Password: password1, + }) + resp1, err = client1.Do(req1) + if err != nil { + t.Fatalf("Client.Do() error = %v", err) + } + if resp1.StatusCode != http.StatusOK { + t.Errorf("Client.Do() = %v, want %v", resp1.StatusCode, http.StatusOK) + } + if wantRequestCount1 += 2; requestCount1 != wantRequestCount1 { + t.Errorf("unexpected number of requests: %d, want %d", requestCount1, wantRequestCount1) + } + if wantSuccessCount1++; successCount1 != wantSuccessCount1 { + t.Errorf("unexpected number of successful requests: %d, want %d", successCount1, wantSuccessCount1) + } + if wantAuthCount1++; authCount1 != wantAuthCount1 { + t.Errorf("unexpected number of auth requests: %d, want %d", authCount1, wantAuthCount1) + } + + // credential change for server 2 + username2 = "test_user2_new" + password2 = "test_password2_new" + accessToken2 = "test/access/token/2/new" + req2, err = http.NewRequestWithContext(ctx, http.MethodGet, ts2.URL, nil) + if err != nil { + t.Fatalf("failed to create test request: %v", err) + } + client2.Credential = StaticCredential(uri2.Host, Credential{ + Username: username2, + Password: password2, + }) + resp2, err = client2.Do(req2) + if err != nil { + t.Fatalf("Client.Do() error = %v", err) + } + if resp2.StatusCode != http.StatusOK { + t.Errorf("Client.Do() = %v, want %v", resp2.StatusCode, http.StatusOK) + } + if wantRequestCount2 += 2; requestCount2 != wantRequestCount2 { + t.Errorf("unexpected number of requests: %d, want %d", requestCount2, wantRequestCount2) + } + if wantSuccessCount2++; successCount2 != wantSuccessCount2 { + t.Errorf("unexpected number of successful requests: %d, want %d", successCount2, wantSuccessCount2) + } + if wantAuthCount2++; authCount2 != wantAuthCount2 { + t.Errorf("unexpected number of auth requests: %d, want %d", authCount2, wantAuthCount2) + } +} + func TestClient_Do_Bearer_OAuth2_Password(t *testing.T) { username := "test_user" password := "test_password" From 0d2c1d3ef958537b7fd1ed45b28df5c5a8e8f948 Mon Sep 17 00:00:00 2001 From: "Lixia (Sylvia) Lei" Date: Thu, 21 Sep 2023 11:15:54 +0800 Subject: [PATCH 12/29] test oauth password Signed-off-by: Lixia (Sylvia) Lei --- registry/remote/auth/client_test.go | 333 +++++++++++++++++++++++++++- 1 file changed, 330 insertions(+), 3 deletions(-) diff --git a/registry/remote/auth/client_test.go b/registry/remote/auth/client_test.go index 2d3fd2d0..80101f6b 100644 --- a/registry/remote/auth/client_test.go +++ b/registry/remote/auth/client_test.go @@ -934,7 +934,6 @@ func TestClient_Do_Bearer_Auth_Cached_PerHost(t *testing.T) { var authCount1, wantAuthCount1 int64 var service1 string scopes1 := []string{ - "repository:dst:pull,push", "repository:src:pull", } as1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -1005,8 +1004,7 @@ func TestClient_Do_Bearer_Auth_Cached_PerHost(t *testing.T) { var authCount2, wantAuthCount2 int64 var service2 string scopes2 := []string{ - "repository:dst2:pull,push", - "repository:src2:pull", + "repository:dst:pull,push", } as2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodGet || r.URL.Path != "/" { @@ -1535,6 +1533,335 @@ func TestClient_Do_Bearer_OAuth2_Password_Cached(t *testing.T) { } } +func TestClient_Do_Bearer_OAuth2_Password_Cached_PerHost(t *testing.T) { + // set up server 1 + username1 := "test_user1" + password1 := "test_password1" + accessToken1 := "test/access/token/1" + var requestCount1, wantRequestCount1 int64 + var successCount1, wantSuccessCount1 int64 + var authCount1, wantAuthCount1 int64 + var service1 string + scopes1 := []string{ + "repository:src:pull", + } + as1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost || r.URL.Path != "/" { + t.Error("unexecuted attempt of authorization service") + w.WriteHeader(http.StatusUnauthorized) + return + } + if err := r.ParseForm(); err != nil { + t.Errorf("failed to parse form: %v", err) + w.WriteHeader(http.StatusUnauthorized) + return + } + if got := r.PostForm.Get("grant_type"); got != "password" { + t.Errorf("unexpected grant type: %v, want %v", got, "password") + w.WriteHeader(http.StatusUnauthorized) + return + } + if got := r.PostForm.Get("service"); got != service1 { + t.Errorf("unexpected service: %v, want %v", got, service1) + w.WriteHeader(http.StatusUnauthorized) + return + } + if got := r.PostForm.Get("client_id"); got != defaultClientID { + t.Errorf("unexpected client id: %v, want %v", got, defaultClientID) + w.WriteHeader(http.StatusUnauthorized) + return + } + scope := strings.Join(scopes1, " ") + if got := r.PostForm.Get("scope"); got != scope { + t.Errorf("unexpected scope: %v, want %v", got, scope) + w.WriteHeader(http.StatusUnauthorized) + return + } + if got := r.PostForm.Get("username"); got != username1 { + t.Errorf("unexpected username: %v, want %v", got, username1) + w.WriteHeader(http.StatusUnauthorized) + return + } + if got := r.PostForm.Get("password"); got != password1 { + t.Errorf("unexpected password: %v, want %v", got, password1) + w.WriteHeader(http.StatusUnauthorized) + return + } + + atomic.AddInt64(&authCount1, 1) + if _, err := fmt.Fprintf(w, `{"access_token":%q}`, accessToken1); err != nil { + t.Errorf("failed to write %q: %v", r.URL, err) + } + })) + defer as1.Close() + ts1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt64(&requestCount1, 1) + if r.Method != http.MethodGet || r.URL.Path != "/" { + t.Errorf("unexpected access: %s %s", r.Method, r.URL) + w.WriteHeader(http.StatusNotFound) + return + } + header := "Bearer " + accessToken1 + if auth := r.Header.Get("Authorization"); auth != header { + challenge := fmt.Sprintf("Bearer realm=%q,service=%q,scope=%q", as1.URL, service1, strings.Join(scopes1, " ")) + w.Header().Set("Www-Authenticate", challenge) + w.WriteHeader(http.StatusUnauthorized) + return + } + atomic.AddInt64(&successCount1, 1) + })) + defer ts1.Close() + uri1, err := url.Parse(ts1.URL) + if err != nil { + t.Fatalf("invalid test http server: %v", err) + } + service1 = uri1.Host + client1 := &Client{ + Credential: StaticCredential(uri1.Host, Credential{ + Username: username1, + Password: password1, + }), + ForceAttemptOAuth2: true, + Cache: NewCache(), + } + // set up server 2 + username2 := "test_user2" + password2 := "test_password2" + accessToken2 := "test/access/token/2" + var requestCount2, wantRequestCount2 int64 + var successCount2, wantSuccessCount2 int64 + var authCount2, wantAuthCount2 int64 + var service2 string + scopes2 := []string{ + "repository:dst:pull,push", + } + as2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost || r.URL.Path != "/" { + t.Error("unexecuted attempt of authorization service") + w.WriteHeader(http.StatusUnauthorized) + return + } + if err := r.ParseForm(); err != nil { + t.Errorf("failed to parse form: %v", err) + w.WriteHeader(http.StatusUnauthorized) + return + } + if got := r.PostForm.Get("grant_type"); got != "password" { + t.Errorf("unexpected grant type: %v, want %v", got, "password") + w.WriteHeader(http.StatusUnauthorized) + return + } + if got := r.PostForm.Get("service"); got != service2 { + t.Errorf("unexpected service: %v, want %v", got, service2) + w.WriteHeader(http.StatusUnauthorized) + return + } + if got := r.PostForm.Get("client_id"); got != defaultClientID { + t.Errorf("unexpected client id: %v, want %v", got, defaultClientID) + w.WriteHeader(http.StatusUnauthorized) + return + } + scope := strings.Join(scopes2, " ") + if got := r.PostForm.Get("scope"); got != scope { + t.Errorf("unexpected scope: %v, want %v", got, scope) + w.WriteHeader(http.StatusUnauthorized) + return + } + if got := r.PostForm.Get("username"); got != username2 { + t.Errorf("unexpected username: %v, want %v", got, username2) + w.WriteHeader(http.StatusUnauthorized) + return + } + if got := r.PostForm.Get("password"); got != password2 { + t.Errorf("unexpected password: %v, want %v", got, password2) + w.WriteHeader(http.StatusUnauthorized) + return + } + + atomic.AddInt64(&authCount2, 1) + if _, err := fmt.Fprintf(w, `{"access_token":%q}`, accessToken2); err != nil { + t.Errorf("failed to write %q: %v", r.URL, err) + } + })) + defer as2.Close() + ts2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt64(&requestCount2, 1) + if r.Method != http.MethodGet || r.URL.Path != "/" { + t.Errorf("unexpected access: %s %s", r.Method, r.URL) + w.WriteHeader(http.StatusNotFound) + return + } + header := "Bearer " + accessToken2 + if auth := r.Header.Get("Authorization"); auth != header { + challenge := fmt.Sprintf("Bearer realm=%q,service=%q,scope=%q", as2.URL, service2, strings.Join(scopes2, " ")) + w.Header().Set("Www-Authenticate", challenge) + w.WriteHeader(http.StatusUnauthorized) + return + } + atomic.AddInt64(&successCount2, 1) + })) + defer ts2.Close() + uri2, err := url.Parse(ts2.URL) + if err != nil { + t.Fatalf("invalid test http server: %v", err) + } + service2 = uri2.Host + client2 := &Client{ + Credential: StaticCredential(uri2.Host, Credential{ + Username: username2, + Password: password2, + }), + ForceAttemptOAuth2: true, + Cache: NewCache(), + } + + ctx := context.Background() + ctx = WithScopesPerHost(ctx, uri1.Host, scopes1...) + ctx = WithScopesPerHost(ctx, uri2.Host, scopes2...) + // first request to server 1 + req1, err := http.NewRequestWithContext(ctx, http.MethodGet, ts1.URL, nil) + if err != nil { + t.Fatalf("failed to create test request: %v", err) + } + resp1, err := client1.Do(req1) + if err != nil { + t.Fatalf("Client.Do() error = %v", err) + } + if resp1.StatusCode != http.StatusOK { + t.Errorf("Client.Do() = %v, want %v", resp1.StatusCode, http.StatusOK) + } + if wantRequestCount1 += 2; requestCount1 != wantRequestCount1 { + t.Errorf("unexpected number of requests: %d, want %d", requestCount1, wantRequestCount1) + } + if wantSuccessCount1++; successCount1 != wantSuccessCount1 { + t.Errorf("unexpected number of successful requests: %d, want %d", successCount1, wantSuccessCount1) + } + if wantAuthCount1++; authCount1 != wantAuthCount1 { + t.Errorf("unexpected number of auth requests: %d, want %d", authCount1, wantAuthCount1) + } + // first request to server 2 + req2, err := http.NewRequestWithContext(ctx, http.MethodGet, ts2.URL, nil) + if err != nil { + t.Fatalf("failed to create test request: %v", err) + } + resp2, err := client2.Do(req2) + if err != nil { + t.Fatalf("Client.Do() error = %v", err) + } + if resp2.StatusCode != http.StatusOK { + t.Errorf("Client.Do() = %v, want %v", resp2.StatusCode, http.StatusOK) + } + if wantRequestCount2 += 2; requestCount2 != wantRequestCount2 { + t.Errorf("unexpected number of requests: %d, want %d", requestCount2, wantRequestCount2) + } + if wantSuccessCount2++; successCount2 != wantSuccessCount2 { + t.Errorf("unexpected number of successful requests: %d, want %d", successCount2, wantSuccessCount2) + } + if wantAuthCount2++; authCount2 != wantAuthCount2 { + t.Errorf("unexpected number of auth requests: %d, want %d", authCount2, wantAuthCount2) + } + + // repeated request to server 1 + req1, err = http.NewRequestWithContext(ctx, http.MethodGet, ts1.URL, nil) + if err != nil { + t.Fatalf("failed to create test request: %v", err) + } + resp1, err = client1.Do(req1) + if err != nil { + t.Fatalf("Client.Do() error = %v", err) + } + if resp1.StatusCode != http.StatusOK { + t.Errorf("Client.Do() = %v, want %v", resp1.StatusCode, http.StatusOK) + } + if wantRequestCount1++; requestCount1 != wantRequestCount1 { + t.Errorf("unexpected number of requests: %d, want %d", requestCount1, wantRequestCount1) + } + if wantSuccessCount1++; successCount1 != wantSuccessCount1 { + t.Errorf("unexpected number of successful requests: %d, want %d", successCount1, wantSuccessCount1) + } + if authCount1 != wantAuthCount1 { + t.Errorf("unexpected number of auth requests: %d, want %d", authCount1, wantAuthCount1) + } + // repeated request to server 2 + req2, err = http.NewRequestWithContext(ctx, http.MethodGet, ts2.URL, nil) + if err != nil { + t.Fatalf("failed to create test request: %v", err) + } + resp2, err = client2.Do(req2) + if err != nil { + t.Fatalf("Client.Do() error = %v", err) + } + if resp2.StatusCode != http.StatusOK { + t.Errorf("Client.Do() = %v, want %v", resp2.StatusCode, http.StatusOK) + } + if wantRequestCount2++; requestCount2 != wantRequestCount2 { + t.Errorf("unexpected number of requests: %d, want %d", requestCount2, wantRequestCount2) + } + if wantSuccessCount2++; successCount2 != wantSuccessCount2 { + t.Errorf("unexpected number of successful requests: %d, want %d", successCount2, wantSuccessCount2) + } + if authCount2 != wantAuthCount2 { + t.Errorf("unexpected number of auth requests: %d, want %d", authCount2, wantAuthCount2) + } + + // credential change for server 1 + username1 = "test_user1_new" + password1 = "test_password1_new" + accessToken1 = "test/access/token/1/new" + req1, err = http.NewRequestWithContext(ctx, http.MethodGet, ts1.URL, nil) + if err != nil { + t.Fatalf("failed to create test request: %v", err) + } + client1.Credential = StaticCredential(uri1.Host, Credential{ + Username: username1, + Password: password1, + }) + resp1, err = client1.Do(req1) + if err != nil { + t.Fatalf("Client.Do() error = %v", err) + } + if resp1.StatusCode != http.StatusOK { + t.Errorf("Client.Do() = %v, want %v", resp1.StatusCode, http.StatusOK) + } + if wantRequestCount1 += 2; requestCount1 != wantRequestCount1 { + t.Errorf("unexpected number of requests: %d, want %d", requestCount1, wantRequestCount1) + } + if wantSuccessCount1++; successCount1 != wantSuccessCount1 { + t.Errorf("unexpected number of successful requests: %d, want %d", successCount1, wantSuccessCount1) + } + if wantAuthCount1++; authCount1 != wantAuthCount1 { + t.Errorf("unexpected number of auth requests: %d, want %d", authCount1, wantAuthCount1) + } + // credential change for server 2 + username2 = "test_user2_new" + password2 = "test_password2_new" + accessToken2 = "test/access/token/2/new" + req2, err = http.NewRequestWithContext(ctx, http.MethodGet, ts2.URL, nil) + if err != nil { + t.Fatalf("failed to create test request: %v", err) + } + client2.Credential = StaticCredential(uri2.Host, Credential{ + Username: username2, + Password: password2, + }) + resp2, err = client2.Do(req2) + if err != nil { + t.Fatalf("Client.Do() error = %v", err) + } + if resp2.StatusCode != http.StatusOK { + t.Errorf("Client.Do() = %v, want %v", resp2.StatusCode, http.StatusOK) + } + if wantRequestCount2 += 2; requestCount2 != wantRequestCount2 { + t.Errorf("unexpected number of requests: %d, want %d", requestCount2, wantRequestCount2) + } + if wantSuccessCount2++; successCount2 != wantSuccessCount2 { + t.Errorf("unexpected number of successful requests: %d, want %d", successCount2, wantSuccessCount2) + } + if wantAuthCount2++; authCount2 != wantAuthCount2 { + t.Errorf("unexpected number of auth requests: %d, want %d", authCount2, wantAuthCount2) + } +} + func TestClient_Do_Bearer_OAuth2_RefreshToken(t *testing.T) { refreshToken := "test/refresh/token" accessToken := "test/access/token" From 3325e2690f577ec7c10143c09ca743ebbfad6657 Mon Sep 17 00:00:00 2001 From: "Lixia (Sylvia) Lei" Date: Thu, 21 Sep 2023 12:42:57 +0800 Subject: [PATCH 13/29] test refresh token Signed-off-by: Lixia (Sylvia) Lei --- registry/remote/auth/client_test.go | 311 ++++++++++++++++++++++++++++ 1 file changed, 311 insertions(+) diff --git a/registry/remote/auth/client_test.go b/registry/remote/auth/client_test.go index 80101f6b..9e4ffade 100644 --- a/registry/remote/auth/client_test.go +++ b/registry/remote/auth/client_test.go @@ -2162,6 +2162,317 @@ func TestClient_Do_Bearer_OAuth2_RefreshToken_Cached(t *testing.T) { } } +func TestClient_Do_Bearer_OAuth2_RefreshToken_Cached_PerHost(t *testing.T) { + // set up server 1 + refreshToken1 := "test/refresh/token/1" + accessToken1 := "test/access/token/1" + var requestCount1, wantRequestCount1 int64 + var successCount1, wantSuccessCount1 int64 + var authCount1, wantAuthCount1 int64 + var service1 string + scopes1 := []string{ + "repository:src:pull", + } + as1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost || r.URL.Path != "/" { + t.Error("unexecuted attempt of authorization service") + w.WriteHeader(http.StatusUnauthorized) + return + } + if err := r.ParseForm(); err != nil { + t.Errorf("failed to parse form: %v", err) + w.WriteHeader(http.StatusUnauthorized) + return + } + if got := r.PostForm.Get("grant_type"); got != "refresh_token" { + t.Errorf("unexpected grant type: %v, want %v", got, "refresh_token") + w.WriteHeader(http.StatusUnauthorized) + return + } + if got := r.PostForm.Get("service"); got != service1 { + t.Errorf("unexpected service: %v, want %v", got, service1) + w.WriteHeader(http.StatusUnauthorized) + return + } + if got := r.PostForm.Get("client_id"); got != defaultClientID { + t.Errorf("unexpected client id: %v, want %v", got, defaultClientID) + w.WriteHeader(http.StatusUnauthorized) + return + } + scope := strings.Join(scopes1, " ") + if got := r.PostForm.Get("scope"); got != scope { + t.Errorf("unexpected scope: %v, want %v", got, scope) + w.WriteHeader(http.StatusUnauthorized) + return + } + if got := r.PostForm.Get("refresh_token"); got != refreshToken1 { + t.Errorf("unexpected refresh token: %v, want %v", got, refreshToken1) + w.WriteHeader(http.StatusUnauthorized) + return + } + + atomic.AddInt64(&authCount1, 1) + if _, err := fmt.Fprintf(w, `{"access_token":%q}`, accessToken1); err != nil { + t.Errorf("failed to write %q: %v", r.URL, err) + } + })) + defer as1.Close() + ts1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt64(&requestCount1, 1) + if r.Method != http.MethodGet || r.URL.Path != "/" { + t.Errorf("unexpected access: %s %s", r.Method, r.URL) + w.WriteHeader(http.StatusNotFound) + return + } + header := "Bearer " + accessToken1 + if auth := r.Header.Get("Authorization"); auth != header { + challenge := fmt.Sprintf("Bearer realm=%q,service=%q,scope=%q", as1.URL, service1, strings.Join(scopes1, " ")) + w.Header().Set("Www-Authenticate", challenge) + w.WriteHeader(http.StatusUnauthorized) + return + } + atomic.AddInt64(&successCount1, 1) + })) + defer ts1.Close() + uri1, err := url.Parse(ts1.URL) + if err != nil { + t.Fatalf("invalid test http server: %v", err) + } + service1 = uri1.Host + client1 := &Client{ + Credential: StaticCredential(uri1.Host, Credential{ + RefreshToken: refreshToken1, + }), + Cache: NewCache(), + } + + // set up server 2 + refreshToken2 := "test/refresh/token/1" + accessToken2 := "test/access/token/1" + var requestCount2, wantRequestCount2 int64 + var successCount2, wantSuccessCount2 int64 + var authCount2, wantAuthCount2 int64 + var service2 string + scopes2 := []string{ + "repository:dst:pull,push", + } + as2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost || r.URL.Path != "/" { + t.Error("unexecuted attempt of authorization service") + w.WriteHeader(http.StatusUnauthorized) + return + } + if err := r.ParseForm(); err != nil { + t.Errorf("failed to parse form: %v", err) + w.WriteHeader(http.StatusUnauthorized) + return + } + if got := r.PostForm.Get("grant_type"); got != "refresh_token" { + t.Errorf("unexpected grant type: %v, want %v", got, "refresh_token") + w.WriteHeader(http.StatusUnauthorized) + return + } + if got := r.PostForm.Get("service"); got != service2 { + t.Errorf("unexpected service: %v, want %v", got, service2) + w.WriteHeader(http.StatusUnauthorized) + return + } + if got := r.PostForm.Get("client_id"); got != defaultClientID { + t.Errorf("unexpected client id: %v, want %v", got, defaultClientID) + w.WriteHeader(http.StatusUnauthorized) + return + } + scope := strings.Join(scopes2, " ") + if got := r.PostForm.Get("scope"); got != scope { + t.Errorf("unexpected scope: %v, want %v", got, scope) + w.WriteHeader(http.StatusUnauthorized) + return + } + if got := r.PostForm.Get("refresh_token"); got != refreshToken2 { + t.Errorf("unexpected refresh token: %v, want %v", got, refreshToken2) + w.WriteHeader(http.StatusUnauthorized) + return + } + + atomic.AddInt64(&authCount2, 1) + if _, err := fmt.Fprintf(w, `{"access_token":%q}`, accessToken2); err != nil { + t.Errorf("failed to write %q: %v", r.URL, err) + } + })) + defer as2.Close() + ts2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt64(&requestCount2, 1) + if r.Method != http.MethodGet || r.URL.Path != "/" { + t.Errorf("unexpected access: %s %s", r.Method, r.URL) + w.WriteHeader(http.StatusNotFound) + return + } + header := "Bearer " + accessToken2 + if auth := r.Header.Get("Authorization"); auth != header { + challenge := fmt.Sprintf("Bearer realm=%q,service=%q,scope=%q", as2.URL, service2, strings.Join(scopes2, " ")) + w.Header().Set("Www-Authenticate", challenge) + w.WriteHeader(http.StatusUnauthorized) + return + } + atomic.AddInt64(&successCount2, 1) + })) + defer ts2.Close() + uri2, err := url.Parse(ts2.URL) + if err != nil { + t.Fatalf("invalid test http server: %v", err) + } + service2 = uri2.Host + client2 := &Client{ + Credential: StaticCredential(uri2.Host, Credential{ + RefreshToken: refreshToken2, + }), + Cache: NewCache(), + } + + ctx := context.Background() + ctx = WithScopesPerHost(ctx, uri1.Host, scopes1...) + ctx = WithScopesPerHost(ctx, uri2.Host, scopes2...) + // first request to server 1 + req1, err := http.NewRequestWithContext(ctx, http.MethodGet, ts1.URL, nil) + if err != nil { + t.Fatalf("failed to create test request: %v", err) + } + resp1, err := client1.Do(req1) + if err != nil { + t.Fatalf("Client.Do() error = %v", err) + } + if resp1.StatusCode != http.StatusOK { + t.Errorf("Client.Do() = %v, want %v", resp1.StatusCode, http.StatusOK) + } + if wantRequestCount1 += 2; requestCount1 != wantRequestCount1 { + t.Errorf("unexpected number of requests: %d, want %d", requestCount1, wantRequestCount1) + } + if wantSuccessCount1++; successCount1 != wantSuccessCount1 { + t.Errorf("unexpected number of successful requests: %d, want %d", successCount1, wantSuccessCount1) + } + if wantAuthCount1++; authCount1 != wantAuthCount1 { + t.Errorf("unexpected number of auth requests: %d, want %d", authCount1, wantAuthCount1) + } + + // first request to server 2 + req2, err := http.NewRequestWithContext(ctx, http.MethodGet, ts2.URL, nil) + if err != nil { + t.Fatalf("failed to create test request: %v", err) + } + resp2, err := client2.Do(req2) + if err != nil { + t.Fatalf("Client.Do() error = %v", err) + } + if resp2.StatusCode != http.StatusOK { + t.Errorf("Client.Do() = %v, want %v", resp2.StatusCode, http.StatusOK) + } + if wantRequestCount2 += 2; requestCount2 != wantRequestCount2 { + t.Errorf("unexpected number of requests: %d, want %d", requestCount2, wantRequestCount2) + } + if wantSuccessCount2++; successCount2 != wantSuccessCount2 { + t.Errorf("unexpected number of successful requests: %d, want %d", successCount2, wantSuccessCount2) + } + if wantAuthCount2++; authCount2 != wantAuthCount2 { + t.Errorf("unexpected number of auth requests: %d, want %d", authCount2, wantAuthCount2) + } + + // repeated request to server 1 + req1, err = http.NewRequestWithContext(ctx, http.MethodGet, ts1.URL, nil) + if err != nil { + t.Fatalf("failed to create test request: %v", err) + } + resp1, err = client1.Do(req1) + if err != nil { + t.Fatalf("Client.Do() error = %v", err) + } + if resp1.StatusCode != http.StatusOK { + t.Errorf("Client.Do() = %v, want %v", resp1.StatusCode, http.StatusOK) + } + if wantRequestCount1++; requestCount1 != wantRequestCount1 { + t.Errorf("unexpected number of requests: %d, want %d", requestCount1, wantRequestCount1) + } + if wantSuccessCount1++; successCount1 != wantSuccessCount1 { + t.Errorf("unexpected number of successful requests: %d, want %d", successCount1, wantSuccessCount1) + } + if authCount1 != wantAuthCount1 { + t.Errorf("unexpected number of auth requests: %d, want %d", authCount1, wantAuthCount1) + } + // repeated request to server 2 + req2, err = http.NewRequestWithContext(ctx, http.MethodGet, ts2.URL, nil) + if err != nil { + t.Fatalf("failed to create test request: %v", err) + } + resp2, err = client2.Do(req2) + if err != nil { + t.Fatalf("Client.Do() error = %v", err) + } + if resp2.StatusCode != http.StatusOK { + t.Errorf("Client.Do() = %v, want %v", resp2.StatusCode, http.StatusOK) + } + if wantRequestCount2++; requestCount2 != wantRequestCount2 { + t.Errorf("unexpected number of requests: %d, want %d", requestCount2, wantRequestCount2) + } + if wantSuccessCount2++; successCount2 != wantSuccessCount2 { + t.Errorf("unexpected number of successful requests: %d, want %d", successCount2, wantSuccessCount2) + } + if authCount2 != wantAuthCount2 { + t.Errorf("unexpected number of auth requests: %d, want %d", authCount2, wantAuthCount2) + } + + // credential change to server 1 + refreshToken1 = "test/refresh/token/1/new" + accessToken1 = "test/access/token/1/new" + req1, err = http.NewRequestWithContext(ctx, http.MethodGet, ts1.URL, nil) + if err != nil { + t.Fatalf("failed to create test request: %v", err) + } + client1.Credential = StaticCredential(uri1.Host, Credential{ + RefreshToken: refreshToken1, + }) + resp1, err = client1.Do(req1) + if err != nil { + t.Fatalf("Client.Do() error = %v", err) + } + if resp1.StatusCode != http.StatusOK { + t.Errorf("Client.Do() = %v, want %v", resp1.StatusCode, http.StatusOK) + } + if wantRequestCount1 += 2; requestCount1 != wantRequestCount1 { + t.Errorf("unexpected number of requests: %d, want %d", requestCount1, wantRequestCount1) + } + if wantSuccessCount1++; successCount1 != wantSuccessCount1 { + t.Errorf("unexpected number of successful requests: %d, want %d", successCount1, wantSuccessCount1) + } + if wantAuthCount1++; authCount1 != wantAuthCount1 { + t.Errorf("unexpected number of auth requests: %d, want %d", authCount1, wantAuthCount1) + } + // credential change to server 2 + refreshToken2 = "test/refresh/token/2/new" + accessToken2 = "test/access/token/2/new" + req2, err = http.NewRequestWithContext(ctx, http.MethodGet, ts2.URL, nil) + if err != nil { + t.Fatalf("failed to create test request: %v", err) + } + client2.Credential = StaticCredential(uri2.Host, Credential{ + RefreshToken: refreshToken2, + }) + resp2, err = client2.Do(req2) + if err != nil { + t.Fatalf("Client.Do() error = %v", err) + } + if resp2.StatusCode != http.StatusOK { + t.Errorf("Client.Do() = %v, want %v", resp2.StatusCode, http.StatusOK) + } + if wantRequestCount2 += 2; requestCount2 != wantRequestCount2 { + t.Errorf("unexpected number of requests: %d, want %d", requestCount2, wantRequestCount2) + } + if wantSuccessCount2++; successCount2 != wantSuccessCount2 { + t.Errorf("unexpected number of successful requests: %d, want %d", successCount2, wantSuccessCount2) + } + if wantAuthCount2++; authCount2 != wantAuthCount2 { + t.Errorf("unexpected number of auth requests: %d, want %d", authCount2, wantAuthCount2) + } +} + func TestClient_Do_Token_Expire(t *testing.T) { refreshToken := "test/refresh/token" accessToken := "test/access/token" From 6a98c3cf2e858b8e9627b4582a8aaf252df50aa0 Mon Sep 17 00:00:00 2001 From: "Lixia (Sylvia) Lei" Date: Thu, 21 Sep 2023 12:58:45 +0800 Subject: [PATCH 14/29] test expire Signed-off-by: Lixia (Sylvia) Lei --- registry/remote/auth/client_test.go | 259 ++++++++++++++++++++++++++++ 1 file changed, 259 insertions(+) diff --git a/registry/remote/auth/client_test.go b/registry/remote/auth/client_test.go index 9e4ffade..d42a75a5 100644 --- a/registry/remote/auth/client_test.go +++ b/registry/remote/auth/client_test.go @@ -2612,6 +2612,265 @@ func TestClient_Do_Token_Expire(t *testing.T) { } } +func TestClient_Do_Token_Expire_PerHost(t *testing.T) { + // set up server 1 + refreshToken1 := "test/refresh/token/1" + accessToken1 := "test/access/token/1" + var requestCount1, wantRequestCount1 int64 + var successCount1, wantSuccessCount1 int64 + var authCount1, wantAuthCount1 int64 + var service1 string + scopes1 := []string{ + "repository:src:pull", + } + as1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost || r.URL.Path != "/" { + t.Error("unexecuted attempt of authorization service") + w.WriteHeader(http.StatusUnauthorized) + return + } + if err := r.ParseForm(); err != nil { + t.Errorf("failed to parse form: %v", err) + w.WriteHeader(http.StatusUnauthorized) + return + } + if got := r.PostForm.Get("grant_type"); got != "refresh_token" { + t.Errorf("unexpected grant type: %v, want %v", got, "refresh_token") + w.WriteHeader(http.StatusUnauthorized) + return + } + if got := r.PostForm.Get("service"); got != service1 { + t.Errorf("unexpected service: %v, want %v", got, service1) + w.WriteHeader(http.StatusUnauthorized) + return + } + if got := r.PostForm.Get("client_id"); got != defaultClientID { + t.Errorf("unexpected client id: %v, want %v", got, defaultClientID) + w.WriteHeader(http.StatusUnauthorized) + return + } + scope := strings.Join(scopes1, " ") + if got := r.PostForm.Get("scope"); got != scope { + t.Errorf("unexpected scope: %v, want %v", got, scope) + w.WriteHeader(http.StatusUnauthorized) + return + } + if got := r.PostForm.Get("refresh_token"); got != refreshToken1 { + t.Errorf("unexpected refresh token: %v, want %v", got, refreshToken1) + w.WriteHeader(http.StatusUnauthorized) + return + } + + atomic.AddInt64(&authCount1, 1) + if _, err := fmt.Fprintf(w, `{"access_token":%q}`, accessToken1); err != nil { + t.Errorf("failed to write %q: %v", r.URL, err) + } + })) + defer as1.Close() + ts1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt64(&requestCount1, 1) + if r.Method != http.MethodGet || r.URL.Path != "/" { + t.Errorf("unexpected access: %s %s", r.Method, r.URL) + w.WriteHeader(http.StatusNotFound) + return + } + header := "Bearer " + accessToken1 + if auth := r.Header.Get("Authorization"); auth != header { + challenge := fmt.Sprintf("Bearer realm=%q,service=%q,scope=%q", as1.URL, service1, strings.Join(scopes1, " ")) + w.Header().Set("Www-Authenticate", challenge) + w.WriteHeader(http.StatusUnauthorized) + return + } + atomic.AddInt64(&successCount1, 1) + })) + defer ts1.Close() + uri1, err := url.Parse(ts1.URL) + if err != nil { + t.Fatalf("invalid test http server: %v", err) + } + service1 = uri1.Host + client1 := &Client{ + Credential: StaticCredential(uri1.Host, Credential{ + RefreshToken: refreshToken1, + }), + Cache: NewCache(), + } + // set up server 2 + refreshToken2 := "test/refresh/token/2" + accessToken2 := "test/access/token/2" + var requestCount2, wantRequestCount2 int64 + var successCount2, wantSuccessCount2 int64 + var authCount2, wantAuthCount2 int64 + var service2 string + scopes2 := []string{ + "repository:dst:pull,push", + } + as2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost || r.URL.Path != "/" { + t.Error("unexecuted attempt of authorization service") + w.WriteHeader(http.StatusUnauthorized) + return + } + if err := r.ParseForm(); err != nil { + t.Errorf("failed to parse form: %v", err) + w.WriteHeader(http.StatusUnauthorized) + return + } + if got := r.PostForm.Get("grant_type"); got != "refresh_token" { + t.Errorf("unexpected grant type: %v, want %v", got, "refresh_token") + w.WriteHeader(http.StatusUnauthorized) + return + } + if got := r.PostForm.Get("service"); got != service2 { + t.Errorf("unexpected service: %v, want %v", got, service2) + w.WriteHeader(http.StatusUnauthorized) + return + } + if got := r.PostForm.Get("client_id"); got != defaultClientID { + t.Errorf("unexpected client id: %v, want %v", got, defaultClientID) + w.WriteHeader(http.StatusUnauthorized) + return + } + scope := strings.Join(scopes2, " ") + if got := r.PostForm.Get("scope"); got != scope { + t.Errorf("unexpected scope: %v, want %v", got, scope) + w.WriteHeader(http.StatusUnauthorized) + return + } + if got := r.PostForm.Get("refresh_token"); got != refreshToken2 { + t.Errorf("unexpected refresh token: %v, want %v", got, refreshToken2) + w.WriteHeader(http.StatusUnauthorized) + return + } + + atomic.AddInt64(&authCount2, 1) + if _, err := fmt.Fprintf(w, `{"access_token":%q}`, accessToken2); err != nil { + t.Errorf("failed to write %q: %v", r.URL, err) + } + })) + defer as2.Close() + ts2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt64(&requestCount2, 1) + if r.Method != http.MethodGet || r.URL.Path != "/" { + t.Errorf("unexpected access: %s %s", r.Method, r.URL) + w.WriteHeader(http.StatusNotFound) + return + } + header := "Bearer " + accessToken2 + if auth := r.Header.Get("Authorization"); auth != header { + challenge := fmt.Sprintf("Bearer realm=%q,service=%q,scope=%q", as2.URL, service2, strings.Join(scopes2, " ")) + w.Header().Set("Www-Authenticate", challenge) + w.WriteHeader(http.StatusUnauthorized) + return + } + atomic.AddInt64(&successCount2, 1) + })) + defer ts2.Close() + uri2, err := url.Parse(ts2.URL) + if err != nil { + t.Fatalf("invalid test http server: %v", err) + } + service2 = uri2.Host + client2 := &Client{ + Credential: StaticCredential(uri2.Host, Credential{ + RefreshToken: refreshToken2, + }), + Cache: NewCache(), + } + + ctx := context.Background() + ctx = WithScopesPerHost(ctx, uri1.Host, scopes1...) + ctx = WithScopesPerHost(ctx, uri2.Host, scopes2...) + // first request to server 1 + req1, err := http.NewRequestWithContext(ctx, http.MethodGet, ts1.URL, nil) + if err != nil { + t.Fatalf("failed to create test request: %v", err) + } + resp1, err := client1.Do(req1) + if err != nil { + t.Fatalf("Client.Do() error = %v", err) + } + if resp1.StatusCode != http.StatusOK { + t.Errorf("Client.Do() = %v, want %v", resp1.StatusCode, http.StatusOK) + } + if wantRequestCount1 += 2; requestCount1 != wantRequestCount1 { + t.Errorf("unexpected number of requests: %d, want %d", requestCount1, wantRequestCount1) + } + if wantSuccessCount1++; successCount1 != wantSuccessCount1 { + t.Errorf("unexpected number of successful requests: %d, want %d", successCount1, wantSuccessCount1) + } + if wantAuthCount1++; authCount1 != wantAuthCount1 { + t.Errorf("unexpected number of auth requests: %d, want %d", authCount1, wantAuthCount1) + } + + // first request to server 2 + req2, err := http.NewRequestWithContext(ctx, http.MethodGet, ts2.URL, nil) + if err != nil { + t.Fatalf("failed to create test request: %v", err) + } + resp2, err := client2.Do(req2) + if err != nil { + t.Fatalf("Client.Do() error = %v", err) + } + if resp2.StatusCode != http.StatusOK { + t.Errorf("Client.Do() = %v, want %v", resp2.StatusCode, http.StatusOK) + } + if wantRequestCount2 += 2; requestCount2 != wantRequestCount2 { + t.Errorf("unexpected number of requests: %d, want %d", requestCount2, wantRequestCount2) + } + if wantSuccessCount2++; successCount2 != wantSuccessCount2 { + t.Errorf("unexpected number of successful requests: %d, want %d", successCount2, wantSuccessCount2) + } + if wantAuthCount2++; authCount2 != wantAuthCount2 { + t.Errorf("unexpected number of auth requests: %d, want %d", authCount2, wantAuthCount2) + } + + // invalidate the access token and request again to server 1 + accessToken1 = "test/access/token/1/new" + req1, err = http.NewRequestWithContext(ctx, http.MethodGet, ts1.URL, nil) + if err != nil { + t.Fatalf("failed to create test request: %v", err) + } + resp1, err = client1.Do(req1) + if err != nil { + t.Fatalf("Client.Do() error = %v", err) + } + if resp1.StatusCode != http.StatusOK { + t.Errorf("Client.Do() = %v, want %v", resp1.StatusCode, http.StatusOK) + } + if wantRequestCount1 += 2; requestCount1 != wantRequestCount1 { + t.Errorf("unexpected number of requests: %d, want %d", requestCount1, wantRequestCount1) + } + if wantSuccessCount1++; successCount1 != wantSuccessCount1 { + t.Errorf("unexpected number of successful requests: %d, want %d", successCount1, wantSuccessCount1) + } + if wantAuthCount1++; authCount1 != wantAuthCount1 { + t.Errorf("unexpected number of auth requests: %d, want %d", authCount1, wantAuthCount1) + } + // invalidate the access token and request again to server 2 + accessToken2 = "test/access/token/2/new" + req2, err = http.NewRequestWithContext(ctx, http.MethodGet, ts2.URL, nil) + if err != nil { + t.Fatalf("failed to create test request: %v", err) + } + resp2, err = client2.Do(req2) + if err != nil { + t.Fatalf("Client.Do() error = %v", err) + } + if resp2.StatusCode != http.StatusOK { + t.Errorf("Client.Do() = %v, want %v", resp2.StatusCode, http.StatusOK) + } + if wantRequestCount2 += 2; requestCount2 != wantRequestCount2 { + t.Errorf("unexpected number of requests: %d, want %d", requestCount2, wantRequestCount2) + } + if wantSuccessCount2++; successCount2 != wantSuccessCount2 { + t.Errorf("unexpected number of successful requests: %d, want %d", successCount2, wantSuccessCount2) + } + if wantAuthCount2++; authCount2 != wantAuthCount2 { + t.Errorf("unexpected number of auth requests: %d, want %d", authCount2, wantAuthCount2) + } +} + func TestClient_Do_Scope_Hint_Mismatch(t *testing.T) { username := "test_user" password := "test_password" From f7180dc490475a64d5b823d935fa0ea6f204f16d Mon Sep 17 00:00:00 2001 From: "Lixia (Sylvia) Lei" Date: Thu, 21 Sep 2023 14:16:32 +0800 Subject: [PATCH 15/29] client tests Signed-off-by: Lixia (Sylvia) Lei --- registry/remote/auth/client_test.go | 287 ++++++++++++++++++++++++++++ 1 file changed, 287 insertions(+) diff --git a/registry/remote/auth/client_test.go b/registry/remote/auth/client_test.go index d42a75a5..74db0dba 100644 --- a/registry/remote/auth/client_test.go +++ b/registry/remote/auth/client_test.go @@ -3022,6 +3022,293 @@ func TestClient_Do_Scope_Hint_Mismatch(t *testing.T) { } } +func TestClient_Do_Scope_Hint_Mismatch_PerHost(t *testing.T) { + // set up server 1 + username1 := "test_user1" + password1 := "test_password1" + accessToken1 := "test/access/token/1" + var requestCount1, wantRequestCount1 int64 + var successCount1, wantSuccessCount1 int64 + var authCount1, wantAuthCount1 int64 + var service1 string + scopes1 := []string{ + "repository:dst:pull,push", + "repository:src:pull", + } + scope1 := "repository:test1:delete" + as1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost || r.URL.Path != "/" { + t.Error("unexecuted attempt of authorization service") + w.WriteHeader(http.StatusUnauthorized) + return + } + if err := r.ParseForm(); err != nil { + t.Errorf("failed to parse form: %v", err) + w.WriteHeader(http.StatusUnauthorized) + return + } + if got := r.PostForm.Get("grant_type"); got != "password" { + t.Errorf("unexpected grant type: %v, want %v", got, "password") + w.WriteHeader(http.StatusUnauthorized) + return + } + if got := r.PostForm.Get("service"); got != service1 { + t.Errorf("unexpected service: %v, want %v", got, service1) + w.WriteHeader(http.StatusUnauthorized) + return + } + if got := r.PostForm.Get("client_id"); got != defaultClientID { + t.Errorf("unexpected client id: %v, want %v", got, defaultClientID) + w.WriteHeader(http.StatusUnauthorized) + return + } + scopes := CleanScopes(append([]string{scope1}, scopes1...)) + scope := strings.Join(scopes, " ") + if got := r.PostForm.Get("scope"); got != scope { + t.Errorf("unexpected scope: %v, want %v", got, scope) + w.WriteHeader(http.StatusUnauthorized) + return + } + if got := r.PostForm.Get("username"); got != username1 { + t.Errorf("unexpected username: %v, want %v", got, username1) + w.WriteHeader(http.StatusUnauthorized) + return + } + if got := r.PostForm.Get("password"); got != password1 { + t.Errorf("unexpected password: %v, want %v", got, password1) + w.WriteHeader(http.StatusUnauthorized) + return + } + + atomic.AddInt64(&authCount1, 1) + if _, err := fmt.Fprintf(w, `{"access_token":%q}`, accessToken1); err != nil { + t.Errorf("failed to write %q: %v", r.URL, err) + } + })) + defer as1.Close() + ts1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt64(&requestCount1, 1) + if r.Method != http.MethodGet || r.URL.Path != "/" { + t.Errorf("unexpected access: %s %s", r.Method, r.URL) + w.WriteHeader(http.StatusNotFound) + return + } + header := "Bearer " + accessToken1 + if auth := r.Header.Get("Authorization"); auth != header { + challenge := fmt.Sprintf("Bearer realm=%q,service=%q,scope=%q", as1.URL, service1, scope1) + w.Header().Set("Www-Authenticate", challenge) + w.WriteHeader(http.StatusUnauthorized) + return + } + atomic.AddInt64(&successCount1, 1) + })) + defer ts1.Close() + uri1, err := url.Parse(ts1.URL) + if err != nil { + t.Fatalf("invalid test http server: %v", err) + } + service1 = uri1.Host + client1 := &Client{ + Credential: StaticCredential(uri1.Host, Credential{ + Username: username1, + Password: password1, + }), + ForceAttemptOAuth2: true, + Cache: NewCache(), + } + + // set up server 1 + username2 := "test_user2" + password2 := "test_password2" + accessToken2 := "test/access/token/2" + var requestCount2, wantRequestCount2 int64 + var successCount2, wantSuccessCount2 int64 + var authCount2, wantAuthCount2 int64 + var service2 string + scopes2 := []string{ + "repository:dst:pull,push", + "repository:src:pull", + } + scope2 := "repository:test2:delete" + as2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost || r.URL.Path != "/" { + t.Error("unexecuted attempt of authorization service") + w.WriteHeader(http.StatusUnauthorized) + return + } + if err := r.ParseForm(); err != nil { + t.Errorf("failed to parse form: %v", err) + w.WriteHeader(http.StatusUnauthorized) + return + } + if got := r.PostForm.Get("grant_type"); got != "password" { + t.Errorf("unexpected grant type: %v, want %v", got, "password") + w.WriteHeader(http.StatusUnauthorized) + return + } + if got := r.PostForm.Get("service"); got != service2 { + t.Errorf("unexpected service: %v, want %v", got, service2) + w.WriteHeader(http.StatusUnauthorized) + return + } + if got := r.PostForm.Get("client_id"); got != defaultClientID { + t.Errorf("unexpected client id: %v, want %v", got, defaultClientID) + w.WriteHeader(http.StatusUnauthorized) + return + } + scopes := CleanScopes(append([]string{scope2}, scopes2...)) + scope := strings.Join(scopes, " ") + if got := r.PostForm.Get("scope"); got != scope { + t.Errorf("unexpected scope: %v, want %v", got, scope) + w.WriteHeader(http.StatusUnauthorized) + return + } + if got := r.PostForm.Get("username"); got != username2 { + t.Errorf("unexpected username: %v, want %v", got, username2) + w.WriteHeader(http.StatusUnauthorized) + return + } + if got := r.PostForm.Get("password"); got != password2 { + t.Errorf("unexpected password: %v, want %v", got, password2) + w.WriteHeader(http.StatusUnauthorized) + return + } + + atomic.AddInt64(&authCount2, 1) + if _, err := fmt.Fprintf(w, `{"access_token":%q}`, accessToken2); err != nil { + t.Errorf("failed to write %q: %v", r.URL, err) + } + })) + defer as2.Close() + ts2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt64(&requestCount2, 1) + if r.Method != http.MethodGet || r.URL.Path != "/" { + t.Errorf("unexpected access: %s %s", r.Method, r.URL) + w.WriteHeader(http.StatusNotFound) + return + } + header := "Bearer " + accessToken2 + if auth := r.Header.Get("Authorization"); auth != header { + challenge := fmt.Sprintf("Bearer realm=%q,service=%q,scope=%q", as2.URL, service2, scope2) + w.Header().Set("Www-Authenticate", challenge) + w.WriteHeader(http.StatusUnauthorized) + return + } + atomic.AddInt64(&successCount2, 1) + })) + defer ts1.Close() + uri2, err := url.Parse(ts2.URL) + if err != nil { + t.Fatalf("invalid test http server: %v", err) + } + service2 = uri2.Host + client2 := &Client{ + Credential: StaticCredential(uri2.Host, Credential{ + Username: username2, + Password: password2, + }), + ForceAttemptOAuth2: true, + Cache: NewCache(), + } + + ctx := context.Background() + ctx = WithScopesPerHost(ctx, uri1.Host, scopes1...) + ctx = WithScopesPerHost(ctx, uri2.Host, scopes2...) + // first request to server 1 + req1, err := http.NewRequestWithContext(ctx, http.MethodGet, ts1.URL, nil) + if err != nil { + t.Fatalf("failed to create test request: %v", err) + } + resp1, err := client1.Do(req1) + if err != nil { + t.Fatalf("Client.Do() error = %v", err) + } + if resp1.StatusCode != http.StatusOK { + t.Errorf("Client.Do() = %v, want %v", resp1.StatusCode, http.StatusOK) + } + if wantRequestCount1 += 2; requestCount1 != wantRequestCount1 { + t.Errorf("unexpected number of requests: %d, want %d", requestCount1, wantRequestCount1) + } + if wantSuccessCount1++; successCount1 != wantSuccessCount1 { + t.Errorf("unexpected number of successful requests: %d, want %d", successCount1, wantSuccessCount1) + } + if wantAuthCount1++; authCount1 != wantAuthCount1 { + t.Errorf("unexpected number of auth requests: %d, want %d", authCount1, wantAuthCount1) + } + + // first request to server 1 + req2, err := http.NewRequestWithContext(ctx, http.MethodGet, ts2.URL, nil) + if err != nil { + t.Fatalf("failed to create test request: %v", err) + } + resp2, err := client2.Do(req2) + if err != nil { + t.Fatalf("Client.Do() error = %v", err) + } + if resp2.StatusCode != http.StatusOK { + t.Errorf("Client.Do() = %v, want %v", resp2.StatusCode, http.StatusOK) + } + if wantRequestCount2 += 2; requestCount2 != wantRequestCount2 { + t.Errorf("unexpected number of requests: %d, want %d", requestCount2, wantRequestCount2) + } + if wantSuccessCount2++; successCount2 != wantSuccessCount2 { + t.Errorf("unexpected number of successful requests: %d, want %d", successCount2, wantSuccessCount2) + } + if wantAuthCount2++; authCount2 != wantAuthCount2 { + t.Errorf("unexpected number of auth requests: %d, want %d", authCount2, wantAuthCount2) + } + + // repeated request to server 1 + // although the actual scope does not match the hinted scopes, the client + // with cache cannot avoid a request to obtain a challenge but can prevent + // a repeated call to the authorization server. + req1, err = http.NewRequestWithContext(ctx, http.MethodGet, ts1.URL, nil) + if err != nil { + t.Fatalf("failed to create test request: %v", err) + } + resp1, err = client1.Do(req1) + if err != nil { + t.Fatalf("Client.Do() error = %v", err) + } + if resp1.StatusCode != http.StatusOK { + t.Errorf("Client.Do() = %v, want %v", resp1.StatusCode, http.StatusOK) + } + if wantRequestCount1 += 2; requestCount1 != wantRequestCount1 { + t.Errorf("unexpected number of requests: %d, want %d", requestCount1, wantRequestCount1) + } + if wantSuccessCount1++; successCount1 != wantSuccessCount1 { + t.Errorf("unexpected number of successful requests: %d, want %d", successCount1, wantSuccessCount1) + } + if authCount1 != wantAuthCount1 { + t.Errorf("unexpected number of auth requests: %d, want %d", authCount1, wantAuthCount1) + } + + // repeated request to server 2 + // although the actual scope does not match the hinted scopes, the client + // with cache cannot avoid a request to obtain a challenge but can prevent + // a repeated call to the authorization server. + req2, err = http.NewRequestWithContext(ctx, http.MethodGet, ts2.URL, nil) + if err != nil { + t.Fatalf("failed to create test request: %v", err) + } + resp2, err = client2.Do(req2) + if err != nil { + t.Fatalf("Client.Do() error = %v", err) + } + if resp2.StatusCode != http.StatusOK { + t.Errorf("Client.Do() = %v, want %v", resp2.StatusCode, http.StatusOK) + } + if wantRequestCount2 += 2; requestCount2 != wantRequestCount2 { + t.Errorf("unexpected number of requests: %d, want %d", requestCount2, wantRequestCount2) + } + if wantSuccessCount2++; successCount2 != wantSuccessCount2 { + t.Errorf("unexpected number of successful requests: %d, want %d", successCount2, wantSuccessCount2) + } + if authCount2 != wantAuthCount2 { + t.Errorf("unexpected number of auth requests: %d, want %d", authCount2, wantAuthCount2) + } +} + func TestClient_Do_Invalid_Credential_Basic(t *testing.T) { username := "test_user" password := "test_password" From 448ae21e1ace5c825ccd89a451bf4e7677ea67ae Mon Sep 17 00:00:00 2001 From: "Lixia (Sylvia) Lei" Date: Thu, 21 Sep 2023 14:42:16 +0800 Subject: [PATCH 16/29] add docs Signed-off-by: Lixia (Sylvia) Lei --- content.go | 4 ++-- registry/remote/auth/scope.go | 31 ++++++++++++++++++++++++++++-- registry/remote/auth/scope_test.go | 12 ++++++------ registry/remote/repository.go | 30 ++++++++++++++--------------- 4 files changed, 52 insertions(+), 25 deletions(-) diff --git a/content.go b/content.go index 40647b8e..92794969 100644 --- a/content.go +++ b/content.go @@ -90,7 +90,7 @@ func TagN(ctx context.Context, target Target, srcReference string, dstReferences if err != nil { return ocispec.Descriptor{}, err } - ctx = auth.WithRepositoryScopes(ctx, ref, auth.ActionPull, auth.ActionPush) + ctx = auth.AppendScopeHints(ctx, ref, auth.ActionPull, auth.ActionPush) } desc, contentBytes, err := FetchBytes(ctx, target, srcReference, FetchBytesOptions{ @@ -148,7 +148,7 @@ func Tag(ctx context.Context, target Target, src, dst string) (ocispec.Descripto if err != nil { return ocispec.Descriptor{}, err } - ctx = auth.WithRepositoryScopes(ctx, ref, auth.ActionPull, auth.ActionPush) + ctx = auth.AppendScopeHints(ctx, ref, auth.ActionPull, auth.ActionPush) } desc, rc, err := refFetcher.FetchReference(ctx, src) if err != nil { diff --git a/registry/remote/auth/scope.go b/registry/remote/auth/scope.go index 859e22b6..cc060266 100644 --- a/registry/remote/auth/scope.go +++ b/registry/remote/auth/scope.go @@ -57,8 +57,10 @@ func ScopeRepository(repository string, actions ...string) string { }, ":") } -// WithRepositoryScopes adds a hinted scope to the context. -func WithRepositoryScopes(ctx context.Context, ref registry.Reference, actions ...string) context.Context { +// AppendScopeHints appends a repository scope with the given actions +// to the existing scopes in the context for the given registry and returns +// a new context. +func AppendScopeHints(ctx context.Context, ref registry.Reference, actions ...string) context.Context { if len(actions) == 0 { return ctx } @@ -110,8 +112,28 @@ func GetScopes(ctx context.Context) []string { return nil } +// scopesPerHostContextKey is the context key for per-host scopes. type scopesPerHostContextKey struct{} +// WithScopesPerHost returns a context with per-host scopes added. +// Scopes are de-duplicated. +// Scopes are used as hints for the auth client to fetch bearer tokens with +// larger scopes. +// +// For example, uploading blob to the repository "hello-world" does HEAD request +// first then POST and PUT. The HEAD request will return a challenge for scope +// `repository:hello-world:pull`, and the auth client will fetch a token for +// that challenge. Later, the POST request will return a challenge for scope +// `repository:hello-world:push`, and the auth client will fetch a token for +// that challenge again. By invoking `WithScopes()` with the scope +// `repository:hello-world:pull,push`, the auth client with cache is hinted to +// fetch a token via a single token fetch request for all the HEAD, POST, PUT +// requests. +// +// Passing an empty list of scopes will virtually remove the scope hints in the +// context for the given host. +// +// Reference: https://docs.docker.com/registry/spec/auth/scope/ func WithScopesPerHost(ctx context.Context, host string, scopes ...string) context.Context { var regMap map[string][]string var ok bool @@ -124,6 +146,10 @@ func WithScopesPerHost(ctx context.Context, host string, scopes ...string) conte return context.WithValue(ctx, scopesPerHostContextKey{}, regMap) } +// AppendScopesPerHost appends additional scopes to the existing scopes +// in the context for the given host and returns a new context. +// The resulted scopes are de-duplicated. +// The append operation does modify the existing scope in the context passed in. func AppendScopesPerHost(ctx context.Context, host string, scopes ...string) context.Context { if len(scopes) == 0 { return ctx @@ -132,6 +158,7 @@ func AppendScopesPerHost(ctx context.Context, host string, scopes ...string) con return WithScopesPerHost(ctx, host, append(oldScopes, scopes...)...) } +// GetScopesPerHost returns the scopes in the context for the given host. func GetScopesPerHost(ctx context.Context, host string) []string { if regMap, ok := ctx.Value(scopesPerHostContextKey{}).(map[string][]string); ok { return slices.Clone(regMap[host]) diff --git a/registry/remote/auth/scope_test.go b/registry/remote/auth/scope_test.go index 75106bc6..44c07f79 100644 --- a/registry/remote/auth/scope_test.go +++ b/registry/remote/auth/scope_test.go @@ -123,8 +123,8 @@ func TestWithScopeHints(t *testing.T) { want2 := []string{ "repository:foo:push", } - ctx = WithRepositoryScopes(ctx, ref1, ActionPull) - ctx = WithRepositoryScopes(ctx, ref2, ActionPush) + ctx = AppendScopeHints(ctx, ref1, ActionPull) + ctx = AppendScopeHints(ctx, ref2, ActionPush) if got := GetScopesPerHost(ctx, ref1.Host()); !reflect.DeepEqual(got, want1) { t.Errorf("GetScopesPerRegistry(WithScopeHints()) = %v, want %v", got, want1) } @@ -149,8 +149,8 @@ func TestWithScopeHints(t *testing.T) { want2 = []string{ "repository:foo:delete,push", } - ctx = WithRepositoryScopes(ctx, ref1, scopes1...) - ctx = WithRepositoryScopes(ctx, ref2, scopes2...) + ctx = AppendScopeHints(ctx, ref1, scopes1...) + ctx = AppendScopeHints(ctx, ref2, scopes2...) if got := GetScopesPerHost(ctx, ref1.Host()); !reflect.DeepEqual(got, want1) { t.Errorf("GetScopesPerRegistry(WithScopeHints()) = %v, want %v", got, want1) } @@ -159,8 +159,8 @@ func TestWithScopeHints(t *testing.T) { } // append empty scopes - ctx = WithRepositoryScopes(ctx, ref1) - ctx = WithRepositoryScopes(ctx, ref2) + ctx = AppendScopeHints(ctx, ref1) + ctx = AppendScopeHints(ctx, ref2) if got := GetScopesPerHost(ctx, ref1.Host()); !reflect.DeepEqual(got, want1) { t.Errorf("GetScopesPerRegistry(WithScopeHints()) = %v, want %v", got, want1) } diff --git a/registry/remote/repository.go b/registry/remote/repository.go index a6ccb3fe..0e157068 100644 --- a/registry/remote/repository.go +++ b/registry/remote/repository.go @@ -391,7 +391,7 @@ func (r *Repository) ParseReference(reference string) (registry.Reference, error // - https://github.com/opencontainers/distribution-spec/blob/v1.1.0-rc3/spec.md#content-discovery // - https://docs.docker.com/registry/spec/api/#tags func (r *Repository) Tags(ctx context.Context, last string, fn func(tags []string) error) error { - ctx = auth.WithRepositoryScopes(ctx, r.Reference, auth.ActionPull) + ctx = auth.AppendScopeHints(ctx, r.Reference, auth.ActionPull) url := buildRepositoryTagListURL(r.PlainHTTP, r.Reference) var err error for err == nil { @@ -508,7 +508,7 @@ func (r *Repository) Referrers(ctx context.Context, desc ocispec.Descriptor, art func (r *Repository) referrersByAPI(ctx context.Context, desc ocispec.Descriptor, artifactType string, fn func(referrers []ocispec.Descriptor) error) error { ref := r.Reference ref.Reference = desc.Digest.String() - ctx = auth.WithRepositoryScopes(ctx, ref, auth.ActionPull) + ctx = auth.AppendScopeHints(ctx, ref, auth.ActionPull) url := buildReferrersURL(r.PlainHTTP, ref, artifactType) var err error @@ -642,7 +642,7 @@ func (r *Repository) pingReferrers(ctx context.Context) (bool, error) { ref := r.Reference ref.Reference = zeroDigest - ctx = auth.WithRepositoryScopes(ctx, ref, auth.ActionPull) + ctx = auth.AppendScopeHints(ctx, ref, auth.ActionPull) url := buildReferrersURL(r.PlainHTTP, ref, "") req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) @@ -676,7 +676,7 @@ func (r *Repository) pingReferrers(ctx context.Context) (bool, error) { func (r *Repository) delete(ctx context.Context, target ocispec.Descriptor, isManifest bool) error { ref := r.Reference ref.Reference = target.Digest.String() - ctx = auth.WithRepositoryScopes(ctx, ref, auth.ActionDelete) + ctx = auth.AppendScopeHints(ctx, ref, auth.ActionDelete) buildURL := buildRepositoryBlobURL if isManifest { buildURL = buildRepositoryManifestURL @@ -712,7 +712,7 @@ type blobStore struct { func (s *blobStore) Fetch(ctx context.Context, target ocispec.Descriptor) (rc io.ReadCloser, err error) { ref := s.repo.Reference ref.Reference = target.Digest.String() - ctx = auth.WithRepositoryScopes(ctx, ref, auth.ActionPull) + ctx = auth.AppendScopeHints(ctx, ref, auth.ActionPull) url := buildRepositoryBlobURL(s.repo.PlainHTTP, ref) req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) if err != nil { @@ -754,12 +754,12 @@ func (s *blobStore) Fetch(ctx context.Context, target ocispec.Descriptor) (rc io func (s *blobStore) Mount(ctx context.Context, desc ocispec.Descriptor, fromRepo string, getContent func() (io.ReadCloser, error)) error { // pushing usually requires both pull and push actions. // Reference: https://github.com/distribution/distribution/blob/v2.7.1/registry/handlers/app.go#L921-L930 - ctx = auth.WithRepositoryScopes(ctx, s.repo.Reference, auth.ActionPull, auth.ActionPush) + ctx = auth.AppendScopeHints(ctx, s.repo.Reference, auth.ActionPull, auth.ActionPush) // We also need pull access to the source repo. fromRef := s.repo.Reference fromRef.Repository = fromRepo - ctx = auth.WithRepositoryScopes(ctx, fromRef, auth.ActionPull) + ctx = auth.AppendScopeHints(ctx, fromRef, auth.ActionPull) url := buildRepositoryBlobMountURL(s.repo.PlainHTTP, s.repo.Reference, desc.Digest, fromRepo) req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, nil) @@ -832,7 +832,7 @@ func (s *blobStore) Push(ctx context.Context, expected ocispec.Descriptor, conte // start an upload // pushing usually requires both pull and push actions. // Reference: https://github.com/distribution/distribution/blob/v2.7.1/registry/handlers/app.go#L921-L930 - ctx = auth.WithRepositoryScopes(ctx, s.repo.Reference, auth.ActionPull, auth.ActionPush) + ctx = auth.AppendScopeHints(ctx, s.repo.Reference, auth.ActionPull, auth.ActionPush) url := buildRepositoryBlobUploadURL(s.repo.PlainHTTP, s.repo.Reference) req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, nil) if err != nil { @@ -933,7 +933,7 @@ func (s *blobStore) Resolve(ctx context.Context, reference string) (ocispec.Desc if err != nil { return ocispec.Descriptor{}, err } - ctx = auth.WithRepositoryScopes(ctx, ref, auth.ActionPull) + ctx = auth.AppendScopeHints(ctx, ref, auth.ActionPull) url := buildRepositoryBlobURL(s.repo.PlainHTTP, ref) req, err := http.NewRequestWithContext(ctx, http.MethodHead, url, nil) if err != nil { @@ -968,7 +968,7 @@ func (s *blobStore) FetchReference(ctx context.Context, reference string) (desc return ocispec.Descriptor{}, nil, err } - ctx = auth.WithRepositoryScopes(ctx, ref, auth.ActionPull) + ctx = auth.AppendScopeHints(ctx, ref, auth.ActionPull) url := buildRepositoryBlobURL(s.repo.PlainHTTP, ref) req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) if err != nil { @@ -1043,7 +1043,7 @@ type manifestStore struct { func (s *manifestStore) Fetch(ctx context.Context, target ocispec.Descriptor) (rc io.ReadCloser, err error) { ref := s.repo.Reference ref.Reference = target.Digest.String() - ctx = auth.WithRepositoryScopes(ctx, ref, auth.ActionPull) + ctx = auth.AppendScopeHints(ctx, ref, auth.ActionPull) url := buildRepositoryManifestURL(s.repo.PlainHTTP, ref) req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) if err != nil { @@ -1169,7 +1169,7 @@ func (s *manifestStore) Resolve(ctx context.Context, reference string) (ocispec. if err != nil { return ocispec.Descriptor{}, err } - ctx = auth.WithRepositoryScopes(ctx, ref, auth.ActionPull) + ctx = auth.AppendScopeHints(ctx, ref, auth.ActionPull) url := buildRepositoryManifestURL(s.repo.PlainHTTP, ref) req, err := http.NewRequestWithContext(ctx, http.MethodHead, url, nil) if err != nil { @@ -1201,7 +1201,7 @@ func (s *manifestStore) FetchReference(ctx context.Context, reference string) (d return ocispec.Descriptor{}, nil, err } - ctx = auth.WithRepositoryScopes(ctx, ref, auth.ActionPull) + ctx = auth.AppendScopeHints(ctx, ref, auth.ActionPull) url := buildRepositoryManifestURL(s.repo.PlainHTTP, ref) req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) if err != nil { @@ -1244,7 +1244,7 @@ func (s *manifestStore) Tag(ctx context.Context, desc ocispec.Descriptor, refere return err } - ctx = auth.WithRepositoryScopes(ctx, ref, auth.ActionPull, auth.ActionPush) + ctx = auth.AppendScopeHints(ctx, ref, auth.ActionPull, auth.ActionPush) rc, err := s.Fetch(ctx, desc) if err != nil { return err @@ -1269,7 +1269,7 @@ func (s *manifestStore) push(ctx context.Context, expected ocispec.Descriptor, c ref.Reference = reference // pushing usually requires both pull and push actions. // Reference: https://github.com/distribution/distribution/blob/v2.7.1/registry/handlers/app.go#L921-L930 - ctx = auth.WithRepositoryScopes(ctx, ref, auth.ActionPull, auth.ActionPush) + ctx = auth.AppendScopeHints(ctx, ref, auth.ActionPull, auth.ActionPush) url := buildRepositoryManifestURL(s.repo.PlainHTTP, ref) // unwrap the content for optimizations of built-in types. body := ioutil.UnwrapNopCloser(content) From 66f3841b5c66ff71df15e6f13065d0b76baa9a05 Mon Sep 17 00:00:00 2001 From: "Lixia (Sylvia) Lei" Date: Thu, 21 Sep 2023 15:08:19 +0800 Subject: [PATCH 17/29] update doc Signed-off-by: Lixia (Sylvia) Lei --- registry/remote/auth/scope.go | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/registry/remote/auth/scope.go b/registry/remote/auth/scope.go index cc060266..8ae74d03 100644 --- a/registry/remote/auth/scope.go +++ b/registry/remote/auth/scope.go @@ -57,9 +57,18 @@ func ScopeRepository(repository string, actions ...string) string { }, ":") } -// AppendScopeHints appends a repository scope with the given actions -// to the existing scopes in the context for the given registry and returns -// a new context. +// AppendScopeHints appends repository scope hints with the given actions +// for the auth client to fetch bearer tokens with larger scopes. +// +// For example, uploading blob to the repository "hello-world" does HEAD request +// first then POST and PUT. The HEAD request will return a challenge for scope +// `repository:hello-world:pull`, and the auth client will fetch a token for +// that challenge. Later, the POST request will return a challenge for scope +// `repository:hello-world:push`, and the auth client will fetch a token for +// that challenge again. By invoking `AppendScopeHints()` with the actions +// [ActionPull] and [ActionPush] for the repository `hello-world`, +// the auth client with cache is hinted to fetch a token via a single token +// fetch request for all the HEAD, POST, PUT requests. func AppendScopeHints(ctx context.Context, ref registry.Reference, actions ...string) context.Context { if len(actions) == 0 { return ctx From c5179541e695a49cdc004ce55c34d234e730b6ae Mon Sep 17 00:00:00 2001 From: "Lixia (Sylvia) Lei" Date: Thu, 21 Sep 2023 15:10:41 +0800 Subject: [PATCH 18/29] minor reword Signed-off-by: Lixia (Sylvia) Lei --- registry/remote/auth/scope.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/registry/remote/auth/scope.go b/registry/remote/auth/scope.go index 8ae74d03..61e13b28 100644 --- a/registry/remote/auth/scope.go +++ b/registry/remote/auth/scope.go @@ -65,7 +65,7 @@ func ScopeRepository(repository string, actions ...string) string { // `repository:hello-world:pull`, and the auth client will fetch a token for // that challenge. Later, the POST request will return a challenge for scope // `repository:hello-world:push`, and the auth client will fetch a token for -// that challenge again. By invoking `AppendScopeHints()` with the actions +// that challenge again. By invoking AppendScopeHints with the actions // [ActionPull] and [ActionPush] for the repository `hello-world`, // the auth client with cache is hinted to fetch a token via a single token // fetch request for all the HEAD, POST, PUT requests. @@ -89,7 +89,7 @@ type scopesContextKey struct{} // `repository:hello-world:pull`, and the auth client will fetch a token for // that challenge. Later, the POST request will return a challenge for scope // `repository:hello-world:push`, and the auth client will fetch a token for -// that challenge again. By invoking `WithScopes()` with the scope +// that challenge again. By invoking WithScopes with the scope // `repository:hello-world:pull,push`, the auth client with cache is hinted to // fetch a token via a single token fetch request for all the HEAD, POST, PUT // requests. @@ -134,7 +134,7 @@ type scopesPerHostContextKey struct{} // `repository:hello-world:pull`, and the auth client will fetch a token for // that challenge. Later, the POST request will return a challenge for scope // `repository:hello-world:push`, and the auth client will fetch a token for -// that challenge again. By invoking `WithScopes()` with the scope +// that challenge again. By invoking WithScopesPerHost with the scope // `repository:hello-world:pull,push`, the auth client with cache is hinted to // fetch a token via a single token fetch request for all the HEAD, POST, PUT // requests. From 799ba84b29e324e374098fa3fe2a948476d847b8 Mon Sep 17 00:00:00 2001 From: "Lixia (Sylvia) Lei" Date: Thu, 21 Sep 2023 15:20:21 +0800 Subject: [PATCH 19/29] rename Signed-off-by: Lixia (Sylvia) Lei --- registry/remote/auth/scope.go | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/registry/remote/auth/scope.go b/registry/remote/auth/scope.go index 61e13b28..99b62d8f 100644 --- a/registry/remote/auth/scope.go +++ b/registry/remote/auth/scope.go @@ -57,8 +57,10 @@ func ScopeRepository(repository string, actions ...string) string { }, ":") } -// AppendScopeHints appends repository scope hints with the given actions -// for the auth client to fetch bearer tokens with larger scopes. +// AppendScopeHints returns a new context containing scope hints for the auth +// client to fetch bearer tokens with the given actions on the repository. +// If called multiple times, the new scopes will be appended to the existing +// scopes. The resulted scopes are de-duplicated. // // For example, uploading blob to the repository "hello-world" does HEAD request // first then POST and PUT. The HEAD request will return a challenge for scope From 5a766a88885d37f20de3822efbac701727e0478e Mon Sep 17 00:00:00 2001 From: "Lixia (Sylvia) Lei" Date: Thu, 21 Sep 2023 15:42:07 +0800 Subject: [PATCH 20/29] fix Signed-off-by: Lixia (Sylvia) Lei --- registry/remote/auth/client.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/registry/remote/auth/client.go b/registry/remote/auth/client.go index 22e7b8aa..59df4137 100644 --- a/registry/remote/auth/client.go +++ b/registry/remote/auth/client.go @@ -189,6 +189,7 @@ func (c *Client) Do(originalReq *http.Request) (*http.Response, error) { case SchemeBearer: // merge per-host scopes with generic scopes scopes := append(GetScopesPerHost(ctx, host), GetScopes(ctx)...) + scopes = CleanScopes(scopes) attemptedKey = strings.Join(scopes, " ") token, err := cache.GetToken(ctx, host, SchemeBearer, attemptedKey) if err == nil { @@ -226,11 +227,11 @@ func (c *Client) Do(originalReq *http.Request) (*http.Response, error) { // merge per-host scopes with generic scopes scopes := append(GetScopesPerHost(ctx, host), GetScopes(ctx)...) - // merge hinted scopes with challenged scopes if paramScope := params["scope"]; paramScope != "" { + // merge hinted scopes with challenged scopes scopes = append(scopes, strings.Split(paramScope, " ")...) - scopes = CleanScopes(scopes) } + scopes = CleanScopes(scopes) key := strings.Join(scopes, " ") // attempt the cache again if there is a scope change From 29f3497e399084f1865aeb27c52f76d3010b7ada Mon Sep 17 00:00:00 2001 From: "Lixia (Sylvia) Lei" Date: Thu, 21 Sep 2023 16:00:25 +0800 Subject: [PATCH 21/29] optimize Signed-off-by: Lixia (Sylvia) Lei --- registry/remote/auth/client.go | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/registry/remote/auth/client.go b/registry/remote/auth/client.go index 59df4137..aea376ed 100644 --- a/registry/remote/auth/client.go +++ b/registry/remote/auth/client.go @@ -188,8 +188,11 @@ func (c *Client) Do(originalReq *http.Request) (*http.Response, error) { } case SchemeBearer: // merge per-host scopes with generic scopes - scopes := append(GetScopesPerHost(ctx, host), GetScopes(ctx)...) - scopes = CleanScopes(scopes) + scopes := GetScopesPerHost(ctx, host) + if moreScopes := GetScopes(ctx); len(moreScopes) > 0 { + scopes = append(scopes, moreScopes...) + scopes = CleanScopes(scopes) + } attemptedKey = strings.Join(scopes, " ") token, err := cache.GetToken(ctx, host, SchemeBearer, attemptedKey) if err == nil { @@ -225,13 +228,20 @@ func (c *Client) Do(originalReq *http.Request) (*http.Response, error) { case SchemeBearer: resp.Body.Close() - // merge per-host scopes with generic scopes - scopes := append(GetScopesPerHost(ctx, host), GetScopes(ctx)...) + scopes := GetScopesPerHost(ctx, host) + cleanScopeLen := len(scopes) + if moreScopes := GetScopes(ctx); len(moreScopes) > 0 { + // merge per-host scopes with generic scopes + scopes = append(scopes, moreScopes...) + } if paramScope := params["scope"]; paramScope != "" { // merge hinted scopes with challenged scopes scopes = append(scopes, strings.Split(paramScope, " ")...) } - scopes = CleanScopes(scopes) + if len(scopes) > cleanScopeLen { + // re-clean the scopes + scopes = CleanScopes(scopes) + } key := strings.Join(scopes, " ") // attempt the cache again if there is a scope change From 78c68b8d90e25f19f816ee49a3f70d82ca03f6dc Mon Sep 17 00:00:00 2001 From: "Lixia (Sylvia) Lei" Date: Thu, 21 Sep 2023 16:59:33 +0800 Subject: [PATCH 22/29] fix race? Signed-off-by: Lixia (Sylvia) Lei --- internal/maps/maps.go | 28 ++++++++++++++++++++++++++++ registry/remote/auth/scope.go | 22 ++++++++++++---------- 2 files changed, 40 insertions(+), 10 deletions(-) create mode 100644 internal/maps/maps.go diff --git a/internal/maps/maps.go b/internal/maps/maps.go new file mode 100644 index 00000000..8db27fcb --- /dev/null +++ b/internal/maps/maps.go @@ -0,0 +1,28 @@ +/* +Copyright The ORAS Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package maps + +// Copy copies all key/value pairs in src adding them to dst. +// When a key in src is already present in dst, +// the value in dst will be overwritten by the value associated +// with the key in src. +// +// Reference: https://pkg.go.dev/maps@go1.21.1#Copy +func Copy[M1 ~map[K]V, M2 ~map[K]V, K comparable, V any](dst M1, src M2) { + for k, v := range src { + dst[k] = v + } +} diff --git a/registry/remote/auth/scope.go b/registry/remote/auth/scope.go index 99b62d8f..7d8bb2c9 100644 --- a/registry/remote/auth/scope.go +++ b/registry/remote/auth/scope.go @@ -20,6 +20,7 @@ import ( "sort" "strings" + "oras.land/oras-go/v2/internal/maps" "oras.land/oras-go/v2/internal/slices" "oras.land/oras-go/v2/registry" ) @@ -146,15 +147,16 @@ type scopesPerHostContextKey struct{} // // Reference: https://docs.docker.com/registry/spec/auth/scope/ func WithScopesPerHost(ctx context.Context, host string, scopes ...string) context.Context { - var regMap map[string][]string - var ok bool - regMap, ok = ctx.Value(scopesPerHostContextKey{}).(map[string][]string) - if !ok { - regMap = make(map[string][]string) + var scopesByHost map[string][]string + if old, ok := ctx.Value(scopesPerHostContextKey{}).(map[string][]string); ok { + scopesByHost = make(map[string][]string, len(old)) + maps.Copy(scopesByHost, old) + } else { + scopesByHost = make(map[string][]string, 1) } - scopes = CleanScopes(scopes) - regMap[host] = scopes - return context.WithValue(ctx, scopesPerHostContextKey{}, regMap) + + scopesByHost[host] = CleanScopes(scopes) + return context.WithValue(ctx, scopesPerHostContextKey{}, scopesByHost) } // AppendScopesPerHost appends additional scopes to the existing scopes @@ -171,8 +173,8 @@ func AppendScopesPerHost(ctx context.Context, host string, scopes ...string) con // GetScopesPerHost returns the scopes in the context for the given host. func GetScopesPerHost(ctx context.Context, host string) []string { - if regMap, ok := ctx.Value(scopesPerHostContextKey{}).(map[string][]string); ok { - return slices.Clone(regMap[host]) + if scopesByHost, ok := ctx.Value(scopesPerHostContextKey{}).(map[string][]string); ok { + return slices.Clone(scopesByHost[host]) } return nil } From 60eba30b8735368bd924b018bfe03af143c47ade Mon Sep 17 00:00:00 2001 From: "Lixia (Sylvia) Lei" Date: Mon, 25 Sep 2023 15:45:26 +0800 Subject: [PATCH 23/29] refactor Signed-off-by: Lixia (Sylvia) Lei --- internal/maps/maps.go | 28 ---------------------------- registry/remote/auth/scope.go | 19 +++++-------------- 2 files changed, 5 insertions(+), 42 deletions(-) delete mode 100644 internal/maps/maps.go diff --git a/internal/maps/maps.go b/internal/maps/maps.go deleted file mode 100644 index 8db27fcb..00000000 --- a/internal/maps/maps.go +++ /dev/null @@ -1,28 +0,0 @@ -/* -Copyright The ORAS Authors. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package maps - -// Copy copies all key/value pairs in src adding them to dst. -// When a key in src is already present in dst, -// the value in dst will be overwritten by the value associated -// with the key in src. -// -// Reference: https://pkg.go.dev/maps@go1.21.1#Copy -func Copy[M1 ~map[K]V, M2 ~map[K]V, K comparable, V any](dst M1, src M2) { - for k, v := range src { - dst[k] = v - } -} diff --git a/registry/remote/auth/scope.go b/registry/remote/auth/scope.go index 7d8bb2c9..ce5b255a 100644 --- a/registry/remote/auth/scope.go +++ b/registry/remote/auth/scope.go @@ -20,7 +20,6 @@ import ( "sort" "strings" - "oras.land/oras-go/v2/internal/maps" "oras.land/oras-go/v2/internal/slices" "oras.land/oras-go/v2/registry" ) @@ -125,7 +124,7 @@ func GetScopes(ctx context.Context) []string { } // scopesPerHostContextKey is the context key for per-host scopes. -type scopesPerHostContextKey struct{} +type scopesPerHostContextKey string // WithScopesPerHost returns a context with per-host scopes added. // Scopes are de-duplicated. @@ -147,16 +146,8 @@ type scopesPerHostContextKey struct{} // // Reference: https://docs.docker.com/registry/spec/auth/scope/ func WithScopesPerHost(ctx context.Context, host string, scopes ...string) context.Context { - var scopesByHost map[string][]string - if old, ok := ctx.Value(scopesPerHostContextKey{}).(map[string][]string); ok { - scopesByHost = make(map[string][]string, len(old)) - maps.Copy(scopesByHost, old) - } else { - scopesByHost = make(map[string][]string, 1) - } - - scopesByHost[host] = CleanScopes(scopes) - return context.WithValue(ctx, scopesPerHostContextKey{}, scopesByHost) + scopes = CleanScopes(scopes) + return context.WithValue(ctx, scopesPerHostContextKey(host), scopes) } // AppendScopesPerHost appends additional scopes to the existing scopes @@ -173,8 +164,8 @@ func AppendScopesPerHost(ctx context.Context, host string, scopes ...string) con // GetScopesPerHost returns the scopes in the context for the given host. func GetScopesPerHost(ctx context.Context, host string) []string { - if scopesByHost, ok := ctx.Value(scopesPerHostContextKey{}).(map[string][]string); ok { - return slices.Clone(scopesByHost[host]) + if scopes, ok := ctx.Value(scopesPerHostContextKey(host)).([]string); ok { + return slices.Clone(scopes) } return nil } From 783cffc36cda8b666099bd80d214079474098c0c Mon Sep 17 00:00:00 2001 From: "Lixia (Sylvia) Lei" Date: Mon, 25 Sep 2023 16:28:04 +0800 Subject: [PATCH 24/29] rename Signed-off-by: Lixia (Sylvia) Lei --- registry/remote/auth/client.go | 4 +- registry/remote/auth/client_test.go | 24 +++++----- registry/remote/auth/scope.go | 28 ++++++------ registry/remote/auth/scope_test.go | 68 ++++++++++++++--------------- registry/remote/registry.go | 2 +- 5 files changed, 63 insertions(+), 63 deletions(-) diff --git a/registry/remote/auth/client.go b/registry/remote/auth/client.go index aea376ed..a9a49d38 100644 --- a/registry/remote/auth/client.go +++ b/registry/remote/auth/client.go @@ -188,7 +188,7 @@ func (c *Client) Do(originalReq *http.Request) (*http.Response, error) { } case SchemeBearer: // merge per-host scopes with generic scopes - scopes := GetScopesPerHost(ctx, host) + scopes := GetScopesForHost(ctx, host) if moreScopes := GetScopes(ctx); len(moreScopes) > 0 { scopes = append(scopes, moreScopes...) scopes = CleanScopes(scopes) @@ -228,7 +228,7 @@ func (c *Client) Do(originalReq *http.Request) (*http.Response, error) { case SchemeBearer: resp.Body.Close() - scopes := GetScopesPerHost(ctx, host) + scopes := GetScopesForHost(ctx, host) cleanScopeLen := len(scopes) if moreScopes := GetScopes(ctx); len(moreScopes) > 0 { // merge per-host scopes with generic scopes diff --git a/registry/remote/auth/client_test.go b/registry/remote/auth/client_test.go index 74db0dba..de879863 100644 --- a/registry/remote/auth/client_test.go +++ b/registry/remote/auth/client_test.go @@ -526,8 +526,8 @@ func TestClient_Do_Bearer_AccessToken_Cached_PerHost(t *testing.T) { } ctx := context.Background() - ctx = WithScopesPerHost(ctx, uri1.Host, scope1) - ctx = WithScopesPerHost(ctx, uri2.Host, scope2) + ctx = WithScopesForHost(ctx, uri1.Host, scope1) + ctx = WithScopesForHost(ctx, uri2.Host, scope2) // first request to server 1 req1, err := http.NewRequestWithContext(ctx, http.MethodGet, ts1.URL, nil) if err != nil { @@ -1066,8 +1066,8 @@ func TestClient_Do_Bearer_Auth_Cached_PerHost(t *testing.T) { } ctx := context.Background() - ctx = WithScopesPerHost(ctx, uri1.Host, scopes1...) - ctx = WithScopesPerHost(ctx, uri2.Host, scopes2...) + ctx = WithScopesForHost(ctx, uri1.Host, scopes1...) + ctx = WithScopesForHost(ctx, uri2.Host, scopes2...) // first request to server 1 req1, err := http.NewRequestWithContext(ctx, http.MethodGet, ts1.URL, nil) if err != nil { @@ -1716,8 +1716,8 @@ func TestClient_Do_Bearer_OAuth2_Password_Cached_PerHost(t *testing.T) { } ctx := context.Background() - ctx = WithScopesPerHost(ctx, uri1.Host, scopes1...) - ctx = WithScopesPerHost(ctx, uri2.Host, scopes2...) + ctx = WithScopesForHost(ctx, uri1.Host, scopes1...) + ctx = WithScopesForHost(ctx, uri2.Host, scopes2...) // first request to server 1 req1, err := http.NewRequestWithContext(ctx, http.MethodGet, ts1.URL, nil) if err != nil { @@ -2330,8 +2330,8 @@ func TestClient_Do_Bearer_OAuth2_RefreshToken_Cached_PerHost(t *testing.T) { } ctx := context.Background() - ctx = WithScopesPerHost(ctx, uri1.Host, scopes1...) - ctx = WithScopesPerHost(ctx, uri2.Host, scopes2...) + ctx = WithScopesForHost(ctx, uri1.Host, scopes1...) + ctx = WithScopesForHost(ctx, uri2.Host, scopes2...) // first request to server 1 req1, err := http.NewRequestWithContext(ctx, http.MethodGet, ts1.URL, nil) if err != nil { @@ -2779,8 +2779,8 @@ func TestClient_Do_Token_Expire_PerHost(t *testing.T) { } ctx := context.Background() - ctx = WithScopesPerHost(ctx, uri1.Host, scopes1...) - ctx = WithScopesPerHost(ctx, uri2.Host, scopes2...) + ctx = WithScopesForHost(ctx, uri1.Host, scopes1...) + ctx = WithScopesForHost(ctx, uri2.Host, scopes2...) // first request to server 1 req1, err := http.NewRequestWithContext(ctx, http.MethodGet, ts1.URL, nil) if err != nil { @@ -3212,8 +3212,8 @@ func TestClient_Do_Scope_Hint_Mismatch_PerHost(t *testing.T) { } ctx := context.Background() - ctx = WithScopesPerHost(ctx, uri1.Host, scopes1...) - ctx = WithScopesPerHost(ctx, uri2.Host, scopes2...) + ctx = WithScopesForHost(ctx, uri1.Host, scopes1...) + ctx = WithScopesForHost(ctx, uri2.Host, scopes2...) // first request to server 1 req1, err := http.NewRequestWithContext(ctx, http.MethodGet, ts1.URL, nil) if err != nil { diff --git a/registry/remote/auth/scope.go b/registry/remote/auth/scope.go index ce5b255a..41cb3e2e 100644 --- a/registry/remote/auth/scope.go +++ b/registry/remote/auth/scope.go @@ -76,7 +76,7 @@ func AppendScopeHints(ctx context.Context, ref registry.Reference, actions ...st return ctx } scope := ScopeRepository(ref.Repository, actions...) - return AppendScopesPerHost(ctx, ref.Host(), scope) + return AppendScopesForHost(ctx, ref.Host(), scope) } // scopesContextKey is the context key for scopes. @@ -123,10 +123,10 @@ func GetScopes(ctx context.Context) []string { return nil } -// scopesPerHostContextKey is the context key for per-host scopes. -type scopesPerHostContextKey string +// scopesForHostContextKey is the context key for per-host scopes. +type scopesForHostContextKey string -// WithScopesPerHost returns a context with per-host scopes added. +// WithScopesForHost returns a context with per-host scopes added. // Scopes are de-duplicated. // Scopes are used as hints for the auth client to fetch bearer tokens with // larger scopes. @@ -136,7 +136,7 @@ type scopesPerHostContextKey string // `repository:hello-world:pull`, and the auth client will fetch a token for // that challenge. Later, the POST request will return a challenge for scope // `repository:hello-world:push`, and the auth client will fetch a token for -// that challenge again. By invoking WithScopesPerHost with the scope +// that challenge again. By invoking WithScopesForHost with the scope // `repository:hello-world:pull,push`, the auth client with cache is hinted to // fetch a token via a single token fetch request for all the HEAD, POST, PUT // requests. @@ -145,26 +145,26 @@ type scopesPerHostContextKey string // context for the given host. // // Reference: https://docs.docker.com/registry/spec/auth/scope/ -func WithScopesPerHost(ctx context.Context, host string, scopes ...string) context.Context { +func WithScopesForHost(ctx context.Context, host string, scopes ...string) context.Context { scopes = CleanScopes(scopes) - return context.WithValue(ctx, scopesPerHostContextKey(host), scopes) + return context.WithValue(ctx, scopesForHostContextKey(host), scopes) } -// AppendScopesPerHost appends additional scopes to the existing scopes +// AppendScopesForHost appends additional scopes to the existing scopes // in the context for the given host and returns a new context. // The resulted scopes are de-duplicated. // The append operation does modify the existing scope in the context passed in. -func AppendScopesPerHost(ctx context.Context, host string, scopes ...string) context.Context { +func AppendScopesForHost(ctx context.Context, host string, scopes ...string) context.Context { if len(scopes) == 0 { return ctx } - oldScopes := GetScopesPerHost(ctx, host) - return WithScopesPerHost(ctx, host, append(oldScopes, scopes...)...) + oldScopes := GetScopesForHost(ctx, host) + return WithScopesForHost(ctx, host, append(oldScopes, scopes...)...) } -// GetScopesPerHost returns the scopes in the context for the given host. -func GetScopesPerHost(ctx context.Context, host string) []string { - if scopes, ok := ctx.Value(scopesPerHostContextKey(host)).([]string); ok { +// GetScopesForHost returns the scopes in the context for the given host. +func GetScopesForHost(ctx context.Context, host string) []string { + if scopes, ok := ctx.Value(scopesForHostContextKey(host)).([]string); ok { return slices.Clone(scopes) } return nil diff --git a/registry/remote/auth/scope_test.go b/registry/remote/auth/scope_test.go index 44c07f79..be901499 100644 --- a/registry/remote/auth/scope_test.go +++ b/registry/remote/auth/scope_test.go @@ -125,10 +125,10 @@ func TestWithScopeHints(t *testing.T) { } ctx = AppendScopeHints(ctx, ref1, ActionPull) ctx = AppendScopeHints(ctx, ref2, ActionPush) - if got := GetScopesPerHost(ctx, ref1.Host()); !reflect.DeepEqual(got, want1) { + if got := GetScopesForHost(ctx, ref1.Host()); !reflect.DeepEqual(got, want1) { t.Errorf("GetScopesPerRegistry(WithScopeHints()) = %v, want %v", got, want1) } - if got := GetScopesPerHost(ctx, ref2.Host()); !reflect.DeepEqual(got, want2) { + if got := GetScopesForHost(ctx, ref2.Host()); !reflect.DeepEqual(got, want2) { t.Errorf("GetScopesPerRegistry(WithScopeHints()) = %v, want %v", got, want2) } @@ -151,20 +151,20 @@ func TestWithScopeHints(t *testing.T) { } ctx = AppendScopeHints(ctx, ref1, scopes1...) ctx = AppendScopeHints(ctx, ref2, scopes2...) - if got := GetScopesPerHost(ctx, ref1.Host()); !reflect.DeepEqual(got, want1) { + if got := GetScopesForHost(ctx, ref1.Host()); !reflect.DeepEqual(got, want1) { t.Errorf("GetScopesPerRegistry(WithScopeHints()) = %v, want %v", got, want1) } - if got := GetScopesPerHost(ctx, ref2.Host()); !reflect.DeepEqual(got, want2) { + if got := GetScopesForHost(ctx, ref2.Host()); !reflect.DeepEqual(got, want2) { t.Errorf("GetScopesPerRegistry(WithScopeHints()) = %v, want %v", got, want2) } // append empty scopes ctx = AppendScopeHints(ctx, ref1) ctx = AppendScopeHints(ctx, ref2) - if got := GetScopesPerHost(ctx, ref1.Host()); !reflect.DeepEqual(got, want1) { + if got := GetScopesForHost(ctx, ref1.Host()); !reflect.DeepEqual(got, want1) { t.Errorf("GetScopesPerRegistry(WithScopeHints()) = %v, want %v", got, want1) } - if got := GetScopesPerHost(ctx, ref2.Host()); !reflect.DeepEqual(got, want2) { + if got := GetScopesForHost(ctx, ref2.Host()); !reflect.DeepEqual(got, want2) { t.Errorf("GetScopesPerRegistry(WithScopeHints()) = %v, want %v", got, want2) } } @@ -262,12 +262,12 @@ func TestWithScopesPerHost(t *testing.T) { want2 := []string{ "repository:foo:push", } - ctx = WithScopesPerHost(ctx, reg1, want1...) - ctx = WithScopesPerHost(ctx, reg2, want2...) - if got := GetScopesPerHost(ctx, reg1); !reflect.DeepEqual(got, want1) { + ctx = WithScopesForHost(ctx, reg1, want1...) + ctx = WithScopesForHost(ctx, reg2, want2...) + if got := GetScopesForHost(ctx, reg1); !reflect.DeepEqual(got, want1) { t.Errorf("GetScopesPerRegistry(WithScopesPerRegistry()) = %v, want %v", got, want1) } - if got := GetScopesPerHost(ctx, reg2); !reflect.DeepEqual(got, want2) { + if got := GetScopesForHost(ctx, reg2); !reflect.DeepEqual(got, want2) { t.Errorf("GetScopesPerRegistry(WithScopesPerRegistry()) = %v, want %v", got, want2) } @@ -278,12 +278,12 @@ func TestWithScopesPerHost(t *testing.T) { want2 = []string{ "repository:bar:pull", } - ctx = WithScopesPerHost(ctx, reg1, want1...) - ctx = WithScopesPerHost(ctx, reg2, want2...) - if got := GetScopesPerHost(ctx, reg1); !reflect.DeepEqual(got, want1) { + ctx = WithScopesForHost(ctx, reg1, want1...) + ctx = WithScopesForHost(ctx, reg2, want2...) + if got := GetScopesForHost(ctx, reg1); !reflect.DeepEqual(got, want1) { t.Errorf("GetScopesPerRegistry(WithScopesPerRegistry()) = %v, want %v", got, want1) } - if got := GetScopesPerHost(ctx, reg2); !reflect.DeepEqual(got, want2) { + if got := GetScopesForHost(ctx, reg2); !reflect.DeepEqual(got, want2) { t.Errorf("GetScopesPerRegistry(WithScopesPerRegistry()) = %v, want %v", got, want2) } @@ -308,23 +308,23 @@ func TestWithScopesPerHost(t *testing.T) { "repository:goodbye-world:pull,push", "repository:nginx:delete", } - ctx = WithScopesPerHost(ctx, reg1, scopes1...) - ctx = WithScopesPerHost(ctx, reg2, scopes2...) - if got := GetScopesPerHost(ctx, reg1); !reflect.DeepEqual(got, want1) { + ctx = WithScopesForHost(ctx, reg1, scopes1...) + ctx = WithScopesForHost(ctx, reg2, scopes2...) + if got := GetScopesForHost(ctx, reg1); !reflect.DeepEqual(got, want1) { t.Errorf("GetScopesPerRegistry(WithScopesPerRegistry()) = %v, want %v", got, want1) } - if got := GetScopesPerHost(ctx, reg2); !reflect.DeepEqual(got, want2) { + if got := GetScopesForHost(ctx, reg2); !reflect.DeepEqual(got, want2) { t.Errorf("GetScopesPerRegistry(WithScopesPerRegistry()) = %v, want %v", got, want2) } // clean scopes var want []string - ctx = WithScopesPerHost(ctx, reg1, want...) - ctx = WithScopesPerHost(ctx, reg2, want...) - if got := GetScopesPerHost(ctx, reg1); !reflect.DeepEqual(got, want) { + ctx = WithScopesForHost(ctx, reg1, want...) + ctx = WithScopesForHost(ctx, reg2, want...) + if got := GetScopesForHost(ctx, reg1); !reflect.DeepEqual(got, want) { t.Errorf("GetScopesPerRegistry(WithScopesPerRegistry()) = %v, want %v", got, want) } - if got := GetScopesPerHost(ctx, reg2); !reflect.DeepEqual(got, want) { + if got := GetScopesForHost(ctx, reg2); !reflect.DeepEqual(got, want) { t.Errorf("GetScopesPerRegistry(WithScopesPerRegistry()) = %v, want %v", got, want) } } @@ -341,12 +341,12 @@ func TestAppendScopesPerHost(t *testing.T) { want2 := []string{ "repository:foo:push", } - ctx = AppendScopesPerHost(ctx, reg1, want1...) - ctx = AppendScopesPerHost(ctx, reg2, want2...) - if got := GetScopesPerHost(ctx, reg1); !reflect.DeepEqual(got, want1) { + ctx = AppendScopesForHost(ctx, reg1, want1...) + ctx = AppendScopesForHost(ctx, reg2, want2...) + if got := GetScopesForHost(ctx, reg1); !reflect.DeepEqual(got, want1) { t.Errorf("GetScopesPerRegistry(AppendScopesPerRegistry()) = %v, want %v", got, want1) } - if got := GetScopesPerHost(ctx, reg2); !reflect.DeepEqual(got, want2) { + if got := GetScopesForHost(ctx, reg2); !reflect.DeepEqual(got, want2) { t.Errorf("GetScopesPerRegistry(AppendScopesPerRegistry()) = %v, want %v", got, want2) } @@ -373,22 +373,22 @@ func TestAppendScopesPerHost(t *testing.T) { "repository:goodbye-world:pull,push", "repository:nginx:delete", } - ctx = AppendScopesPerHost(ctx, reg1, scopes1...) - ctx = AppendScopesPerHost(ctx, reg2, scopes2...) - if got := GetScopesPerHost(ctx, reg1); !reflect.DeepEqual(got, want1) { + ctx = AppendScopesForHost(ctx, reg1, scopes1...) + ctx = AppendScopesForHost(ctx, reg2, scopes2...) + if got := GetScopesForHost(ctx, reg1); !reflect.DeepEqual(got, want1) { t.Errorf("GetScopesPerRegistry(AppendScopesPerRegistry()) = %v, want %v", got, want1) } - if got := GetScopesPerHost(ctx, reg2); !reflect.DeepEqual(got, want2) { + if got := GetScopesForHost(ctx, reg2); !reflect.DeepEqual(got, want2) { t.Errorf("GetScopesPerRegistry(AppendScopesPerRegistry()) = %v, want %v", got, want2) } // append empty scopes - ctx = AppendScopesPerHost(ctx, reg1) - ctx = AppendScopesPerHost(ctx, reg2) - if got := GetScopesPerHost(ctx, reg1); !reflect.DeepEqual(got, want1) { + ctx = AppendScopesForHost(ctx, reg1) + ctx = AppendScopesForHost(ctx, reg2) + if got := GetScopesForHost(ctx, reg1); !reflect.DeepEqual(got, want1) { t.Errorf("GetScopesPerRegistry(AppendScopesPerRegistry()) = %v, want %v", got, want1) } - if got := GetScopesPerHost(ctx, reg2); !reflect.DeepEqual(got, want2) { + if got := GetScopesForHost(ctx, reg2); !reflect.DeepEqual(got, want2) { t.Errorf("GetScopesPerRegistry(AppendScopesPerRegistry()) = %v, want %v", got, want2) } } diff --git a/registry/remote/registry.go b/registry/remote/registry.go index f0c417e4..d1334042 100644 --- a/registry/remote/registry.go +++ b/registry/remote/registry.go @@ -127,7 +127,7 @@ func (r *Registry) Ping(ctx context.Context) error { // // Reference: https://docs.docker.com/registry/spec/api/#catalog func (r *Registry) Repositories(ctx context.Context, last string, fn func(repos []string) error) error { - ctx = auth.AppendScopesPerHost(ctx, r.Reference.Host(), auth.ScopeRegistryCatalog) + ctx = auth.AppendScopesForHost(ctx, r.Reference.Host(), auth.ScopeRegistryCatalog) url := buildRegistryCatalogURL(r.PlainHTTP, r.Reference) var err error for err == nil { From 32e48667ac7a4ca7377023a62e5ed125cbc37862 Mon Sep 17 00:00:00 2001 From: "Lixia (Sylvia) Lei" Date: Mon, 25 Sep 2023 16:39:11 +0800 Subject: [PATCH 25/29] update doc Signed-off-by: Lixia (Sylvia) Lei --- registry/remote/auth/client.go | 4 ++-- registry/remote/auth/scope.go | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/registry/remote/auth/client.go b/registry/remote/auth/client.go index a9a49d38..0303b8b9 100644 --- a/registry/remote/auth/client.go +++ b/registry/remote/auth/client.go @@ -187,7 +187,7 @@ func (c *Client) Do(originalReq *http.Request) (*http.Response, error) { req.Header.Set("Authorization", "Basic "+token) } case SchemeBearer: - // merge per-host scopes with generic scopes + // merge per-host scopes with global scopes scopes := GetScopesForHost(ctx, host) if moreScopes := GetScopes(ctx); len(moreScopes) > 0 { scopes = append(scopes, moreScopes...) @@ -231,7 +231,7 @@ func (c *Client) Do(originalReq *http.Request) (*http.Response, error) { scopes := GetScopesForHost(ctx, host) cleanScopeLen := len(scopes) if moreScopes := GetScopes(ctx); len(moreScopes) > 0 { - // merge per-host scopes with generic scopes + // merge per-host scopes with global scopes scopes = append(scopes, moreScopes...) } if paramScope := params["scope"]; paramScope != "" { diff --git a/registry/remote/auth/scope.go b/registry/remote/auth/scope.go index 41cb3e2e..bd98b6c6 100644 --- a/registry/remote/auth/scope.go +++ b/registry/remote/auth/scope.go @@ -162,7 +162,8 @@ func AppendScopesForHost(ctx context.Context, host string, scopes ...string) con return WithScopesForHost(ctx, host, append(oldScopes, scopes...)...) } -// GetScopesForHost returns the scopes in the context for the given host. +// GetScopesForHost returns the scopes in the context for the given host, +// excluding the global scopes added by [WithScopes] and [AppendScopes]. func GetScopesForHost(ctx context.Context, host string) []string { if scopes, ok := ctx.Value(scopesForHostContextKey(host)).([]string); ok { return slices.Clone(scopes) From fd3da4a50b6299492fa1d252a20cccb7e2829c69 Mon Sep 17 00:00:00 2001 From: "Lixia (Sylvia) Lei" Date: Mon, 25 Sep 2023 17:04:54 +0800 Subject: [PATCH 26/29] add a private function Signed-off-by: Lixia (Sylvia) Lei --- registry/remote/auth/client.go | 17 ++--------------- registry/remote/auth/scope.go | 20 +++++++++++++++++++- 2 files changed, 21 insertions(+), 16 deletions(-) diff --git a/registry/remote/auth/client.go b/registry/remote/auth/client.go index 0303b8b9..27bf182d 100644 --- a/registry/remote/auth/client.go +++ b/registry/remote/auth/client.go @@ -187,12 +187,7 @@ func (c *Client) Do(originalReq *http.Request) (*http.Response, error) { req.Header.Set("Authorization", "Basic "+token) } case SchemeBearer: - // merge per-host scopes with global scopes - scopes := GetScopesForHost(ctx, host) - if moreScopes := GetScopes(ctx); len(moreScopes) > 0 { - scopes = append(scopes, moreScopes...) - scopes = CleanScopes(scopes) - } + scopes := getAllScopesForHost(ctx, host) attemptedKey = strings.Join(scopes, " ") token, err := cache.GetToken(ctx, host, SchemeBearer, attemptedKey) if err == nil { @@ -228,18 +223,10 @@ func (c *Client) Do(originalReq *http.Request) (*http.Response, error) { case SchemeBearer: resp.Body.Close() - scopes := GetScopesForHost(ctx, host) - cleanScopeLen := len(scopes) - if moreScopes := GetScopes(ctx); len(moreScopes) > 0 { - // merge per-host scopes with global scopes - scopes = append(scopes, moreScopes...) - } + scopes := getAllScopesForHost(ctx, host) if paramScope := params["scope"]; paramScope != "" { // merge hinted scopes with challenged scopes scopes = append(scopes, strings.Split(paramScope, " ")...) - } - if len(scopes) > cleanScopeLen { - // re-clean the scopes scopes = CleanScopes(scopes) } key := strings.Join(scopes, " ") diff --git a/registry/remote/auth/scope.go b/registry/remote/auth/scope.go index bd98b6c6..0c042c4a 100644 --- a/registry/remote/auth/scope.go +++ b/registry/remote/auth/scope.go @@ -163,7 +163,7 @@ func AppendScopesForHost(ctx context.Context, host string, scopes ...string) con } // GetScopesForHost returns the scopes in the context for the given host, -// excluding the global scopes added by [WithScopes] and [AppendScopes]. +// excluding global scopes added by [WithScopes] and [AppendScopes]. func GetScopesForHost(ctx context.Context, host string) []string { if scopes, ok := ctx.Value(scopesForHostContextKey(host)).([]string); ok { return slices.Clone(scopes) @@ -171,6 +171,24 @@ func GetScopesForHost(ctx context.Context, host string) []string { return nil } +// getAllScopesForHost returns the scopes in the context for the given host, +// including global scopes added by [WithScopes] and [AppendScopes]. +func getAllScopesForHost(ctx context.Context, host string) []string { + scopes := GetScopesForHost(ctx, host) + globalScopes := GetScopes(ctx) + + switch { + case len(scopes) == 0: + return globalScopes + case len(globalScopes) == 0: + return scopes + default: + // re-clean the scopes + allScopes := append(scopes, globalScopes...) + return CleanScopes(allScopes) + } +} + // CleanScopes merges and sort the actions in ascending order if the scopes have // the same resource type and name. The final scopes are sorted in ascending // order. In other words, the scopes passed in are de-duplicated and sorted. From 56cc8e78b7d2cb1bfb6bbbee80995b9af5571f1f Mon Sep 17 00:00:00 2001 From: "Lixia (Sylvia) Lei" Date: Mon, 25 Sep 2023 17:23:02 +0800 Subject: [PATCH 27/29] test new function Signed-off-by: Lixia (Sylvia) Lei --- registry/remote/auth/scope_test.go | 68 ++++++++++++++++++++++++++++++ 1 file changed, 68 insertions(+) diff --git a/registry/remote/auth/scope_test.go b/registry/remote/auth/scope_test.go index be901499..0b00c5c7 100644 --- a/registry/remote/auth/scope_test.go +++ b/registry/remote/auth/scope_test.go @@ -658,3 +658,71 @@ func Test_cleanActions(t *testing.T) { }) } } + +func Test_getAllScopesForHost(t *testing.T) { + host := "registry.example.com" + tests := []struct { + name string + scopes []string + globalScopes []string + want []string + }{ + { + name: "Empty per-host scopes", + scopes: []string{}, + globalScopes: []string{ + "repository:hello-world:push", + "repository:alpine:delete", + "repository:hello-world:pull", + "repository:alpine:delete", + }, + want: []string{ + "repository:alpine:delete", + "repository:hello-world:pull,push", + }, + }, + { + name: "Empty global scopes", + scopes: []string{ + "repository:hello-world:push", + "repository:alpine:delete", + "repository:hello-world:pull", + "repository:alpine:delete", + }, + globalScopes: []string{}, + want: []string{ + "repository:alpine:delete", + "repository:hello-world:pull,push", + }, + }, + { + name: "Per-host scopes + global scopes", + scopes: []string{ + "repository:hello-world:push", + "repository:alpine:delete", + "repository:hello-world:pull", + "repository:alpine:delete", + }, + globalScopes: []string{ + "repository:foo:pull", + "repository:hello-world:pull", + "repository:alpine:pull", + }, + want: []string{ + "repository:alpine:delete,pull", + "repository:foo:pull", + "repository:hello-world:pull,push", + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := context.Background() + ctx = WithScopesForHost(ctx, host, tt.scopes...) + ctx = WithScopes(ctx, tt.globalScopes...) + if got := getAllScopesForHost(ctx, host); !reflect.DeepEqual(got, tt.want) { + t.Errorf("getAllScopesForHost() = %v, want %v", got, tt.want) + } + }) + } +} From e6fa97b48b9f73031d930a52b7443bb27ce69d60 Mon Sep 17 00:00:00 2001 From: "Lixia (Sylvia) Lei" Date: Tue, 26 Sep 2023 18:46:06 +0800 Subject: [PATCH 28/29] address comments Signed-off-by: Lixia (Sylvia) Lei --- content.go | 4 ++-- registry/remote/auth/scope.go | 22 +++++++++++----------- registry/remote/auth/scope_test.go | 12 ++++++------ registry/remote/repository.go | 30 +++++++++++++++--------------- 4 files changed, 34 insertions(+), 34 deletions(-) diff --git a/content.go b/content.go index 92794969..b8bf2638 100644 --- a/content.go +++ b/content.go @@ -90,7 +90,7 @@ func TagN(ctx context.Context, target Target, srcReference string, dstReferences if err != nil { return ocispec.Descriptor{}, err } - ctx = auth.AppendScopeHints(ctx, ref, auth.ActionPull, auth.ActionPush) + ctx = auth.AppendRepositoryScope(ctx, ref, auth.ActionPull, auth.ActionPush) } desc, contentBytes, err := FetchBytes(ctx, target, srcReference, FetchBytesOptions{ @@ -148,7 +148,7 @@ func Tag(ctx context.Context, target Target, src, dst string) (ocispec.Descripto if err != nil { return ocispec.Descriptor{}, err } - ctx = auth.AppendScopeHints(ctx, ref, auth.ActionPull, auth.ActionPush) + ctx = auth.AppendRepositoryScope(ctx, ref, auth.ActionPull, auth.ActionPush) } desc, rc, err := refFetcher.FetchReference(ctx, src) if err != nil { diff --git a/registry/remote/auth/scope.go b/registry/remote/auth/scope.go index 0c042c4a..8b7ec263 100644 --- a/registry/remote/auth/scope.go +++ b/registry/remote/auth/scope.go @@ -57,8 +57,8 @@ func ScopeRepository(repository string, actions ...string) string { }, ":") } -// AppendScopeHints returns a new context containing scope hints for the auth -// client to fetch bearer tokens with the given actions on the repository. +// AppendRepositoryScope returns a new context containing scope hints for the +// auth client to fetch bearer tokens with the given actions on the repository. // If called multiple times, the new scopes will be appended to the existing // scopes. The resulted scopes are de-duplicated. // @@ -67,11 +67,11 @@ func ScopeRepository(repository string, actions ...string) string { // `repository:hello-world:pull`, and the auth client will fetch a token for // that challenge. Later, the POST request will return a challenge for scope // `repository:hello-world:push`, and the auth client will fetch a token for -// that challenge again. By invoking AppendScopeHints with the actions +// that challenge again. By invoking AppendRepositoryScope with the actions // [ActionPull] and [ActionPush] for the repository `hello-world`, // the auth client with cache is hinted to fetch a token via a single token // fetch request for all the HEAD, POST, PUT requests. -func AppendScopeHints(ctx context.Context, ref registry.Reference, actions ...string) context.Context { +func AppendRepositoryScope(ctx context.Context, ref registry.Reference, actions ...string) context.Context { if len(actions) == 0 { return ctx } @@ -177,16 +177,16 @@ func getAllScopesForHost(ctx context.Context, host string) []string { scopes := GetScopesForHost(ctx, host) globalScopes := GetScopes(ctx) - switch { - case len(scopes) == 0: + if len(scopes) == 0 { return globalScopes - case len(globalScopes) == 0: + } + if len(globalScopes) == 0 { return scopes - default: - // re-clean the scopes - allScopes := append(scopes, globalScopes...) - return CleanScopes(allScopes) } + + // re-clean the scopes + allScopes := append(scopes, globalScopes...) + return CleanScopes(allScopes) } // CleanScopes merges and sort the actions in ascending order if the scopes have diff --git a/registry/remote/auth/scope_test.go b/registry/remote/auth/scope_test.go index 0b00c5c7..0c75e128 100644 --- a/registry/remote/auth/scope_test.go +++ b/registry/remote/auth/scope_test.go @@ -123,8 +123,8 @@ func TestWithScopeHints(t *testing.T) { want2 := []string{ "repository:foo:push", } - ctx = AppendScopeHints(ctx, ref1, ActionPull) - ctx = AppendScopeHints(ctx, ref2, ActionPush) + ctx = AppendRepositoryScope(ctx, ref1, ActionPull) + ctx = AppendRepositoryScope(ctx, ref2, ActionPush) if got := GetScopesForHost(ctx, ref1.Host()); !reflect.DeepEqual(got, want1) { t.Errorf("GetScopesPerRegistry(WithScopeHints()) = %v, want %v", got, want1) } @@ -149,8 +149,8 @@ func TestWithScopeHints(t *testing.T) { want2 = []string{ "repository:foo:delete,push", } - ctx = AppendScopeHints(ctx, ref1, scopes1...) - ctx = AppendScopeHints(ctx, ref2, scopes2...) + ctx = AppendRepositoryScope(ctx, ref1, scopes1...) + ctx = AppendRepositoryScope(ctx, ref2, scopes2...) if got := GetScopesForHost(ctx, ref1.Host()); !reflect.DeepEqual(got, want1) { t.Errorf("GetScopesPerRegistry(WithScopeHints()) = %v, want %v", got, want1) } @@ -159,8 +159,8 @@ func TestWithScopeHints(t *testing.T) { } // append empty scopes - ctx = AppendScopeHints(ctx, ref1) - ctx = AppendScopeHints(ctx, ref2) + ctx = AppendRepositoryScope(ctx, ref1) + ctx = AppendRepositoryScope(ctx, ref2) if got := GetScopesForHost(ctx, ref1.Host()); !reflect.DeepEqual(got, want1) { t.Errorf("GetScopesPerRegistry(WithScopeHints()) = %v, want %v", got, want1) } diff --git a/registry/remote/repository.go b/registry/remote/repository.go index 0e157068..5373492b 100644 --- a/registry/remote/repository.go +++ b/registry/remote/repository.go @@ -391,7 +391,7 @@ func (r *Repository) ParseReference(reference string) (registry.Reference, error // - https://github.com/opencontainers/distribution-spec/blob/v1.1.0-rc3/spec.md#content-discovery // - https://docs.docker.com/registry/spec/api/#tags func (r *Repository) Tags(ctx context.Context, last string, fn func(tags []string) error) error { - ctx = auth.AppendScopeHints(ctx, r.Reference, auth.ActionPull) + ctx = auth.AppendRepositoryScope(ctx, r.Reference, auth.ActionPull) url := buildRepositoryTagListURL(r.PlainHTTP, r.Reference) var err error for err == nil { @@ -508,7 +508,7 @@ func (r *Repository) Referrers(ctx context.Context, desc ocispec.Descriptor, art func (r *Repository) referrersByAPI(ctx context.Context, desc ocispec.Descriptor, artifactType string, fn func(referrers []ocispec.Descriptor) error) error { ref := r.Reference ref.Reference = desc.Digest.String() - ctx = auth.AppendScopeHints(ctx, ref, auth.ActionPull) + ctx = auth.AppendRepositoryScope(ctx, ref, auth.ActionPull) url := buildReferrersURL(r.PlainHTTP, ref, artifactType) var err error @@ -642,7 +642,7 @@ func (r *Repository) pingReferrers(ctx context.Context) (bool, error) { ref := r.Reference ref.Reference = zeroDigest - ctx = auth.AppendScopeHints(ctx, ref, auth.ActionPull) + ctx = auth.AppendRepositoryScope(ctx, ref, auth.ActionPull) url := buildReferrersURL(r.PlainHTTP, ref, "") req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) @@ -676,7 +676,7 @@ func (r *Repository) pingReferrers(ctx context.Context) (bool, error) { func (r *Repository) delete(ctx context.Context, target ocispec.Descriptor, isManifest bool) error { ref := r.Reference ref.Reference = target.Digest.String() - ctx = auth.AppendScopeHints(ctx, ref, auth.ActionDelete) + ctx = auth.AppendRepositoryScope(ctx, ref, auth.ActionDelete) buildURL := buildRepositoryBlobURL if isManifest { buildURL = buildRepositoryManifestURL @@ -712,7 +712,7 @@ type blobStore struct { func (s *blobStore) Fetch(ctx context.Context, target ocispec.Descriptor) (rc io.ReadCloser, err error) { ref := s.repo.Reference ref.Reference = target.Digest.String() - ctx = auth.AppendScopeHints(ctx, ref, auth.ActionPull) + ctx = auth.AppendRepositoryScope(ctx, ref, auth.ActionPull) url := buildRepositoryBlobURL(s.repo.PlainHTTP, ref) req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) if err != nil { @@ -754,12 +754,12 @@ func (s *blobStore) Fetch(ctx context.Context, target ocispec.Descriptor) (rc io func (s *blobStore) Mount(ctx context.Context, desc ocispec.Descriptor, fromRepo string, getContent func() (io.ReadCloser, error)) error { // pushing usually requires both pull and push actions. // Reference: https://github.com/distribution/distribution/blob/v2.7.1/registry/handlers/app.go#L921-L930 - ctx = auth.AppendScopeHints(ctx, s.repo.Reference, auth.ActionPull, auth.ActionPush) + ctx = auth.AppendRepositoryScope(ctx, s.repo.Reference, auth.ActionPull, auth.ActionPush) // We also need pull access to the source repo. fromRef := s.repo.Reference fromRef.Repository = fromRepo - ctx = auth.AppendScopeHints(ctx, fromRef, auth.ActionPull) + ctx = auth.AppendRepositoryScope(ctx, fromRef, auth.ActionPull) url := buildRepositoryBlobMountURL(s.repo.PlainHTTP, s.repo.Reference, desc.Digest, fromRepo) req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, nil) @@ -832,7 +832,7 @@ func (s *blobStore) Push(ctx context.Context, expected ocispec.Descriptor, conte // start an upload // pushing usually requires both pull and push actions. // Reference: https://github.com/distribution/distribution/blob/v2.7.1/registry/handlers/app.go#L921-L930 - ctx = auth.AppendScopeHints(ctx, s.repo.Reference, auth.ActionPull, auth.ActionPush) + ctx = auth.AppendRepositoryScope(ctx, s.repo.Reference, auth.ActionPull, auth.ActionPush) url := buildRepositoryBlobUploadURL(s.repo.PlainHTTP, s.repo.Reference) req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, nil) if err != nil { @@ -933,7 +933,7 @@ func (s *blobStore) Resolve(ctx context.Context, reference string) (ocispec.Desc if err != nil { return ocispec.Descriptor{}, err } - ctx = auth.AppendScopeHints(ctx, ref, auth.ActionPull) + ctx = auth.AppendRepositoryScope(ctx, ref, auth.ActionPull) url := buildRepositoryBlobURL(s.repo.PlainHTTP, ref) req, err := http.NewRequestWithContext(ctx, http.MethodHead, url, nil) if err != nil { @@ -968,7 +968,7 @@ func (s *blobStore) FetchReference(ctx context.Context, reference string) (desc return ocispec.Descriptor{}, nil, err } - ctx = auth.AppendScopeHints(ctx, ref, auth.ActionPull) + ctx = auth.AppendRepositoryScope(ctx, ref, auth.ActionPull) url := buildRepositoryBlobURL(s.repo.PlainHTTP, ref) req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) if err != nil { @@ -1043,7 +1043,7 @@ type manifestStore struct { func (s *manifestStore) Fetch(ctx context.Context, target ocispec.Descriptor) (rc io.ReadCloser, err error) { ref := s.repo.Reference ref.Reference = target.Digest.String() - ctx = auth.AppendScopeHints(ctx, ref, auth.ActionPull) + ctx = auth.AppendRepositoryScope(ctx, ref, auth.ActionPull) url := buildRepositoryManifestURL(s.repo.PlainHTTP, ref) req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) if err != nil { @@ -1169,7 +1169,7 @@ func (s *manifestStore) Resolve(ctx context.Context, reference string) (ocispec. if err != nil { return ocispec.Descriptor{}, err } - ctx = auth.AppendScopeHints(ctx, ref, auth.ActionPull) + ctx = auth.AppendRepositoryScope(ctx, ref, auth.ActionPull) url := buildRepositoryManifestURL(s.repo.PlainHTTP, ref) req, err := http.NewRequestWithContext(ctx, http.MethodHead, url, nil) if err != nil { @@ -1201,7 +1201,7 @@ func (s *manifestStore) FetchReference(ctx context.Context, reference string) (d return ocispec.Descriptor{}, nil, err } - ctx = auth.AppendScopeHints(ctx, ref, auth.ActionPull) + ctx = auth.AppendRepositoryScope(ctx, ref, auth.ActionPull) url := buildRepositoryManifestURL(s.repo.PlainHTTP, ref) req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) if err != nil { @@ -1244,7 +1244,7 @@ func (s *manifestStore) Tag(ctx context.Context, desc ocispec.Descriptor, refere return err } - ctx = auth.AppendScopeHints(ctx, ref, auth.ActionPull, auth.ActionPush) + ctx = auth.AppendRepositoryScope(ctx, ref, auth.ActionPull, auth.ActionPush) rc, err := s.Fetch(ctx, desc) if err != nil { return err @@ -1269,7 +1269,7 @@ func (s *manifestStore) push(ctx context.Context, expected ocispec.Descriptor, c ref.Reference = reference // pushing usually requires both pull and push actions. // Reference: https://github.com/distribution/distribution/blob/v2.7.1/registry/handlers/app.go#L921-L930 - ctx = auth.AppendScopeHints(ctx, ref, auth.ActionPull, auth.ActionPush) + ctx = auth.AppendRepositoryScope(ctx, ref, auth.ActionPull, auth.ActionPush) url := buildRepositoryManifestURL(s.repo.PlainHTTP, ref) // unwrap the content for optimizations of built-in types. body := ioutil.UnwrapNopCloser(content) From db4eaa8a86b19d1177063a929da08c9a2489b861 Mon Sep 17 00:00:00 2001 From: "Lixia (Sylvia) Lei" Date: Tue, 26 Sep 2023 19:27:31 +0800 Subject: [PATCH 29/29] expose GetAllScopesForHost Signed-off-by: Lixia (Sylvia) Lei --- registry/remote/auth/client.go | 4 ++-- registry/remote/auth/scope.go | 5 ++--- registry/remote/auth/scope_test.go | 2 +- 3 files changed, 5 insertions(+), 6 deletions(-) diff --git a/registry/remote/auth/client.go b/registry/remote/auth/client.go index 27bf182d..58355161 100644 --- a/registry/remote/auth/client.go +++ b/registry/remote/auth/client.go @@ -187,7 +187,7 @@ func (c *Client) Do(originalReq *http.Request) (*http.Response, error) { req.Header.Set("Authorization", "Basic "+token) } case SchemeBearer: - scopes := getAllScopesForHost(ctx, host) + scopes := GetAllScopesForHost(ctx, host) attemptedKey = strings.Join(scopes, " ") token, err := cache.GetToken(ctx, host, SchemeBearer, attemptedKey) if err == nil { @@ -223,7 +223,7 @@ func (c *Client) Do(originalReq *http.Request) (*http.Response, error) { case SchemeBearer: resp.Body.Close() - scopes := getAllScopesForHost(ctx, host) + scopes := GetAllScopesForHost(ctx, host) if paramScope := params["scope"]; paramScope != "" { // merge hinted scopes with challenged scopes scopes = append(scopes, strings.Split(paramScope, " ")...) diff --git a/registry/remote/auth/scope.go b/registry/remote/auth/scope.go index 8b7ec263..fabc2af2 100644 --- a/registry/remote/auth/scope.go +++ b/registry/remote/auth/scope.go @@ -171,9 +171,9 @@ func GetScopesForHost(ctx context.Context, host string) []string { return nil } -// getAllScopesForHost returns the scopes in the context for the given host, +// GetAllScopesForHost returns the scopes in the context for the given host, // including global scopes added by [WithScopes] and [AppendScopes]. -func getAllScopesForHost(ctx context.Context, host string) []string { +func GetAllScopesForHost(ctx context.Context, host string) []string { scopes := GetScopesForHost(ctx, host) globalScopes := GetScopes(ctx) @@ -183,7 +183,6 @@ func getAllScopesForHost(ctx context.Context, host string) []string { if len(globalScopes) == 0 { return scopes } - // re-clean the scopes allScopes := append(scopes, globalScopes...) return CleanScopes(allScopes) diff --git a/registry/remote/auth/scope_test.go b/registry/remote/auth/scope_test.go index 0c75e128..ca9fe339 100644 --- a/registry/remote/auth/scope_test.go +++ b/registry/remote/auth/scope_test.go @@ -720,7 +720,7 @@ func Test_getAllScopesForHost(t *testing.T) { ctx := context.Background() ctx = WithScopesForHost(ctx, host, tt.scopes...) ctx = WithScopes(ctx, tt.globalScopes...) - if got := getAllScopesForHost(ctx, host); !reflect.DeepEqual(got, tt.want) { + if got := GetAllScopesForHost(ctx, host); !reflect.DeepEqual(got, tt.want) { t.Errorf("getAllScopesForHost() = %v, want %v", got, tt.want) } })