From a428ca67f59b94f7365298870bcac78c769b80bd Mon Sep 17 00:00:00 2001 From: "Lixia (Sylvia) Lei" Date: Wed, 27 Sep 2023 10:57:56 +0800 Subject: [PATCH] feat: support per-host scope hints (#604) The purpose of this PR is to fix the bug, but new APIs are needed to avoid breaking changes. 1. Introduce `auth.WithScopesForHost` 2. Introduce `auth.AppendScopesForHost` 3. Introduce `auth.GetScopesForHost` and `auth.GetAllScopesForHost` 4. Introduce `auth.AppendRepositoryScope` Resolves: #581 Signed-off-by: Lixia (Sylvia) Lei --- content.go | 5 +- internal/registryutil/auth.go | 29 - registry/remote/auth/client.go | 28 +- registry/remote/auth/client_test.go | 1936 +++++++++++++++++++++++++-- registry/remote/auth/scope.go | 94 +- registry/remote/auth/scope_test.go | 277 ++++ registry/remote/registry.go | 2 +- registry/remote/repository.go | 31 +- 8 files changed, 2207 insertions(+), 195 deletions(-) delete mode 100644 internal/registryutil/auth.go diff --git a/content.go b/content.go index 53eb6c75..b8bf2638 100644 --- a/content.go +++ b/content.go @@ -29,7 +29,6 @@ import ( "oras.land/oras-go/v2/internal/docker" "oras.land/oras-go/v2/internal/interfaces" "oras.land/oras-go/v2/internal/platform" - "oras.land/oras-go/v2/internal/registryutil" "oras.land/oras-go/v2/internal/syncutil" "oras.land/oras-go/v2/registry" "oras.land/oras-go/v2/registry/remote/auth" @@ -91,7 +90,7 @@ func TagN(ctx context.Context, target Target, srcReference string, dstReferences if err != nil { return ocispec.Descriptor{}, err } - ctx = registryutil.WithScopeHint(ctx, ref, auth.ActionPull, auth.ActionPush) + ctx = auth.AppendRepositoryScope(ctx, ref, auth.ActionPull, auth.ActionPush) } desc, contentBytes, err := FetchBytes(ctx, target, srcReference, FetchBytesOptions{ @@ -149,7 +148,7 @@ func Tag(ctx context.Context, target Target, src, dst string) (ocispec.Descripto if err != nil { return ocispec.Descriptor{}, err } - ctx = registryutil.WithScopeHint(ctx, ref, auth.ActionPull, auth.ActionPush) + ctx = auth.AppendRepositoryScope(ctx, ref, auth.ActionPull, auth.ActionPush) } desc, rc, err := refFetcher.FetchReference(ctx, src) if err != nil { diff --git a/internal/registryutil/auth.go b/internal/registryutil/auth.go deleted file mode 100644 index 4a601f0c..00000000 --- a/internal/registryutil/auth.go +++ /dev/null @@ -1,29 +0,0 @@ -/* -Copyright The ORAS Authors. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package registryutil - -import ( - "context" - - "oras.land/oras-go/v2/registry" - "oras.land/oras-go/v2/registry/remote/auth" -) - -// WithScopeHint adds a hinted scope to the context. -func WithScopeHint(ctx context.Context, ref registry.Reference, actions ...string) context.Context { - scope := auth.ScopeRepository(ref.Repository, actions...) - return auth.AppendScopes(ctx, scope) -} diff --git a/registry/remote/auth/client.go b/registry/remote/auth/client.go index b4b0261a..58355161 100644 --- a/registry/remote/auth/client.go +++ b/registry/remote/auth/client.go @@ -177,19 +177,19 @@ func (c *Client) Do(originalReq *http.Request) (*http.Response, error) { // attempt cached auth token var attemptedKey string cache := c.cache() - registry := originalReq.Host - scheme, err := cache.GetScheme(ctx, registry) + host := originalReq.Host + scheme, err := cache.GetScheme(ctx, host) if err == nil { switch scheme { case SchemeBasic: - token, err := cache.GetToken(ctx, registry, SchemeBasic, "") + token, err := cache.GetToken(ctx, host, SchemeBasic, "") if err == nil { req.Header.Set("Authorization", "Basic "+token) } case SchemeBearer: - scopes := GetScopes(ctx) + scopes := GetAllScopesForHost(ctx, host) attemptedKey = strings.Join(scopes, " ") - token, err := cache.GetToken(ctx, registry, SchemeBearer, attemptedKey) + token, err := cache.GetToken(ctx, host, SchemeBearer, attemptedKey) if err == nil { req.Header.Set("Authorization", "Bearer "+token) } @@ -211,8 +211,8 @@ func (c *Client) Do(originalReq *http.Request) (*http.Response, error) { case SchemeBasic: resp.Body.Close() - token, err := cache.Set(ctx, registry, SchemeBasic, "", func(ctx context.Context) (string, error) { - return c.fetchBasicAuth(ctx, registry) + token, err := cache.Set(ctx, host, SchemeBasic, "", func(ctx context.Context) (string, error) { + return c.fetchBasicAuth(ctx, host) }) if err != nil { return nil, fmt.Errorf("%s %q: %w", resp.Request.Method, resp.Request.URL, err) @@ -223,17 +223,17 @@ func (c *Client) Do(originalReq *http.Request) (*http.Response, error) { case SchemeBearer: resp.Body.Close() - // merge hinted scopes with challenged scopes - scopes := GetScopes(ctx) - if scope := params["scope"]; scope != "" { - scopes = append(scopes, strings.Split(scope, " ")...) + scopes := GetAllScopesForHost(ctx, host) + if paramScope := params["scope"]; paramScope != "" { + // merge hinted scopes with challenged scopes + scopes = append(scopes, strings.Split(paramScope, " ")...) scopes = CleanScopes(scopes) } key := strings.Join(scopes, " ") // attempt the cache again if there is a scope change if key != attemptedKey { - if token, err := cache.GetToken(ctx, registry, SchemeBearer, key); err == nil { + if token, err := cache.GetToken(ctx, host, SchemeBearer, key); err == nil { req = originalReq.Clone(ctx) req.Header.Set("Authorization", "Bearer "+token) if err := rewindRequestBody(req); err != nil { @@ -254,8 +254,8 @@ func (c *Client) Do(originalReq *http.Request) (*http.Response, error) { // attempt with credentials realm := params["realm"] service := params["service"] - token, err := cache.Set(ctx, registry, SchemeBearer, key, func(ctx context.Context) (string, error) { - return c.fetchBearerToken(ctx, registry, realm, service, scopes) + token, err := cache.Set(ctx, host, SchemeBearer, key, func(ctx context.Context) (string, error) { + return c.fetchBearerToken(ctx, host, realm, service, scopes) }) if err != nil { return nil, fmt.Errorf("%s %q: %w", resp.Request.Method, resp.Request.URL, err) diff --git a/registry/remote/auth/client_test.go b/registry/remote/auth/client_test.go index 9e5ed69d..de879863 100644 --- a/registry/remote/auth/client_test.go +++ b/registry/remote/auth/client_test.go @@ -449,6 +449,205 @@ func TestClient_Do_Bearer_AccessToken_Cached(t *testing.T) { } } +func TestClient_Do_Bearer_AccessToken_Cached_PerHost(t *testing.T) { + as := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + t.Error("unexecuted attempt of authorization service") + w.WriteHeader(http.StatusUnauthorized) + })) + defer as.Close() + // set up server 1 + var requestCount1, wantRequestCount1 int64 + var successCount1, wantSuccessCount1 int64 + var service1 string + scope1 := "repository:test:pull" + accessToken1 := "test/access/token/1" + ts1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt64(&requestCount1, 1) + if r.Method != http.MethodGet || r.URL.Path != "/" { + t.Errorf("unexpected access: %s %s", r.Method, r.URL) + w.WriteHeader(http.StatusNotFound) + return + } + header := "Bearer " + accessToken1 + if auth := r.Header.Get("Authorization"); auth != header { + challenge := fmt.Sprintf("Bearer realm=%q,service=%q,scope=%q", as.URL, service1, scope1) + w.Header().Set("Www-Authenticate", challenge) + w.WriteHeader(http.StatusUnauthorized) + return + } + atomic.AddInt64(&successCount1, 1) + })) + defer ts1.Close() + uri1, err := url.Parse(ts1.URL) + if err != nil { + t.Fatalf("invalid test http server: %v", err) + } + service1 = uri1.Host + client1 := &Client{ + Credential: StaticCredential(uri1.Host, Credential{ + AccessToken: accessToken1, + }), + Cache: NewCache(), + } + + // set up server 2 + var requestCount2, wantRequestCount2 int64 + var successCount2, wantSuccessCount2 int64 + var service2 string + scope2 := "repository:test:pull,push" + accessToken2 := "test/access/token/2" + ts2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt64(&requestCount2, 1) + if r.Method != http.MethodGet || r.URL.Path != "/" { + t.Errorf("unexpected access: %s %s", r.Method, r.URL) + w.WriteHeader(http.StatusNotFound) + return + } + header := "Bearer " + accessToken2 + if auth := r.Header.Get("Authorization"); auth != header { + challenge := fmt.Sprintf("Bearer realm=%q,service=%q,scope=%q", as.URL, service2, scope2) + w.Header().Set("Www-Authenticate", challenge) + w.WriteHeader(http.StatusUnauthorized) + return + } + atomic.AddInt64(&successCount2, 1) + })) + defer ts2.Close() + uri2, err := url.Parse(ts2.URL) + if err != nil { + t.Fatalf("invalid test http server: %v", err) + } + service2 = uri2.Host + client2 := &Client{ + Credential: StaticCredential(uri2.Host, Credential{ + AccessToken: accessToken2, + }), + Cache: NewCache(), + } + + ctx := context.Background() + ctx = WithScopesForHost(ctx, uri1.Host, scope1) + ctx = WithScopesForHost(ctx, uri2.Host, scope2) + // first request to server 1 + req1, err := http.NewRequestWithContext(ctx, http.MethodGet, ts1.URL, nil) + if err != nil { + t.Fatalf("failed to create test request: %v", err) + } + resp1, err := client1.Do(req1) + if err != nil { + t.Fatalf("Client.Do() error = %v", err) + } + if resp1.StatusCode != http.StatusOK { + t.Errorf("Client.Do() = %v, want %v", resp1.StatusCode, http.StatusOK) + } + if wantRequestCount1 += 2; requestCount1 != wantRequestCount1 { + t.Errorf("unexpected number of requests: %d, want %d", requestCount1, wantRequestCount1) + } + if wantSuccessCount1++; successCount1 != wantSuccessCount1 { + t.Errorf("unexpected number of successful requests: %d, want %d", successCount1, wantSuccessCount1) + } + // first request to server 2 + req2, err := http.NewRequestWithContext(ctx, http.MethodGet, ts2.URL, nil) + if err != nil { + t.Fatalf("failed to create test request: %v", err) + } + resp2, err := client2.Do(req2) + if err != nil { + t.Fatalf("Client.Do() error = %v", err) + } + if resp2.StatusCode != http.StatusOK { + t.Errorf("Client.Do() = %v, want %v", resp2.StatusCode, http.StatusOK) + } + if wantRequestCount2 += 2; requestCount1 != wantRequestCount2 { + t.Errorf("unexpected number of requests: %d, want %d", requestCount1, wantRequestCount1) + } + if wantSuccessCount2++; successCount2 != wantSuccessCount2 { + t.Errorf("unexpected number of successful requests: %d, want %d", successCount2, wantSuccessCount2) + } + + // repeated request to server 1 + req1, err = http.NewRequestWithContext(ctx, http.MethodGet, ts1.URL, nil) + if err != nil { + t.Fatalf("failed to create test request: %v", err) + } + resp1, err = client1.Do(req1) + if err != nil { + t.Fatalf("Client.Do() error = %v", err) + } + if resp1.StatusCode != http.StatusOK { + t.Errorf("Client.Do() = %v, want %v", resp1.StatusCode, http.StatusOK) + } + if wantRequestCount1++; requestCount1 != wantRequestCount1 { + t.Errorf("unexpected number of requests: %d, want %d", requestCount1, wantRequestCount1) + } + if wantSuccessCount1++; successCount1 != wantSuccessCount1 { + t.Errorf("unexpected number of successful requests: %d, want %d", successCount1, wantSuccessCount1) + } + // repeated request to server 2 + req2, err = http.NewRequestWithContext(ctx, http.MethodGet, ts2.URL, nil) + if err != nil { + t.Fatalf("failed to create test request: %v", err) + } + resp2, err = client2.Do(req2) + if err != nil { + t.Fatalf("Client.Do() error = %v", err) + } + if resp2.StatusCode != http.StatusOK { + t.Errorf("Client.Do() = %v, want %v", resp2.StatusCode, http.StatusOK) + } + if wantRequestCount2++; requestCount2 != wantRequestCount2 { + t.Errorf("unexpected number of requests: %d, want %d", requestCount2, wantRequestCount2) + } + if wantSuccessCount2++; successCount2 != wantSuccessCount2 { + t.Errorf("unexpected number of successful requests: %d, want %d", successCount2, wantSuccessCount2) + } + + // credential change for server 1 + accessToken1 = "test/access/token/1/new" + req1, err = http.NewRequestWithContext(ctx, http.MethodGet, ts1.URL, nil) + if err != nil { + t.Fatalf("failed to create test request: %v", err) + } + client1.Credential = StaticCredential(uri1.Host, Credential{ + AccessToken: accessToken1, + }) + resp1, err = client1.Do(req1) + if err != nil { + t.Fatalf("Client.Do() error = %v", err) + } + if resp1.StatusCode != http.StatusOK { + t.Errorf("Client.Do() = %v, want %v", resp1.StatusCode, http.StatusOK) + } + if wantRequestCount1 += 2; requestCount1 != wantRequestCount1 { + t.Errorf("unexpected number of requests: %d, want %d", requestCount1, wantRequestCount1) + } + if wantSuccessCount1++; successCount1 != wantSuccessCount1 { + t.Errorf("unexpected number of successful requests: %d, want %d", successCount1, wantSuccessCount1) + } + // credential change for server 2 + accessToken2 = "test/access/token/2/new" + req2, err = http.NewRequestWithContext(ctx, http.MethodGet, ts2.URL, nil) + if err != nil { + t.Fatalf("failed to create test request: %v", err) + } + client2.Credential = StaticCredential(uri2.Host, Credential{ + AccessToken: accessToken2, + }) + resp2, err = client2.Do(req2) + if err != nil { + t.Fatalf("Client.Do() error = %v", err) + } + if resp2.StatusCode != http.StatusOK { + t.Errorf("Client.Do() = %v, want %v", resp2.StatusCode, http.StatusOK) + } + if wantRequestCount2 += 2; requestCount2 != wantRequestCount2 { + t.Errorf("unexpected number of requests: %d, want %d", requestCount2, wantRequestCount2) + } + if wantSuccessCount2++; successCount2 != wantSuccessCount2 { + t.Errorf("unexpected number of successful requests: %d, want %d", successCount2, wantSuccessCount2) + } +} + func TestClient_Do_Bearer_Auth(t *testing.T) { username := "test_user" password := "test_password" @@ -725,6 +924,297 @@ func TestClient_Do_Bearer_Auth_Cached(t *testing.T) { } } +func TestClient_Do_Bearer_Auth_Cached_PerHost(t *testing.T) { + // set up server 1 + username1 := "test_user1" + password1 := "test_password1" + accessToken1 := "test/access/token/1" + var requestCount1, wantRequestCount1 int64 + var successCount1, wantSuccessCount1 int64 + var authCount1, wantAuthCount1 int64 + var service1 string + scopes1 := []string{ + "repository:src:pull", + } + as1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet || r.URL.Path != "/" { + t.Error("unexecuted attempt of authorization service") + w.WriteHeader(http.StatusUnauthorized) + return + } + header := "Basic " + base64.StdEncoding.EncodeToString([]byte(username1+":"+password1)) + if auth := r.Header.Get("Authorization"); auth != header { + t.Errorf("unexpected auth: got %s, want %s", auth, header) + w.WriteHeader(http.StatusUnauthorized) + return + } + if got := r.URL.Query().Get("service"); got != service1 { + t.Errorf("unexpected service: got %s, want %s", got, service1) + w.WriteHeader(http.StatusUnauthorized) + return + } + if got := r.URL.Query()["scope"]; !reflect.DeepEqual(got, scopes1) { + t.Errorf("unexpected scope: got %s, want %s", got, scopes1) + w.WriteHeader(http.StatusUnauthorized) + return + } + + atomic.AddInt64(&authCount1, 1) + if _, err := fmt.Fprintf(w, `{"access_token":%q}`, accessToken1); err != nil { + t.Errorf("failed to write %q: %v", r.URL, err) + } + })) + defer as1.Close() + ts1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt64(&requestCount1, 1) + if r.Method != http.MethodGet || r.URL.Path != "/" { + t.Errorf("unexpected access: %s %s", r.Method, r.URL) + w.WriteHeader(http.StatusNotFound) + return + } + header := "Bearer " + accessToken1 + if auth := r.Header.Get("Authorization"); auth != header { + challenge := fmt.Sprintf("Bearer realm=%q,service=%q,scope=%q", as1.URL, service1, strings.Join(scopes1, " ")) + w.Header().Set("Www-Authenticate", challenge) + w.WriteHeader(http.StatusUnauthorized) + return + } + atomic.AddInt64(&successCount1, 1) + })) + defer ts1.Close() + uri1, err := url.Parse(ts1.URL) + if err != nil { + t.Fatalf("invalid test http server: %v", err) + } + service1 = uri1.Host + client1 := &Client{ + Credential: StaticCredential(uri1.Host, Credential{ + Username: username1, + Password: password1, + }), + Cache: NewCache(), + } + + // set up server 2 + username2 := "test_user2" + password2 := "test_password2" + accessToken2 := "test/access/token/1" + var requestCount2, wantRequestCount2 int64 + var successCount2, wantSuccessCount2 int64 + var authCount2, wantAuthCount2 int64 + var service2 string + scopes2 := []string{ + "repository:dst:pull,push", + } + as2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet || r.URL.Path != "/" { + t.Error("unexecuted attempt of authorization service") + w.WriteHeader(http.StatusUnauthorized) + return + } + header := "Basic " + base64.StdEncoding.EncodeToString([]byte(username2+":"+password2)) + if auth := r.Header.Get("Authorization"); auth != header { + t.Errorf("unexpected auth: got %s, want %s", auth, header) + w.WriteHeader(http.StatusUnauthorized) + return + } + if got := r.URL.Query().Get("service"); got != service2 { + t.Errorf("unexpected service: got %s, want %s", got, service2) + w.WriteHeader(http.StatusUnauthorized) + return + } + if got := r.URL.Query()["scope"]; !reflect.DeepEqual(got, scopes2) { + t.Errorf("unexpected scope: got %s, want %s", got, scopes2) + w.WriteHeader(http.StatusUnauthorized) + return + } + + atomic.AddInt64(&authCount2, 1) + if _, err := fmt.Fprintf(w, `{"access_token":%q}`, accessToken2); err != nil { + t.Errorf("failed to write %q: %v", r.URL, err) + } + })) + defer as1.Close() + ts2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt64(&requestCount2, 1) + if r.Method != http.MethodGet || r.URL.Path != "/" { + t.Errorf("unexpected access: %s %s", r.Method, r.URL) + w.WriteHeader(http.StatusNotFound) + return + } + header := "Bearer " + accessToken2 + if auth := r.Header.Get("Authorization"); auth != header { + challenge := fmt.Sprintf("Bearer realm=%q,service=%q,scope=%q", as2.URL, service2, strings.Join(scopes2, " ")) + w.Header().Set("Www-Authenticate", challenge) + w.WriteHeader(http.StatusUnauthorized) + return + } + atomic.AddInt64(&successCount2, 1) + })) + defer ts2.Close() + uri2, err := url.Parse(ts2.URL) + if err != nil { + t.Fatalf("invalid test http server: %v", err) + } + service2 = uri2.Host + client2 := &Client{ + Credential: StaticCredential(uri2.Host, Credential{ + Username: username2, + Password: password2, + }), + Cache: NewCache(), + } + + ctx := context.Background() + ctx = WithScopesForHost(ctx, uri1.Host, scopes1...) + ctx = WithScopesForHost(ctx, uri2.Host, scopes2...) + // first request to server 1 + req1, err := http.NewRequestWithContext(ctx, http.MethodGet, ts1.URL, nil) + if err != nil { + t.Fatalf("failed to create test request: %v", err) + } + resp1, err := client1.Do(req1) + if err != nil { + t.Fatalf("Client.Do() error = %v", err) + } + if resp1.StatusCode != http.StatusOK { + t.Errorf("Client.Do() = %v, want %v", resp1.StatusCode, http.StatusOK) + } + if wantRequestCount1 += 2; requestCount1 != wantRequestCount1 { + t.Errorf("unexpected number of requests: %d, want %d", requestCount1, wantRequestCount1) + } + if wantSuccessCount1++; successCount1 != wantSuccessCount1 { + t.Errorf("unexpected number of successful requests: %d, want %d", successCount1, wantSuccessCount1) + } + if wantAuthCount1++; authCount1 != wantAuthCount1 { + t.Errorf("unexpected number of auth requests: %d, want %d", authCount1, wantAuthCount1) + } + + // first request to server 2 + req2, err := http.NewRequestWithContext(ctx, http.MethodGet, ts2.URL, nil) + if err != nil { + t.Fatalf("failed to create test request: %v", err) + } + resp2, err := client2.Do(req2) + if err != nil { + t.Fatalf("Client.Do() error = %v", err) + } + if resp2.StatusCode != http.StatusOK { + t.Errorf("Client.Do() = %v, want %v", resp2.StatusCode, http.StatusOK) + } + if wantRequestCount2 += 2; requestCount2 != wantRequestCount2 { + t.Errorf("unexpected number of requests: %d, want %d", requestCount2, wantRequestCount2) + } + if wantSuccessCount2++; successCount2 != wantSuccessCount2 { + t.Errorf("unexpected number of successful requests: %d, want %d", successCount2, wantSuccessCount2) + } + if wantAuthCount2++; authCount2 != wantAuthCount2 { + t.Errorf("unexpected number of auth requests: %d, want %d", authCount2, wantAuthCount2) + } + + // repeated request to server 1 + req1, err = http.NewRequestWithContext(ctx, http.MethodGet, ts1.URL, nil) + if err != nil { + t.Fatalf("failed to create test request: %v", err) + } + resp1, err = client1.Do(req1) + if err != nil { + t.Fatalf("Client.Do() error = %v", err) + } + if resp1.StatusCode != http.StatusOK { + t.Errorf("Client.Do() = %v, want %v", resp1.StatusCode, http.StatusOK) + } + if wantRequestCount1++; requestCount1 != wantRequestCount1 { + t.Errorf("unexpected number of requests: %d, want %d", requestCount1, wantRequestCount1) + } + if wantSuccessCount1++; successCount1 != wantSuccessCount1 { + t.Errorf("unexpected number of successful requests: %d, want %d", successCount1, wantSuccessCount1) + } + if authCount1 != wantAuthCount1 { + t.Errorf("unexpected number of auth requests: %d, want %d", authCount1, wantAuthCount1) + } + + // repeated request to server 2 + req2, err = http.NewRequestWithContext(ctx, http.MethodGet, ts2.URL, nil) + if err != nil { + t.Fatalf("failed to create test request: %v", err) + } + resp2, err = client2.Do(req2) + if err != nil { + t.Fatalf("Client.Do() error = %v", err) + } + if resp2.StatusCode != http.StatusOK { + t.Errorf("Client.Do() = %v, want %v", resp2.StatusCode, http.StatusOK) + } + if wantRequestCount2++; requestCount2 != wantRequestCount2 { + t.Errorf("unexpected number of requests: %d, want %d", requestCount2, wantRequestCount2) + } + if wantSuccessCount2++; successCount2 != wantSuccessCount2 { + t.Errorf("unexpected number of successful requests: %d, want %d", successCount2, wantSuccessCount2) + } + if authCount2 != wantAuthCount2 { + t.Errorf("unexpected number of auth requests: %d, want %d", authCount2, wantAuthCount2) + } + + // credential change for server 1 + username1 = "test_user1_new" + password1 = "test_password1_new" + accessToken1 = "test/access/token/1/new" + req1, err = http.NewRequestWithContext(ctx, http.MethodGet, ts1.URL, nil) + if err != nil { + t.Fatalf("failed to create test request: %v", err) + } + client1.Credential = StaticCredential(uri1.Host, Credential{ + Username: username1, + Password: password1, + }) + resp1, err = client1.Do(req1) + if err != nil { + t.Fatalf("Client.Do() error = %v", err) + } + if resp1.StatusCode != http.StatusOK { + t.Errorf("Client.Do() = %v, want %v", resp1.StatusCode, http.StatusOK) + } + if wantRequestCount1 += 2; requestCount1 != wantRequestCount1 { + t.Errorf("unexpected number of requests: %d, want %d", requestCount1, wantRequestCount1) + } + if wantSuccessCount1++; successCount1 != wantSuccessCount1 { + t.Errorf("unexpected number of successful requests: %d, want %d", successCount1, wantSuccessCount1) + } + if wantAuthCount1++; authCount1 != wantAuthCount1 { + t.Errorf("unexpected number of auth requests: %d, want %d", authCount1, wantAuthCount1) + } + + // credential change for server 2 + username2 = "test_user2_new" + password2 = "test_password2_new" + accessToken2 = "test/access/token/2/new" + req2, err = http.NewRequestWithContext(ctx, http.MethodGet, ts2.URL, nil) + if err != nil { + t.Fatalf("failed to create test request: %v", err) + } + client2.Credential = StaticCredential(uri2.Host, Credential{ + Username: username2, + Password: password2, + }) + resp2, err = client2.Do(req2) + if err != nil { + t.Fatalf("Client.Do() error = %v", err) + } + if resp2.StatusCode != http.StatusOK { + t.Errorf("Client.Do() = %v, want %v", resp2.StatusCode, http.StatusOK) + } + if wantRequestCount2 += 2; requestCount2 != wantRequestCount2 { + t.Errorf("unexpected number of requests: %d, want %d", requestCount2, wantRequestCount2) + } + if wantSuccessCount2++; successCount2 != wantSuccessCount2 { + t.Errorf("unexpected number of successful requests: %d, want %d", successCount2, wantSuccessCount2) + } + if wantAuthCount2++; authCount2 != wantAuthCount2 { + t.Errorf("unexpected number of auth requests: %d, want %d", authCount2, wantAuthCount2) + } +} + func TestClient_Do_Bearer_OAuth2_Password(t *testing.T) { username := "test_user" password := "test_password" @@ -1043,18 +1533,19 @@ func TestClient_Do_Bearer_OAuth2_Password_Cached(t *testing.T) { } } -func TestClient_Do_Bearer_OAuth2_RefreshToken(t *testing.T) { - refreshToken := "test/refresh/token" - accessToken := "test/access/token" - var requestCount, wantRequestCount int64 - var successCount, wantSuccessCount int64 - var authCount, wantAuthCount int64 - var service string - scopes := []string{ - "repository:dst:pull,push", +func TestClient_Do_Bearer_OAuth2_Password_Cached_PerHost(t *testing.T) { + // set up server 1 + username1 := "test_user1" + password1 := "test_password1" + accessToken1 := "test/access/token/1" + var requestCount1, wantRequestCount1 int64 + var successCount1, wantSuccessCount1 int64 + var authCount1, wantAuthCount1 int64 + var service1 string + scopes1 := []string{ "repository:src:pull", } - as := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + as1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost || r.URL.Path != "/" { t.Error("unexecuted attempt of authorization service") w.WriteHeader(http.StatusUnauthorized) @@ -1065,13 +1556,13 @@ func TestClient_Do_Bearer_OAuth2_RefreshToken(t *testing.T) { w.WriteHeader(http.StatusUnauthorized) return } - if got := r.PostForm.Get("grant_type"); got != "refresh_token" { - t.Errorf("unexpected grant type: %v, want %v", got, "refresh_token") + if got := r.PostForm.Get("grant_type"); got != "password" { + t.Errorf("unexpected grant type: %v, want %v", got, "password") w.WriteHeader(http.StatusUnauthorized) return } - if got := r.PostForm.Get("service"); got != service { - t.Errorf("unexpected service: %v, want %v", got, service) + if got := r.PostForm.Get("service"); got != service1 { + t.Errorf("unexpected service: %v, want %v", got, service1) w.WriteHeader(http.StatusUnauthorized) return } @@ -1080,108 +1571,298 @@ func TestClient_Do_Bearer_OAuth2_RefreshToken(t *testing.T) { w.WriteHeader(http.StatusUnauthorized) return } - scope := strings.Join(scopes, " ") + scope := strings.Join(scopes1, " ") if got := r.PostForm.Get("scope"); got != scope { t.Errorf("unexpected scope: %v, want %v", got, scope) w.WriteHeader(http.StatusUnauthorized) return } - if got := r.PostForm.Get("refresh_token"); got != refreshToken { - t.Errorf("unexpected refresh token: %v, want %v", got, refreshToken) + if got := r.PostForm.Get("username"); got != username1 { + t.Errorf("unexpected username: %v, want %v", got, username1) + w.WriteHeader(http.StatusUnauthorized) + return + } + if got := r.PostForm.Get("password"); got != password1 { + t.Errorf("unexpected password: %v, want %v", got, password1) w.WriteHeader(http.StatusUnauthorized) return } - atomic.AddInt64(&authCount, 1) - if _, err := fmt.Fprintf(w, `{"access_token":%q}`, accessToken); err != nil { + atomic.AddInt64(&authCount1, 1) + if _, err := fmt.Fprintf(w, `{"access_token":%q}`, accessToken1); err != nil { t.Errorf("failed to write %q: %v", r.URL, err) } })) - defer as.Close() - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - atomic.AddInt64(&requestCount, 1) + defer as1.Close() + ts1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt64(&requestCount1, 1) if r.Method != http.MethodGet || r.URL.Path != "/" { t.Errorf("unexpected access: %s %s", r.Method, r.URL) w.WriteHeader(http.StatusNotFound) return } - header := "Bearer " + accessToken + header := "Bearer " + accessToken1 if auth := r.Header.Get("Authorization"); auth != header { - challenge := fmt.Sprintf("Bearer realm=%q,service=%q,scope=%q", as.URL, service, strings.Join(scopes, " ")) + challenge := fmt.Sprintf("Bearer realm=%q,service=%q,scope=%q", as1.URL, service1, strings.Join(scopes1, " ")) w.Header().Set("Www-Authenticate", challenge) w.WriteHeader(http.StatusUnauthorized) return } - atomic.AddInt64(&successCount, 1) + atomic.AddInt64(&successCount1, 1) })) - defer ts.Close() - uri, err := url.Parse(ts.URL) + defer ts1.Close() + uri1, err := url.Parse(ts1.URL) if err != nil { t.Fatalf("invalid test http server: %v", err) } - service = uri.Host + service1 = uri1.Host + client1 := &Client{ + Credential: StaticCredential(uri1.Host, Credential{ + Username: username1, + Password: password1, + }), + ForceAttemptOAuth2: true, + Cache: NewCache(), + } + // set up server 2 + username2 := "test_user2" + password2 := "test_password2" + accessToken2 := "test/access/token/2" + var requestCount2, wantRequestCount2 int64 + var successCount2, wantSuccessCount2 int64 + var authCount2, wantAuthCount2 int64 + var service2 string + scopes2 := []string{ + "repository:dst:pull,push", + } + as2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost || r.URL.Path != "/" { + t.Error("unexecuted attempt of authorization service") + w.WriteHeader(http.StatusUnauthorized) + return + } + if err := r.ParseForm(); err != nil { + t.Errorf("failed to parse form: %v", err) + w.WriteHeader(http.StatusUnauthorized) + return + } + if got := r.PostForm.Get("grant_type"); got != "password" { + t.Errorf("unexpected grant type: %v, want %v", got, "password") + w.WriteHeader(http.StatusUnauthorized) + return + } + if got := r.PostForm.Get("service"); got != service2 { + t.Errorf("unexpected service: %v, want %v", got, service2) + w.WriteHeader(http.StatusUnauthorized) + return + } + if got := r.PostForm.Get("client_id"); got != defaultClientID { + t.Errorf("unexpected client id: %v, want %v", got, defaultClientID) + w.WriteHeader(http.StatusUnauthorized) + return + } + scope := strings.Join(scopes2, " ") + if got := r.PostForm.Get("scope"); got != scope { + t.Errorf("unexpected scope: %v, want %v", got, scope) + w.WriteHeader(http.StatusUnauthorized) + return + } + if got := r.PostForm.Get("username"); got != username2 { + t.Errorf("unexpected username: %v, want %v", got, username2) + w.WriteHeader(http.StatusUnauthorized) + return + } + if got := r.PostForm.Get("password"); got != password2 { + t.Errorf("unexpected password: %v, want %v", got, password2) + w.WriteHeader(http.StatusUnauthorized) + return + } - client := &Client{ - Credential: func(ctx context.Context, reg string) (Credential, error) { - if reg != uri.Host { - err := fmt.Errorf("registry mismatch: got %v, want %v", reg, uri.Host) - t.Error(err) - return EmptyCredential, err - } - return Credential{ - RefreshToken: refreshToken, - }, nil - }, + atomic.AddInt64(&authCount2, 1) + if _, err := fmt.Fprintf(w, `{"access_token":%q}`, accessToken2); err != nil { + t.Errorf("failed to write %q: %v", r.URL, err) + } + })) + defer as2.Close() + ts2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt64(&requestCount2, 1) + if r.Method != http.MethodGet || r.URL.Path != "/" { + t.Errorf("unexpected access: %s %s", r.Method, r.URL) + w.WriteHeader(http.StatusNotFound) + return + } + header := "Bearer " + accessToken2 + if auth := r.Header.Get("Authorization"); auth != header { + challenge := fmt.Sprintf("Bearer realm=%q,service=%q,scope=%q", as2.URL, service2, strings.Join(scopes2, " ")) + w.Header().Set("Www-Authenticate", challenge) + w.WriteHeader(http.StatusUnauthorized) + return + } + atomic.AddInt64(&successCount2, 1) + })) + defer ts2.Close() + uri2, err := url.Parse(ts2.URL) + if err != nil { + t.Fatalf("invalid test http server: %v", err) + } + service2 = uri2.Host + client2 := &Client{ + Credential: StaticCredential(uri2.Host, Credential{ + Username: username2, + Password: password2, + }), + ForceAttemptOAuth2: true, + Cache: NewCache(), } - // first request - req, err := http.NewRequest(http.MethodGet, ts.URL, nil) + ctx := context.Background() + ctx = WithScopesForHost(ctx, uri1.Host, scopes1...) + ctx = WithScopesForHost(ctx, uri2.Host, scopes2...) + // first request to server 1 + req1, err := http.NewRequestWithContext(ctx, http.MethodGet, ts1.URL, nil) if err != nil { t.Fatalf("failed to create test request: %v", err) } - resp, err := client.Do(req) + resp1, err := client1.Do(req1) if err != nil { t.Fatalf("Client.Do() error = %v", err) } - if resp.StatusCode != http.StatusOK { - t.Errorf("Client.Do() = %v, want %v", resp.StatusCode, http.StatusOK) + if resp1.StatusCode != http.StatusOK { + t.Errorf("Client.Do() = %v, want %v", resp1.StatusCode, http.StatusOK) } - if wantRequestCount += 2; requestCount != wantRequestCount { - t.Errorf("unexpected number of requests: %d, want %d", requestCount, wantRequestCount) + if wantRequestCount1 += 2; requestCount1 != wantRequestCount1 { + t.Errorf("unexpected number of requests: %d, want %d", requestCount1, wantRequestCount1) } - if wantSuccessCount++; successCount != wantSuccessCount { - t.Errorf("unexpected number of successful requests: %d, want %d", successCount, wantSuccessCount) + if wantSuccessCount1++; successCount1 != wantSuccessCount1 { + t.Errorf("unexpected number of successful requests: %d, want %d", successCount1, wantSuccessCount1) } - if wantAuthCount++; authCount != wantAuthCount { - t.Errorf("unexpected number of auth requests: %d, want %d", authCount, wantAuthCount) + if wantAuthCount1++; authCount1 != wantAuthCount1 { + t.Errorf("unexpected number of auth requests: %d, want %d", authCount1, wantAuthCount1) + } + // first request to server 2 + req2, err := http.NewRequestWithContext(ctx, http.MethodGet, ts2.URL, nil) + if err != nil { + t.Fatalf("failed to create test request: %v", err) + } + resp2, err := client2.Do(req2) + if err != nil { + t.Fatalf("Client.Do() error = %v", err) + } + if resp2.StatusCode != http.StatusOK { + t.Errorf("Client.Do() = %v, want %v", resp2.StatusCode, http.StatusOK) + } + if wantRequestCount2 += 2; requestCount2 != wantRequestCount2 { + t.Errorf("unexpected number of requests: %d, want %d", requestCount2, wantRequestCount2) + } + if wantSuccessCount2++; successCount2 != wantSuccessCount2 { + t.Errorf("unexpected number of successful requests: %d, want %d", successCount2, wantSuccessCount2) + } + if wantAuthCount2++; authCount2 != wantAuthCount2 { + t.Errorf("unexpected number of auth requests: %d, want %d", authCount2, wantAuthCount2) } - // credential change - refreshToken = "test/refresh/token/2" - accessToken = "test/access/token/2" - req, err = http.NewRequest(http.MethodGet, ts.URL, nil) + // repeated request to server 1 + req1, err = http.NewRequestWithContext(ctx, http.MethodGet, ts1.URL, nil) if err != nil { t.Fatalf("failed to create test request: %v", err) } - resp, err = client.Do(req) + resp1, err = client1.Do(req1) if err != nil { t.Fatalf("Client.Do() error = %v", err) } - if resp.StatusCode != http.StatusOK { - t.Errorf("Client.Do() = %v, want %v", resp.StatusCode, http.StatusOK) + if resp1.StatusCode != http.StatusOK { + t.Errorf("Client.Do() = %v, want %v", resp1.StatusCode, http.StatusOK) } - if wantRequestCount += 2; requestCount != wantRequestCount { - t.Errorf("unexpected number of requests: %d, want %d", requestCount, wantRequestCount) + if wantRequestCount1++; requestCount1 != wantRequestCount1 { + t.Errorf("unexpected number of requests: %d, want %d", requestCount1, wantRequestCount1) } - if wantSuccessCount++; successCount != wantSuccessCount { - t.Errorf("unexpected number of successful requests: %d, want %d", successCount, wantSuccessCount) + if wantSuccessCount1++; successCount1 != wantSuccessCount1 { + t.Errorf("unexpected number of successful requests: %d, want %d", successCount1, wantSuccessCount1) } - if wantAuthCount++; authCount != wantAuthCount { - t.Errorf("unexpected number of auth requests: %d, want %d", authCount, wantAuthCount) + if authCount1 != wantAuthCount1 { + t.Errorf("unexpected number of auth requests: %d, want %d", authCount1, wantAuthCount1) + } + // repeated request to server 2 + req2, err = http.NewRequestWithContext(ctx, http.MethodGet, ts2.URL, nil) + if err != nil { + t.Fatalf("failed to create test request: %v", err) + } + resp2, err = client2.Do(req2) + if err != nil { + t.Fatalf("Client.Do() error = %v", err) + } + if resp2.StatusCode != http.StatusOK { + t.Errorf("Client.Do() = %v, want %v", resp2.StatusCode, http.StatusOK) + } + if wantRequestCount2++; requestCount2 != wantRequestCount2 { + t.Errorf("unexpected number of requests: %d, want %d", requestCount2, wantRequestCount2) + } + if wantSuccessCount2++; successCount2 != wantSuccessCount2 { + t.Errorf("unexpected number of successful requests: %d, want %d", successCount2, wantSuccessCount2) + } + if authCount2 != wantAuthCount2 { + t.Errorf("unexpected number of auth requests: %d, want %d", authCount2, wantAuthCount2) + } + + // credential change for server 1 + username1 = "test_user1_new" + password1 = "test_password1_new" + accessToken1 = "test/access/token/1/new" + req1, err = http.NewRequestWithContext(ctx, http.MethodGet, ts1.URL, nil) + if err != nil { + t.Fatalf("failed to create test request: %v", err) + } + client1.Credential = StaticCredential(uri1.Host, Credential{ + Username: username1, + Password: password1, + }) + resp1, err = client1.Do(req1) + if err != nil { + t.Fatalf("Client.Do() error = %v", err) + } + if resp1.StatusCode != http.StatusOK { + t.Errorf("Client.Do() = %v, want %v", resp1.StatusCode, http.StatusOK) + } + if wantRequestCount1 += 2; requestCount1 != wantRequestCount1 { + t.Errorf("unexpected number of requests: %d, want %d", requestCount1, wantRequestCount1) + } + if wantSuccessCount1++; successCount1 != wantSuccessCount1 { + t.Errorf("unexpected number of successful requests: %d, want %d", successCount1, wantSuccessCount1) + } + if wantAuthCount1++; authCount1 != wantAuthCount1 { + t.Errorf("unexpected number of auth requests: %d, want %d", authCount1, wantAuthCount1) + } + // credential change for server 2 + username2 = "test_user2_new" + password2 = "test_password2_new" + accessToken2 = "test/access/token/2/new" + req2, err = http.NewRequestWithContext(ctx, http.MethodGet, ts2.URL, nil) + if err != nil { + t.Fatalf("failed to create test request: %v", err) + } + client2.Credential = StaticCredential(uri2.Host, Credential{ + Username: username2, + Password: password2, + }) + resp2, err = client2.Do(req2) + if err != nil { + t.Fatalf("Client.Do() error = %v", err) + } + if resp2.StatusCode != http.StatusOK { + t.Errorf("Client.Do() = %v, want %v", resp2.StatusCode, http.StatusOK) + } + if wantRequestCount2 += 2; requestCount2 != wantRequestCount2 { + t.Errorf("unexpected number of requests: %d, want %d", requestCount2, wantRequestCount2) + } + if wantSuccessCount2++; successCount2 != wantSuccessCount2 { + t.Errorf("unexpected number of successful requests: %d, want %d", successCount2, wantSuccessCount2) + } + if wantAuthCount2++; authCount2 != wantAuthCount2 { + t.Errorf("unexpected number of auth requests: %d, want %d", authCount2, wantAuthCount2) } } -func TestClient_Do_Bearer_OAuth2_RefreshToken_Cached(t *testing.T) { +func TestClient_Do_Bearer_OAuth2_RefreshToken(t *testing.T) { refreshToken := "test/refresh/token" accessToken := "test/access/token" var requestCount, wantRequestCount int64 @@ -1270,12 +1951,10 @@ func TestClient_Do_Bearer_OAuth2_RefreshToken_Cached(t *testing.T) { RefreshToken: refreshToken, }, nil }, - Cache: NewCache(), } // first request - ctx := WithScopes(context.Background(), scopes...) - req, err := http.NewRequestWithContext(ctx, http.MethodGet, ts.URL, nil) + req, err := http.NewRequest(http.MethodGet, ts.URL, nil) if err != nil { t.Fatalf("failed to create test request: %v", err) } @@ -1296,32 +1975,10 @@ func TestClient_Do_Bearer_OAuth2_RefreshToken_Cached(t *testing.T) { t.Errorf("unexpected number of auth requests: %d, want %d", authCount, wantAuthCount) } - // repeated request - req, err = http.NewRequestWithContext(ctx, http.MethodGet, ts.URL, nil) - if err != nil { - t.Fatalf("failed to create test request: %v", err) - } - resp, err = client.Do(req) - if err != nil { - t.Fatalf("Client.Do() error = %v", err) - } - if resp.StatusCode != http.StatusOK { - t.Errorf("Client.Do() = %v, want %v", resp.StatusCode, http.StatusOK) - } - if wantRequestCount++; requestCount != wantRequestCount { - t.Errorf("unexpected number of requests: %d, want %d", requestCount, wantRequestCount) - } - if wantSuccessCount++; successCount != wantSuccessCount { - t.Errorf("unexpected number of successful requests: %d, want %d", successCount, wantSuccessCount) - } - if authCount != wantAuthCount { - t.Errorf("unexpected number of auth requests: %d, want %d", authCount, wantAuthCount) - } - // credential change refreshToken = "test/refresh/token/2" accessToken = "test/access/token/2" - req, err = http.NewRequestWithContext(ctx, http.MethodGet, ts.URL, nil) + req, err = http.NewRequest(http.MethodGet, ts.URL, nil) if err != nil { t.Fatalf("failed to create test request: %v", err) } @@ -1343,7 +2000,7 @@ func TestClient_Do_Bearer_OAuth2_RefreshToken_Cached(t *testing.T) { } } -func TestClient_Do_Token_Expire(t *testing.T) { +func TestClient_Do_Bearer_OAuth2_RefreshToken_Cached(t *testing.T) { refreshToken := "test/refresh/token" accessToken := "test/access/token" var requestCount, wantRequestCount int64 @@ -1458,7 +2115,30 @@ func TestClient_Do_Token_Expire(t *testing.T) { t.Errorf("unexpected number of auth requests: %d, want %d", authCount, wantAuthCount) } - // invalidate the access token and request again + // repeated request + req, err = http.NewRequestWithContext(ctx, http.MethodGet, ts.URL, nil) + if err != nil { + t.Fatalf("failed to create test request: %v", err) + } + resp, err = client.Do(req) + if err != nil { + t.Fatalf("Client.Do() error = %v", err) + } + if resp.StatusCode != http.StatusOK { + t.Errorf("Client.Do() = %v, want %v", resp.StatusCode, http.StatusOK) + } + if wantRequestCount++; requestCount != wantRequestCount { + t.Errorf("unexpected number of requests: %d, want %d", requestCount, wantRequestCount) + } + if wantSuccessCount++; successCount != wantSuccessCount { + t.Errorf("unexpected number of successful requests: %d, want %d", successCount, wantSuccessCount) + } + if authCount != wantAuthCount { + t.Errorf("unexpected number of auth requests: %d, want %d", authCount, wantAuthCount) + } + + // credential change + refreshToken = "test/refresh/token/2" accessToken = "test/access/token/2" req, err = http.NewRequestWithContext(ctx, http.MethodGet, ts.URL, nil) if err != nil { @@ -1482,20 +2162,18 @@ func TestClient_Do_Token_Expire(t *testing.T) { } } -func TestClient_Do_Scope_Hint_Mismatch(t *testing.T) { - username := "test_user" - password := "test_password" - accessToken := "test/access/token" - var requestCount, wantRequestCount int64 - var successCount, wantSuccessCount int64 - var authCount, wantAuthCount int64 - var service string - scopes := []string{ - "repository:dst:pull,push", +func TestClient_Do_Bearer_OAuth2_RefreshToken_Cached_PerHost(t *testing.T) { + // set up server 1 + refreshToken1 := "test/refresh/token/1" + accessToken1 := "test/access/token/1" + var requestCount1, wantRequestCount1 int64 + var successCount1, wantSuccessCount1 int64 + var authCount1, wantAuthCount1 int64 + var service1 string + scopes1 := []string{ "repository:src:pull", } - scope := "repository:test:delete" - as := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + as1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost || r.URL.Path != "/" { t.Error("unexecuted attempt of authorization service") w.WriteHeader(http.StatusUnauthorized) @@ -1506,13 +2184,13 @@ func TestClient_Do_Scope_Hint_Mismatch(t *testing.T) { w.WriteHeader(http.StatusUnauthorized) return } - if got := r.PostForm.Get("grant_type"); got != "password" { - t.Errorf("unexpected grant type: %v, want %v", got, "password") + if got := r.PostForm.Get("grant_type"); got != "refresh_token" { + t.Errorf("unexpected grant type: %v, want %v", got, "refresh_token") w.WriteHeader(http.StatusUnauthorized) return } - if got := r.PostForm.Get("service"); got != service { - t.Errorf("unexpected service: %v, want %v", got, service) + if got := r.PostForm.Get("service"); got != service1 { + t.Errorf("unexpected service: %v, want %v", got, service1) w.WriteHeader(http.StatusUnauthorized) return } @@ -1521,54 +2199,765 @@ func TestClient_Do_Scope_Hint_Mismatch(t *testing.T) { w.WriteHeader(http.StatusUnauthorized) return } - scopes := CleanScopes(append([]string{scope}, scopes...)) - scope := strings.Join(scopes, " ") + scope := strings.Join(scopes1, " ") if got := r.PostForm.Get("scope"); got != scope { t.Errorf("unexpected scope: %v, want %v", got, scope) w.WriteHeader(http.StatusUnauthorized) return } - if got := r.PostForm.Get("username"); got != username { - t.Errorf("unexpected username: %v, want %v", got, username) - w.WriteHeader(http.StatusUnauthorized) - return - } - if got := r.PostForm.Get("password"); got != password { - t.Errorf("unexpected password: %v, want %v", got, password) + if got := r.PostForm.Get("refresh_token"); got != refreshToken1 { + t.Errorf("unexpected refresh token: %v, want %v", got, refreshToken1) w.WriteHeader(http.StatusUnauthorized) return } - atomic.AddInt64(&authCount, 1) - if _, err := fmt.Fprintf(w, `{"access_token":%q}`, accessToken); err != nil { + atomic.AddInt64(&authCount1, 1) + if _, err := fmt.Fprintf(w, `{"access_token":%q}`, accessToken1); err != nil { t.Errorf("failed to write %q: %v", r.URL, err) } })) - defer as.Close() - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - atomic.AddInt64(&requestCount, 1) + defer as1.Close() + ts1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt64(&requestCount1, 1) if r.Method != http.MethodGet || r.URL.Path != "/" { t.Errorf("unexpected access: %s %s", r.Method, r.URL) w.WriteHeader(http.StatusNotFound) return } - header := "Bearer " + accessToken + header := "Bearer " + accessToken1 if auth := r.Header.Get("Authorization"); auth != header { - challenge := fmt.Sprintf("Bearer realm=%q,service=%q,scope=%q", as.URL, service, scope) + challenge := fmt.Sprintf("Bearer realm=%q,service=%q,scope=%q", as1.URL, service1, strings.Join(scopes1, " ")) w.Header().Set("Www-Authenticate", challenge) w.WriteHeader(http.StatusUnauthorized) return } - atomic.AddInt64(&successCount, 1) + atomic.AddInt64(&successCount1, 1) })) - defer ts.Close() - uri, err := url.Parse(ts.URL) + defer ts1.Close() + uri1, err := url.Parse(ts1.URL) if err != nil { t.Fatalf("invalid test http server: %v", err) } - service = uri.Host - - client := &Client{ + service1 = uri1.Host + client1 := &Client{ + Credential: StaticCredential(uri1.Host, Credential{ + RefreshToken: refreshToken1, + }), + Cache: NewCache(), + } + + // set up server 2 + refreshToken2 := "test/refresh/token/1" + accessToken2 := "test/access/token/1" + var requestCount2, wantRequestCount2 int64 + var successCount2, wantSuccessCount2 int64 + var authCount2, wantAuthCount2 int64 + var service2 string + scopes2 := []string{ + "repository:dst:pull,push", + } + as2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost || r.URL.Path != "/" { + t.Error("unexecuted attempt of authorization service") + w.WriteHeader(http.StatusUnauthorized) + return + } + if err := r.ParseForm(); err != nil { + t.Errorf("failed to parse form: %v", err) + w.WriteHeader(http.StatusUnauthorized) + return + } + if got := r.PostForm.Get("grant_type"); got != "refresh_token" { + t.Errorf("unexpected grant type: %v, want %v", got, "refresh_token") + w.WriteHeader(http.StatusUnauthorized) + return + } + if got := r.PostForm.Get("service"); got != service2 { + t.Errorf("unexpected service: %v, want %v", got, service2) + w.WriteHeader(http.StatusUnauthorized) + return + } + if got := r.PostForm.Get("client_id"); got != defaultClientID { + t.Errorf("unexpected client id: %v, want %v", got, defaultClientID) + w.WriteHeader(http.StatusUnauthorized) + return + } + scope := strings.Join(scopes2, " ") + if got := r.PostForm.Get("scope"); got != scope { + t.Errorf("unexpected scope: %v, want %v", got, scope) + w.WriteHeader(http.StatusUnauthorized) + return + } + if got := r.PostForm.Get("refresh_token"); got != refreshToken2 { + t.Errorf("unexpected refresh token: %v, want %v", got, refreshToken2) + w.WriteHeader(http.StatusUnauthorized) + return + } + + atomic.AddInt64(&authCount2, 1) + if _, err := fmt.Fprintf(w, `{"access_token":%q}`, accessToken2); err != nil { + t.Errorf("failed to write %q: %v", r.URL, err) + } + })) + defer as2.Close() + ts2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt64(&requestCount2, 1) + if r.Method != http.MethodGet || r.URL.Path != "/" { + t.Errorf("unexpected access: %s %s", r.Method, r.URL) + w.WriteHeader(http.StatusNotFound) + return + } + header := "Bearer " + accessToken2 + if auth := r.Header.Get("Authorization"); auth != header { + challenge := fmt.Sprintf("Bearer realm=%q,service=%q,scope=%q", as2.URL, service2, strings.Join(scopes2, " ")) + w.Header().Set("Www-Authenticate", challenge) + w.WriteHeader(http.StatusUnauthorized) + return + } + atomic.AddInt64(&successCount2, 1) + })) + defer ts2.Close() + uri2, err := url.Parse(ts2.URL) + if err != nil { + t.Fatalf("invalid test http server: %v", err) + } + service2 = uri2.Host + client2 := &Client{ + Credential: StaticCredential(uri2.Host, Credential{ + RefreshToken: refreshToken2, + }), + Cache: NewCache(), + } + + ctx := context.Background() + ctx = WithScopesForHost(ctx, uri1.Host, scopes1...) + ctx = WithScopesForHost(ctx, uri2.Host, scopes2...) + // first request to server 1 + req1, err := http.NewRequestWithContext(ctx, http.MethodGet, ts1.URL, nil) + if err != nil { + t.Fatalf("failed to create test request: %v", err) + } + resp1, err := client1.Do(req1) + if err != nil { + t.Fatalf("Client.Do() error = %v", err) + } + if resp1.StatusCode != http.StatusOK { + t.Errorf("Client.Do() = %v, want %v", resp1.StatusCode, http.StatusOK) + } + if wantRequestCount1 += 2; requestCount1 != wantRequestCount1 { + t.Errorf("unexpected number of requests: %d, want %d", requestCount1, wantRequestCount1) + } + if wantSuccessCount1++; successCount1 != wantSuccessCount1 { + t.Errorf("unexpected number of successful requests: %d, want %d", successCount1, wantSuccessCount1) + } + if wantAuthCount1++; authCount1 != wantAuthCount1 { + t.Errorf("unexpected number of auth requests: %d, want %d", authCount1, wantAuthCount1) + } + + // first request to server 2 + req2, err := http.NewRequestWithContext(ctx, http.MethodGet, ts2.URL, nil) + if err != nil { + t.Fatalf("failed to create test request: %v", err) + } + resp2, err := client2.Do(req2) + if err != nil { + t.Fatalf("Client.Do() error = %v", err) + } + if resp2.StatusCode != http.StatusOK { + t.Errorf("Client.Do() = %v, want %v", resp2.StatusCode, http.StatusOK) + } + if wantRequestCount2 += 2; requestCount2 != wantRequestCount2 { + t.Errorf("unexpected number of requests: %d, want %d", requestCount2, wantRequestCount2) + } + if wantSuccessCount2++; successCount2 != wantSuccessCount2 { + t.Errorf("unexpected number of successful requests: %d, want %d", successCount2, wantSuccessCount2) + } + if wantAuthCount2++; authCount2 != wantAuthCount2 { + t.Errorf("unexpected number of auth requests: %d, want %d", authCount2, wantAuthCount2) + } + + // repeated request to server 1 + req1, err = http.NewRequestWithContext(ctx, http.MethodGet, ts1.URL, nil) + if err != nil { + t.Fatalf("failed to create test request: %v", err) + } + resp1, err = client1.Do(req1) + if err != nil { + t.Fatalf("Client.Do() error = %v", err) + } + if resp1.StatusCode != http.StatusOK { + t.Errorf("Client.Do() = %v, want %v", resp1.StatusCode, http.StatusOK) + } + if wantRequestCount1++; requestCount1 != wantRequestCount1 { + t.Errorf("unexpected number of requests: %d, want %d", requestCount1, wantRequestCount1) + } + if wantSuccessCount1++; successCount1 != wantSuccessCount1 { + t.Errorf("unexpected number of successful requests: %d, want %d", successCount1, wantSuccessCount1) + } + if authCount1 != wantAuthCount1 { + t.Errorf("unexpected number of auth requests: %d, want %d", authCount1, wantAuthCount1) + } + // repeated request to server 2 + req2, err = http.NewRequestWithContext(ctx, http.MethodGet, ts2.URL, nil) + if err != nil { + t.Fatalf("failed to create test request: %v", err) + } + resp2, err = client2.Do(req2) + if err != nil { + t.Fatalf("Client.Do() error = %v", err) + } + if resp2.StatusCode != http.StatusOK { + t.Errorf("Client.Do() = %v, want %v", resp2.StatusCode, http.StatusOK) + } + if wantRequestCount2++; requestCount2 != wantRequestCount2 { + t.Errorf("unexpected number of requests: %d, want %d", requestCount2, wantRequestCount2) + } + if wantSuccessCount2++; successCount2 != wantSuccessCount2 { + t.Errorf("unexpected number of successful requests: %d, want %d", successCount2, wantSuccessCount2) + } + if authCount2 != wantAuthCount2 { + t.Errorf("unexpected number of auth requests: %d, want %d", authCount2, wantAuthCount2) + } + + // credential change to server 1 + refreshToken1 = "test/refresh/token/1/new" + accessToken1 = "test/access/token/1/new" + req1, err = http.NewRequestWithContext(ctx, http.MethodGet, ts1.URL, nil) + if err != nil { + t.Fatalf("failed to create test request: %v", err) + } + client1.Credential = StaticCredential(uri1.Host, Credential{ + RefreshToken: refreshToken1, + }) + resp1, err = client1.Do(req1) + if err != nil { + t.Fatalf("Client.Do() error = %v", err) + } + if resp1.StatusCode != http.StatusOK { + t.Errorf("Client.Do() = %v, want %v", resp1.StatusCode, http.StatusOK) + } + if wantRequestCount1 += 2; requestCount1 != wantRequestCount1 { + t.Errorf("unexpected number of requests: %d, want %d", requestCount1, wantRequestCount1) + } + if wantSuccessCount1++; successCount1 != wantSuccessCount1 { + t.Errorf("unexpected number of successful requests: %d, want %d", successCount1, wantSuccessCount1) + } + if wantAuthCount1++; authCount1 != wantAuthCount1 { + t.Errorf("unexpected number of auth requests: %d, want %d", authCount1, wantAuthCount1) + } + // credential change to server 2 + refreshToken2 = "test/refresh/token/2/new" + accessToken2 = "test/access/token/2/new" + req2, err = http.NewRequestWithContext(ctx, http.MethodGet, ts2.URL, nil) + if err != nil { + t.Fatalf("failed to create test request: %v", err) + } + client2.Credential = StaticCredential(uri2.Host, Credential{ + RefreshToken: refreshToken2, + }) + resp2, err = client2.Do(req2) + if err != nil { + t.Fatalf("Client.Do() error = %v", err) + } + if resp2.StatusCode != http.StatusOK { + t.Errorf("Client.Do() = %v, want %v", resp2.StatusCode, http.StatusOK) + } + if wantRequestCount2 += 2; requestCount2 != wantRequestCount2 { + t.Errorf("unexpected number of requests: %d, want %d", requestCount2, wantRequestCount2) + } + if wantSuccessCount2++; successCount2 != wantSuccessCount2 { + t.Errorf("unexpected number of successful requests: %d, want %d", successCount2, wantSuccessCount2) + } + if wantAuthCount2++; authCount2 != wantAuthCount2 { + t.Errorf("unexpected number of auth requests: %d, want %d", authCount2, wantAuthCount2) + } +} + +func TestClient_Do_Token_Expire(t *testing.T) { + refreshToken := "test/refresh/token" + accessToken := "test/access/token" + var requestCount, wantRequestCount int64 + var successCount, wantSuccessCount int64 + var authCount, wantAuthCount int64 + var service string + scopes := []string{ + "repository:dst:pull,push", + "repository:src:pull", + } + as := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost || r.URL.Path != "/" { + t.Error("unexecuted attempt of authorization service") + w.WriteHeader(http.StatusUnauthorized) + return + } + if err := r.ParseForm(); err != nil { + t.Errorf("failed to parse form: %v", err) + w.WriteHeader(http.StatusUnauthorized) + return + } + if got := r.PostForm.Get("grant_type"); got != "refresh_token" { + t.Errorf("unexpected grant type: %v, want %v", got, "refresh_token") + w.WriteHeader(http.StatusUnauthorized) + return + } + if got := r.PostForm.Get("service"); got != service { + t.Errorf("unexpected service: %v, want %v", got, service) + w.WriteHeader(http.StatusUnauthorized) + return + } + if got := r.PostForm.Get("client_id"); got != defaultClientID { + t.Errorf("unexpected client id: %v, want %v", got, defaultClientID) + w.WriteHeader(http.StatusUnauthorized) + return + } + scope := strings.Join(scopes, " ") + if got := r.PostForm.Get("scope"); got != scope { + t.Errorf("unexpected scope: %v, want %v", got, scope) + w.WriteHeader(http.StatusUnauthorized) + return + } + if got := r.PostForm.Get("refresh_token"); got != refreshToken { + t.Errorf("unexpected refresh token: %v, want %v", got, refreshToken) + w.WriteHeader(http.StatusUnauthorized) + return + } + + atomic.AddInt64(&authCount, 1) + if _, err := fmt.Fprintf(w, `{"access_token":%q}`, accessToken); err != nil { + t.Errorf("failed to write %q: %v", r.URL, err) + } + })) + defer as.Close() + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt64(&requestCount, 1) + if r.Method != http.MethodGet || r.URL.Path != "/" { + t.Errorf("unexpected access: %s %s", r.Method, r.URL) + w.WriteHeader(http.StatusNotFound) + return + } + header := "Bearer " + accessToken + if auth := r.Header.Get("Authorization"); auth != header { + challenge := fmt.Sprintf("Bearer realm=%q,service=%q,scope=%q", as.URL, service, strings.Join(scopes, " ")) + w.Header().Set("Www-Authenticate", challenge) + w.WriteHeader(http.StatusUnauthorized) + return + } + atomic.AddInt64(&successCount, 1) + })) + defer ts.Close() + uri, err := url.Parse(ts.URL) + if err != nil { + t.Fatalf("invalid test http server: %v", err) + } + service = uri.Host + + client := &Client{ + Credential: func(ctx context.Context, reg string) (Credential, error) { + if reg != uri.Host { + err := fmt.Errorf("registry mismatch: got %v, want %v", reg, uri.Host) + t.Error(err) + return EmptyCredential, err + } + return Credential{ + RefreshToken: refreshToken, + }, nil + }, + Cache: NewCache(), + } + + // first request + ctx := WithScopes(context.Background(), scopes...) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, ts.URL, nil) + if err != nil { + t.Fatalf("failed to create test request: %v", err) + } + resp, err := client.Do(req) + if err != nil { + t.Fatalf("Client.Do() error = %v", err) + } + if resp.StatusCode != http.StatusOK { + t.Errorf("Client.Do() = %v, want %v", resp.StatusCode, http.StatusOK) + } + if wantRequestCount += 2; requestCount != wantRequestCount { + t.Errorf("unexpected number of requests: %d, want %d", requestCount, wantRequestCount) + } + if wantSuccessCount++; successCount != wantSuccessCount { + t.Errorf("unexpected number of successful requests: %d, want %d", successCount, wantSuccessCount) + } + if wantAuthCount++; authCount != wantAuthCount { + t.Errorf("unexpected number of auth requests: %d, want %d", authCount, wantAuthCount) + } + + // invalidate the access token and request again + accessToken = "test/access/token/2" + req, err = http.NewRequestWithContext(ctx, http.MethodGet, ts.URL, nil) + if err != nil { + t.Fatalf("failed to create test request: %v", err) + } + resp, err = client.Do(req) + if err != nil { + t.Fatalf("Client.Do() error = %v", err) + } + if resp.StatusCode != http.StatusOK { + t.Errorf("Client.Do() = %v, want %v", resp.StatusCode, http.StatusOK) + } + if wantRequestCount += 2; requestCount != wantRequestCount { + t.Errorf("unexpected number of requests: %d, want %d", requestCount, wantRequestCount) + } + if wantSuccessCount++; successCount != wantSuccessCount { + t.Errorf("unexpected number of successful requests: %d, want %d", successCount, wantSuccessCount) + } + if wantAuthCount++; authCount != wantAuthCount { + t.Errorf("unexpected number of auth requests: %d, want %d", authCount, wantAuthCount) + } +} + +func TestClient_Do_Token_Expire_PerHost(t *testing.T) { + // set up server 1 + refreshToken1 := "test/refresh/token/1" + accessToken1 := "test/access/token/1" + var requestCount1, wantRequestCount1 int64 + var successCount1, wantSuccessCount1 int64 + var authCount1, wantAuthCount1 int64 + var service1 string + scopes1 := []string{ + "repository:src:pull", + } + as1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost || r.URL.Path != "/" { + t.Error("unexecuted attempt of authorization service") + w.WriteHeader(http.StatusUnauthorized) + return + } + if err := r.ParseForm(); err != nil { + t.Errorf("failed to parse form: %v", err) + w.WriteHeader(http.StatusUnauthorized) + return + } + if got := r.PostForm.Get("grant_type"); got != "refresh_token" { + t.Errorf("unexpected grant type: %v, want %v", got, "refresh_token") + w.WriteHeader(http.StatusUnauthorized) + return + } + if got := r.PostForm.Get("service"); got != service1 { + t.Errorf("unexpected service: %v, want %v", got, service1) + w.WriteHeader(http.StatusUnauthorized) + return + } + if got := r.PostForm.Get("client_id"); got != defaultClientID { + t.Errorf("unexpected client id: %v, want %v", got, defaultClientID) + w.WriteHeader(http.StatusUnauthorized) + return + } + scope := strings.Join(scopes1, " ") + if got := r.PostForm.Get("scope"); got != scope { + t.Errorf("unexpected scope: %v, want %v", got, scope) + w.WriteHeader(http.StatusUnauthorized) + return + } + if got := r.PostForm.Get("refresh_token"); got != refreshToken1 { + t.Errorf("unexpected refresh token: %v, want %v", got, refreshToken1) + w.WriteHeader(http.StatusUnauthorized) + return + } + + atomic.AddInt64(&authCount1, 1) + if _, err := fmt.Fprintf(w, `{"access_token":%q}`, accessToken1); err != nil { + t.Errorf("failed to write %q: %v", r.URL, err) + } + })) + defer as1.Close() + ts1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt64(&requestCount1, 1) + if r.Method != http.MethodGet || r.URL.Path != "/" { + t.Errorf("unexpected access: %s %s", r.Method, r.URL) + w.WriteHeader(http.StatusNotFound) + return + } + header := "Bearer " + accessToken1 + if auth := r.Header.Get("Authorization"); auth != header { + challenge := fmt.Sprintf("Bearer realm=%q,service=%q,scope=%q", as1.URL, service1, strings.Join(scopes1, " ")) + w.Header().Set("Www-Authenticate", challenge) + w.WriteHeader(http.StatusUnauthorized) + return + } + atomic.AddInt64(&successCount1, 1) + })) + defer ts1.Close() + uri1, err := url.Parse(ts1.URL) + if err != nil { + t.Fatalf("invalid test http server: %v", err) + } + service1 = uri1.Host + client1 := &Client{ + Credential: StaticCredential(uri1.Host, Credential{ + RefreshToken: refreshToken1, + }), + Cache: NewCache(), + } + // set up server 2 + refreshToken2 := "test/refresh/token/2" + accessToken2 := "test/access/token/2" + var requestCount2, wantRequestCount2 int64 + var successCount2, wantSuccessCount2 int64 + var authCount2, wantAuthCount2 int64 + var service2 string + scopes2 := []string{ + "repository:dst:pull,push", + } + as2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost || r.URL.Path != "/" { + t.Error("unexecuted attempt of authorization service") + w.WriteHeader(http.StatusUnauthorized) + return + } + if err := r.ParseForm(); err != nil { + t.Errorf("failed to parse form: %v", err) + w.WriteHeader(http.StatusUnauthorized) + return + } + if got := r.PostForm.Get("grant_type"); got != "refresh_token" { + t.Errorf("unexpected grant type: %v, want %v", got, "refresh_token") + w.WriteHeader(http.StatusUnauthorized) + return + } + if got := r.PostForm.Get("service"); got != service2 { + t.Errorf("unexpected service: %v, want %v", got, service2) + w.WriteHeader(http.StatusUnauthorized) + return + } + if got := r.PostForm.Get("client_id"); got != defaultClientID { + t.Errorf("unexpected client id: %v, want %v", got, defaultClientID) + w.WriteHeader(http.StatusUnauthorized) + return + } + scope := strings.Join(scopes2, " ") + if got := r.PostForm.Get("scope"); got != scope { + t.Errorf("unexpected scope: %v, want %v", got, scope) + w.WriteHeader(http.StatusUnauthorized) + return + } + if got := r.PostForm.Get("refresh_token"); got != refreshToken2 { + t.Errorf("unexpected refresh token: %v, want %v", got, refreshToken2) + w.WriteHeader(http.StatusUnauthorized) + return + } + + atomic.AddInt64(&authCount2, 1) + if _, err := fmt.Fprintf(w, `{"access_token":%q}`, accessToken2); err != nil { + t.Errorf("failed to write %q: %v", r.URL, err) + } + })) + defer as2.Close() + ts2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt64(&requestCount2, 1) + if r.Method != http.MethodGet || r.URL.Path != "/" { + t.Errorf("unexpected access: %s %s", r.Method, r.URL) + w.WriteHeader(http.StatusNotFound) + return + } + header := "Bearer " + accessToken2 + if auth := r.Header.Get("Authorization"); auth != header { + challenge := fmt.Sprintf("Bearer realm=%q,service=%q,scope=%q", as2.URL, service2, strings.Join(scopes2, " ")) + w.Header().Set("Www-Authenticate", challenge) + w.WriteHeader(http.StatusUnauthorized) + return + } + atomic.AddInt64(&successCount2, 1) + })) + defer ts2.Close() + uri2, err := url.Parse(ts2.URL) + if err != nil { + t.Fatalf("invalid test http server: %v", err) + } + service2 = uri2.Host + client2 := &Client{ + Credential: StaticCredential(uri2.Host, Credential{ + RefreshToken: refreshToken2, + }), + Cache: NewCache(), + } + + ctx := context.Background() + ctx = WithScopesForHost(ctx, uri1.Host, scopes1...) + ctx = WithScopesForHost(ctx, uri2.Host, scopes2...) + // first request to server 1 + req1, err := http.NewRequestWithContext(ctx, http.MethodGet, ts1.URL, nil) + if err != nil { + t.Fatalf("failed to create test request: %v", err) + } + resp1, err := client1.Do(req1) + if err != nil { + t.Fatalf("Client.Do() error = %v", err) + } + if resp1.StatusCode != http.StatusOK { + t.Errorf("Client.Do() = %v, want %v", resp1.StatusCode, http.StatusOK) + } + if wantRequestCount1 += 2; requestCount1 != wantRequestCount1 { + t.Errorf("unexpected number of requests: %d, want %d", requestCount1, wantRequestCount1) + } + if wantSuccessCount1++; successCount1 != wantSuccessCount1 { + t.Errorf("unexpected number of successful requests: %d, want %d", successCount1, wantSuccessCount1) + } + if wantAuthCount1++; authCount1 != wantAuthCount1 { + t.Errorf("unexpected number of auth requests: %d, want %d", authCount1, wantAuthCount1) + } + + // first request to server 2 + req2, err := http.NewRequestWithContext(ctx, http.MethodGet, ts2.URL, nil) + if err != nil { + t.Fatalf("failed to create test request: %v", err) + } + resp2, err := client2.Do(req2) + if err != nil { + t.Fatalf("Client.Do() error = %v", err) + } + if resp2.StatusCode != http.StatusOK { + t.Errorf("Client.Do() = %v, want %v", resp2.StatusCode, http.StatusOK) + } + if wantRequestCount2 += 2; requestCount2 != wantRequestCount2 { + t.Errorf("unexpected number of requests: %d, want %d", requestCount2, wantRequestCount2) + } + if wantSuccessCount2++; successCount2 != wantSuccessCount2 { + t.Errorf("unexpected number of successful requests: %d, want %d", successCount2, wantSuccessCount2) + } + if wantAuthCount2++; authCount2 != wantAuthCount2 { + t.Errorf("unexpected number of auth requests: %d, want %d", authCount2, wantAuthCount2) + } + + // invalidate the access token and request again to server 1 + accessToken1 = "test/access/token/1/new" + req1, err = http.NewRequestWithContext(ctx, http.MethodGet, ts1.URL, nil) + if err != nil { + t.Fatalf("failed to create test request: %v", err) + } + resp1, err = client1.Do(req1) + if err != nil { + t.Fatalf("Client.Do() error = %v", err) + } + if resp1.StatusCode != http.StatusOK { + t.Errorf("Client.Do() = %v, want %v", resp1.StatusCode, http.StatusOK) + } + if wantRequestCount1 += 2; requestCount1 != wantRequestCount1 { + t.Errorf("unexpected number of requests: %d, want %d", requestCount1, wantRequestCount1) + } + if wantSuccessCount1++; successCount1 != wantSuccessCount1 { + t.Errorf("unexpected number of successful requests: %d, want %d", successCount1, wantSuccessCount1) + } + if wantAuthCount1++; authCount1 != wantAuthCount1 { + t.Errorf("unexpected number of auth requests: %d, want %d", authCount1, wantAuthCount1) + } + // invalidate the access token and request again to server 2 + accessToken2 = "test/access/token/2/new" + req2, err = http.NewRequestWithContext(ctx, http.MethodGet, ts2.URL, nil) + if err != nil { + t.Fatalf("failed to create test request: %v", err) + } + resp2, err = client2.Do(req2) + if err != nil { + t.Fatalf("Client.Do() error = %v", err) + } + if resp2.StatusCode != http.StatusOK { + t.Errorf("Client.Do() = %v, want %v", resp2.StatusCode, http.StatusOK) + } + if wantRequestCount2 += 2; requestCount2 != wantRequestCount2 { + t.Errorf("unexpected number of requests: %d, want %d", requestCount2, wantRequestCount2) + } + if wantSuccessCount2++; successCount2 != wantSuccessCount2 { + t.Errorf("unexpected number of successful requests: %d, want %d", successCount2, wantSuccessCount2) + } + if wantAuthCount2++; authCount2 != wantAuthCount2 { + t.Errorf("unexpected number of auth requests: %d, want %d", authCount2, wantAuthCount2) + } +} + +func TestClient_Do_Scope_Hint_Mismatch(t *testing.T) { + username := "test_user" + password := "test_password" + accessToken := "test/access/token" + var requestCount, wantRequestCount int64 + var successCount, wantSuccessCount int64 + var authCount, wantAuthCount int64 + var service string + scopes := []string{ + "repository:dst:pull,push", + "repository:src:pull", + } + scope := "repository:test:delete" + as := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost || r.URL.Path != "/" { + t.Error("unexecuted attempt of authorization service") + w.WriteHeader(http.StatusUnauthorized) + return + } + if err := r.ParseForm(); err != nil { + t.Errorf("failed to parse form: %v", err) + w.WriteHeader(http.StatusUnauthorized) + return + } + if got := r.PostForm.Get("grant_type"); got != "password" { + t.Errorf("unexpected grant type: %v, want %v", got, "password") + w.WriteHeader(http.StatusUnauthorized) + return + } + if got := r.PostForm.Get("service"); got != service { + t.Errorf("unexpected service: %v, want %v", got, service) + w.WriteHeader(http.StatusUnauthorized) + return + } + if got := r.PostForm.Get("client_id"); got != defaultClientID { + t.Errorf("unexpected client id: %v, want %v", got, defaultClientID) + w.WriteHeader(http.StatusUnauthorized) + return + } + scopes := CleanScopes(append([]string{scope}, scopes...)) + scope := strings.Join(scopes, " ") + if got := r.PostForm.Get("scope"); got != scope { + t.Errorf("unexpected scope: %v, want %v", got, scope) + w.WriteHeader(http.StatusUnauthorized) + return + } + if got := r.PostForm.Get("username"); got != username { + t.Errorf("unexpected username: %v, want %v", got, username) + w.WriteHeader(http.StatusUnauthorized) + return + } + if got := r.PostForm.Get("password"); got != password { + t.Errorf("unexpected password: %v, want %v", got, password) + w.WriteHeader(http.StatusUnauthorized) + return + } + + atomic.AddInt64(&authCount, 1) + if _, err := fmt.Fprintf(w, `{"access_token":%q}`, accessToken); err != nil { + t.Errorf("failed to write %q: %v", r.URL, err) + } + })) + defer as.Close() + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt64(&requestCount, 1) + if r.Method != http.MethodGet || r.URL.Path != "/" { + t.Errorf("unexpected access: %s %s", r.Method, r.URL) + w.WriteHeader(http.StatusNotFound) + return + } + header := "Bearer " + accessToken + if auth := r.Header.Get("Authorization"); auth != header { + challenge := fmt.Sprintf("Bearer realm=%q,service=%q,scope=%q", as.URL, service, scope) + w.Header().Set("Www-Authenticate", challenge) + w.WriteHeader(http.StatusUnauthorized) + return + } + atomic.AddInt64(&successCount, 1) + })) + defer ts.Close() + uri, err := url.Parse(ts.URL) + if err != nil { + t.Fatalf("invalid test http server: %v", err) + } + service = uri.Host + + client := &Client{ Credential: func(ctx context.Context, reg string) (Credential, error) { if reg != uri.Host { err := fmt.Errorf("registry mismatch: got %v, want %v", reg, uri.Host) @@ -1633,6 +3022,293 @@ func TestClient_Do_Scope_Hint_Mismatch(t *testing.T) { } } +func TestClient_Do_Scope_Hint_Mismatch_PerHost(t *testing.T) { + // set up server 1 + username1 := "test_user1" + password1 := "test_password1" + accessToken1 := "test/access/token/1" + var requestCount1, wantRequestCount1 int64 + var successCount1, wantSuccessCount1 int64 + var authCount1, wantAuthCount1 int64 + var service1 string + scopes1 := []string{ + "repository:dst:pull,push", + "repository:src:pull", + } + scope1 := "repository:test1:delete" + as1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost || r.URL.Path != "/" { + t.Error("unexecuted attempt of authorization service") + w.WriteHeader(http.StatusUnauthorized) + return + } + if err := r.ParseForm(); err != nil { + t.Errorf("failed to parse form: %v", err) + w.WriteHeader(http.StatusUnauthorized) + return + } + if got := r.PostForm.Get("grant_type"); got != "password" { + t.Errorf("unexpected grant type: %v, want %v", got, "password") + w.WriteHeader(http.StatusUnauthorized) + return + } + if got := r.PostForm.Get("service"); got != service1 { + t.Errorf("unexpected service: %v, want %v", got, service1) + w.WriteHeader(http.StatusUnauthorized) + return + } + if got := r.PostForm.Get("client_id"); got != defaultClientID { + t.Errorf("unexpected client id: %v, want %v", got, defaultClientID) + w.WriteHeader(http.StatusUnauthorized) + return + } + scopes := CleanScopes(append([]string{scope1}, scopes1...)) + scope := strings.Join(scopes, " ") + if got := r.PostForm.Get("scope"); got != scope { + t.Errorf("unexpected scope: %v, want %v", got, scope) + w.WriteHeader(http.StatusUnauthorized) + return + } + if got := r.PostForm.Get("username"); got != username1 { + t.Errorf("unexpected username: %v, want %v", got, username1) + w.WriteHeader(http.StatusUnauthorized) + return + } + if got := r.PostForm.Get("password"); got != password1 { + t.Errorf("unexpected password: %v, want %v", got, password1) + w.WriteHeader(http.StatusUnauthorized) + return + } + + atomic.AddInt64(&authCount1, 1) + if _, err := fmt.Fprintf(w, `{"access_token":%q}`, accessToken1); err != nil { + t.Errorf("failed to write %q: %v", r.URL, err) + } + })) + defer as1.Close() + ts1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt64(&requestCount1, 1) + if r.Method != http.MethodGet || r.URL.Path != "/" { + t.Errorf("unexpected access: %s %s", r.Method, r.URL) + w.WriteHeader(http.StatusNotFound) + return + } + header := "Bearer " + accessToken1 + if auth := r.Header.Get("Authorization"); auth != header { + challenge := fmt.Sprintf("Bearer realm=%q,service=%q,scope=%q", as1.URL, service1, scope1) + w.Header().Set("Www-Authenticate", challenge) + w.WriteHeader(http.StatusUnauthorized) + return + } + atomic.AddInt64(&successCount1, 1) + })) + defer ts1.Close() + uri1, err := url.Parse(ts1.URL) + if err != nil { + t.Fatalf("invalid test http server: %v", err) + } + service1 = uri1.Host + client1 := &Client{ + Credential: StaticCredential(uri1.Host, Credential{ + Username: username1, + Password: password1, + }), + ForceAttemptOAuth2: true, + Cache: NewCache(), + } + + // set up server 1 + username2 := "test_user2" + password2 := "test_password2" + accessToken2 := "test/access/token/2" + var requestCount2, wantRequestCount2 int64 + var successCount2, wantSuccessCount2 int64 + var authCount2, wantAuthCount2 int64 + var service2 string + scopes2 := []string{ + "repository:dst:pull,push", + "repository:src:pull", + } + scope2 := "repository:test2:delete" + as2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost || r.URL.Path != "/" { + t.Error("unexecuted attempt of authorization service") + w.WriteHeader(http.StatusUnauthorized) + return + } + if err := r.ParseForm(); err != nil { + t.Errorf("failed to parse form: %v", err) + w.WriteHeader(http.StatusUnauthorized) + return + } + if got := r.PostForm.Get("grant_type"); got != "password" { + t.Errorf("unexpected grant type: %v, want %v", got, "password") + w.WriteHeader(http.StatusUnauthorized) + return + } + if got := r.PostForm.Get("service"); got != service2 { + t.Errorf("unexpected service: %v, want %v", got, service2) + w.WriteHeader(http.StatusUnauthorized) + return + } + if got := r.PostForm.Get("client_id"); got != defaultClientID { + t.Errorf("unexpected client id: %v, want %v", got, defaultClientID) + w.WriteHeader(http.StatusUnauthorized) + return + } + scopes := CleanScopes(append([]string{scope2}, scopes2...)) + scope := strings.Join(scopes, " ") + if got := r.PostForm.Get("scope"); got != scope { + t.Errorf("unexpected scope: %v, want %v", got, scope) + w.WriteHeader(http.StatusUnauthorized) + return + } + if got := r.PostForm.Get("username"); got != username2 { + t.Errorf("unexpected username: %v, want %v", got, username2) + w.WriteHeader(http.StatusUnauthorized) + return + } + if got := r.PostForm.Get("password"); got != password2 { + t.Errorf("unexpected password: %v, want %v", got, password2) + w.WriteHeader(http.StatusUnauthorized) + return + } + + atomic.AddInt64(&authCount2, 1) + if _, err := fmt.Fprintf(w, `{"access_token":%q}`, accessToken2); err != nil { + t.Errorf("failed to write %q: %v", r.URL, err) + } + })) + defer as2.Close() + ts2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt64(&requestCount2, 1) + if r.Method != http.MethodGet || r.URL.Path != "/" { + t.Errorf("unexpected access: %s %s", r.Method, r.URL) + w.WriteHeader(http.StatusNotFound) + return + } + header := "Bearer " + accessToken2 + if auth := r.Header.Get("Authorization"); auth != header { + challenge := fmt.Sprintf("Bearer realm=%q,service=%q,scope=%q", as2.URL, service2, scope2) + w.Header().Set("Www-Authenticate", challenge) + w.WriteHeader(http.StatusUnauthorized) + return + } + atomic.AddInt64(&successCount2, 1) + })) + defer ts1.Close() + uri2, err := url.Parse(ts2.URL) + if err != nil { + t.Fatalf("invalid test http server: %v", err) + } + service2 = uri2.Host + client2 := &Client{ + Credential: StaticCredential(uri2.Host, Credential{ + Username: username2, + Password: password2, + }), + ForceAttemptOAuth2: true, + Cache: NewCache(), + } + + ctx := context.Background() + ctx = WithScopesForHost(ctx, uri1.Host, scopes1...) + ctx = WithScopesForHost(ctx, uri2.Host, scopes2...) + // first request to server 1 + req1, err := http.NewRequestWithContext(ctx, http.MethodGet, ts1.URL, nil) + if err != nil { + t.Fatalf("failed to create test request: %v", err) + } + resp1, err := client1.Do(req1) + if err != nil { + t.Fatalf("Client.Do() error = %v", err) + } + if resp1.StatusCode != http.StatusOK { + t.Errorf("Client.Do() = %v, want %v", resp1.StatusCode, http.StatusOK) + } + if wantRequestCount1 += 2; requestCount1 != wantRequestCount1 { + t.Errorf("unexpected number of requests: %d, want %d", requestCount1, wantRequestCount1) + } + if wantSuccessCount1++; successCount1 != wantSuccessCount1 { + t.Errorf("unexpected number of successful requests: %d, want %d", successCount1, wantSuccessCount1) + } + if wantAuthCount1++; authCount1 != wantAuthCount1 { + t.Errorf("unexpected number of auth requests: %d, want %d", authCount1, wantAuthCount1) + } + + // first request to server 1 + req2, err := http.NewRequestWithContext(ctx, http.MethodGet, ts2.URL, nil) + if err != nil { + t.Fatalf("failed to create test request: %v", err) + } + resp2, err := client2.Do(req2) + if err != nil { + t.Fatalf("Client.Do() error = %v", err) + } + if resp2.StatusCode != http.StatusOK { + t.Errorf("Client.Do() = %v, want %v", resp2.StatusCode, http.StatusOK) + } + if wantRequestCount2 += 2; requestCount2 != wantRequestCount2 { + t.Errorf("unexpected number of requests: %d, want %d", requestCount2, wantRequestCount2) + } + if wantSuccessCount2++; successCount2 != wantSuccessCount2 { + t.Errorf("unexpected number of successful requests: %d, want %d", successCount2, wantSuccessCount2) + } + if wantAuthCount2++; authCount2 != wantAuthCount2 { + t.Errorf("unexpected number of auth requests: %d, want %d", authCount2, wantAuthCount2) + } + + // repeated request to server 1 + // although the actual scope does not match the hinted scopes, the client + // with cache cannot avoid a request to obtain a challenge but can prevent + // a repeated call to the authorization server. + req1, err = http.NewRequestWithContext(ctx, http.MethodGet, ts1.URL, nil) + if err != nil { + t.Fatalf("failed to create test request: %v", err) + } + resp1, err = client1.Do(req1) + if err != nil { + t.Fatalf("Client.Do() error = %v", err) + } + if resp1.StatusCode != http.StatusOK { + t.Errorf("Client.Do() = %v, want %v", resp1.StatusCode, http.StatusOK) + } + if wantRequestCount1 += 2; requestCount1 != wantRequestCount1 { + t.Errorf("unexpected number of requests: %d, want %d", requestCount1, wantRequestCount1) + } + if wantSuccessCount1++; successCount1 != wantSuccessCount1 { + t.Errorf("unexpected number of successful requests: %d, want %d", successCount1, wantSuccessCount1) + } + if authCount1 != wantAuthCount1 { + t.Errorf("unexpected number of auth requests: %d, want %d", authCount1, wantAuthCount1) + } + + // repeated request to server 2 + // although the actual scope does not match the hinted scopes, the client + // with cache cannot avoid a request to obtain a challenge but can prevent + // a repeated call to the authorization server. + req2, err = http.NewRequestWithContext(ctx, http.MethodGet, ts2.URL, nil) + if err != nil { + t.Fatalf("failed to create test request: %v", err) + } + resp2, err = client2.Do(req2) + if err != nil { + t.Fatalf("Client.Do() error = %v", err) + } + if resp2.StatusCode != http.StatusOK { + t.Errorf("Client.Do() = %v, want %v", resp2.StatusCode, http.StatusOK) + } + if wantRequestCount2 += 2; requestCount2 != wantRequestCount2 { + t.Errorf("unexpected number of requests: %d, want %d", requestCount2, wantRequestCount2) + } + if wantSuccessCount2++; successCount2 != wantSuccessCount2 { + t.Errorf("unexpected number of successful requests: %d, want %d", successCount2, wantSuccessCount2) + } + if authCount2 != wantAuthCount2 { + t.Errorf("unexpected number of auth requests: %d, want %d", authCount2, wantAuthCount2) + } +} + func TestClient_Do_Invalid_Credential_Basic(t *testing.T) { username := "test_user" password := "test_password" diff --git a/registry/remote/auth/scope.go b/registry/remote/auth/scope.go index 24a0f898..fabc2af2 100644 --- a/registry/remote/auth/scope.go +++ b/registry/remote/auth/scope.go @@ -19,6 +19,9 @@ import ( "context" "sort" "strings" + + "oras.land/oras-go/v2/internal/slices" + "oras.land/oras-go/v2/registry" ) // Actions used in scopes. @@ -54,6 +57,28 @@ func ScopeRepository(repository string, actions ...string) string { }, ":") } +// AppendRepositoryScope returns a new context containing scope hints for the +// auth client to fetch bearer tokens with the given actions on the repository. +// If called multiple times, the new scopes will be appended to the existing +// scopes. The resulted scopes are de-duplicated. +// +// For example, uploading blob to the repository "hello-world" does HEAD request +// first then POST and PUT. The HEAD request will return a challenge for scope +// `repository:hello-world:pull`, and the auth client will fetch a token for +// that challenge. Later, the POST request will return a challenge for scope +// `repository:hello-world:push`, and the auth client will fetch a token for +// that challenge again. By invoking AppendRepositoryScope with the actions +// [ActionPull] and [ActionPush] for the repository `hello-world`, +// the auth client with cache is hinted to fetch a token via a single token +// fetch request for all the HEAD, POST, PUT requests. +func AppendRepositoryScope(ctx context.Context, ref registry.Reference, actions ...string) context.Context { + if len(actions) == 0 { + return ctx + } + scope := ScopeRepository(ref.Repository, actions...) + return AppendScopesForHost(ctx, ref.Host(), scope) +} + // scopesContextKey is the context key for scopes. type scopesContextKey struct{} @@ -66,7 +91,7 @@ type scopesContextKey struct{} // `repository:hello-world:pull`, and the auth client will fetch a token for // that challenge. Later, the POST request will return a challenge for scope // `repository:hello-world:push`, and the auth client will fetch a token for -// that challenge again. By invoking `WithScopes()` with the scope +// that challenge again. By invoking WithScopes with the scope // `repository:hello-world:pull,push`, the auth client with cache is hinted to // fetch a token via a single token fetch request for all the HEAD, POST, PUT // requests. @@ -93,11 +118,76 @@ func AppendScopes(ctx context.Context, scopes ...string) context.Context { // GetScopes returns the scopes in the context. func GetScopes(ctx context.Context) []string { if scopes, ok := ctx.Value(scopesContextKey{}).([]string); ok { - return append([]string(nil), scopes...) + return slices.Clone(scopes) + } + return nil +} + +// scopesForHostContextKey is the context key for per-host scopes. +type scopesForHostContextKey string + +// WithScopesForHost returns a context with per-host scopes added. +// Scopes are de-duplicated. +// Scopes are used as hints for the auth client to fetch bearer tokens with +// larger scopes. +// +// For example, uploading blob to the repository "hello-world" does HEAD request +// first then POST and PUT. The HEAD request will return a challenge for scope +// `repository:hello-world:pull`, and the auth client will fetch a token for +// that challenge. Later, the POST request will return a challenge for scope +// `repository:hello-world:push`, and the auth client will fetch a token for +// that challenge again. By invoking WithScopesForHost with the scope +// `repository:hello-world:pull,push`, the auth client with cache is hinted to +// fetch a token via a single token fetch request for all the HEAD, POST, PUT +// requests. +// +// Passing an empty list of scopes will virtually remove the scope hints in the +// context for the given host. +// +// Reference: https://docs.docker.com/registry/spec/auth/scope/ +func WithScopesForHost(ctx context.Context, host string, scopes ...string) context.Context { + scopes = CleanScopes(scopes) + return context.WithValue(ctx, scopesForHostContextKey(host), scopes) +} + +// AppendScopesForHost appends additional scopes to the existing scopes +// in the context for the given host and returns a new context. +// The resulted scopes are de-duplicated. +// The append operation does modify the existing scope in the context passed in. +func AppendScopesForHost(ctx context.Context, host string, scopes ...string) context.Context { + if len(scopes) == 0 { + return ctx + } + oldScopes := GetScopesForHost(ctx, host) + return WithScopesForHost(ctx, host, append(oldScopes, scopes...)...) +} + +// GetScopesForHost returns the scopes in the context for the given host, +// excluding global scopes added by [WithScopes] and [AppendScopes]. +func GetScopesForHost(ctx context.Context, host string) []string { + if scopes, ok := ctx.Value(scopesForHostContextKey(host)).([]string); ok { + return slices.Clone(scopes) } return nil } +// GetAllScopesForHost returns the scopes in the context for the given host, +// including global scopes added by [WithScopes] and [AppendScopes]. +func GetAllScopesForHost(ctx context.Context, host string) []string { + scopes := GetScopesForHost(ctx, host) + globalScopes := GetScopes(ctx) + + if len(scopes) == 0 { + return globalScopes + } + if len(globalScopes) == 0 { + return scopes + } + // re-clean the scopes + allScopes := append(scopes, globalScopes...) + return CleanScopes(allScopes) +} + // CleanScopes merges and sort the actions in ascending order if the scopes have // the same resource type and name. The final scopes are sorted in ascending // order. In other words, the scopes passed in are de-duplicated and sorted. diff --git a/registry/remote/auth/scope_test.go b/registry/remote/auth/scope_test.go index ac41ad7b..ca9fe339 100644 --- a/registry/remote/auth/scope_test.go +++ b/registry/remote/auth/scope_test.go @@ -19,6 +19,8 @@ import ( "context" "reflect" "testing" + + "oras.land/oras-go/v2/registry" ) func TestScopeRepository(t *testing.T) { @@ -103,6 +105,70 @@ func TestScopeRepository(t *testing.T) { } } +func TestWithScopeHints(t *testing.T) { + ctx := context.Background() + ref1, err := registry.ParseReference("registry.example.com/foo") + if err != nil { + t.Fatal("registry.ParseReference() error =", err) + } + ref2, err := registry.ParseReference("docker.io/foo") + if err != nil { + t.Fatal("registry.ParseReference() error =", err) + } + + // with single scope + want1 := []string{ + "repository:foo:pull", + } + want2 := []string{ + "repository:foo:push", + } + ctx = AppendRepositoryScope(ctx, ref1, ActionPull) + ctx = AppendRepositoryScope(ctx, ref2, ActionPush) + if got := GetScopesForHost(ctx, ref1.Host()); !reflect.DeepEqual(got, want1) { + t.Errorf("GetScopesPerRegistry(WithScopeHints()) = %v, want %v", got, want1) + } + if got := GetScopesForHost(ctx, ref2.Host()); !reflect.DeepEqual(got, want2) { + t.Errorf("GetScopesPerRegistry(WithScopeHints()) = %v, want %v", got, want2) + } + + // with duplicated scopes + scopes1 := []string{ + ActionDelete, + ActionDelete, + ActionPull, + } + want1 = []string{ + "repository:foo:delete,pull", + } + scopes2 := []string{ + ActionPush, + ActionPush, + ActionDelete, + } + want2 = []string{ + "repository:foo:delete,push", + } + ctx = AppendRepositoryScope(ctx, ref1, scopes1...) + ctx = AppendRepositoryScope(ctx, ref2, scopes2...) + if got := GetScopesForHost(ctx, ref1.Host()); !reflect.DeepEqual(got, want1) { + t.Errorf("GetScopesPerRegistry(WithScopeHints()) = %v, want %v", got, want1) + } + if got := GetScopesForHost(ctx, ref2.Host()); !reflect.DeepEqual(got, want2) { + t.Errorf("GetScopesPerRegistry(WithScopeHints()) = %v, want %v", got, want2) + } + + // append empty scopes + ctx = AppendRepositoryScope(ctx, ref1) + ctx = AppendRepositoryScope(ctx, ref2) + if got := GetScopesForHost(ctx, ref1.Host()); !reflect.DeepEqual(got, want1) { + t.Errorf("GetScopesPerRegistry(WithScopeHints()) = %v, want %v", got, want1) + } + if got := GetScopesForHost(ctx, ref2.Host()); !reflect.DeepEqual(got, want2) { + t.Errorf("GetScopesPerRegistry(WithScopeHints()) = %v, want %v", got, want2) + } +} + func TestWithScopes(t *testing.T) { ctx := context.Background() @@ -184,6 +250,149 @@ func TestAppendScopes(t *testing.T) { } } +func TestWithScopesPerHost(t *testing.T) { + ctx := context.Background() + reg1 := "registry1.example.com" + reg2 := "registry2.example.com" + + // with single scope + want1 := []string{ + "repository:foo:pull", + } + want2 := []string{ + "repository:foo:push", + } + ctx = WithScopesForHost(ctx, reg1, want1...) + ctx = WithScopesForHost(ctx, reg2, want2...) + if got := GetScopesForHost(ctx, reg1); !reflect.DeepEqual(got, want1) { + t.Errorf("GetScopesPerRegistry(WithScopesPerRegistry()) = %v, want %v", got, want1) + } + if got := GetScopesForHost(ctx, reg2); !reflect.DeepEqual(got, want2) { + t.Errorf("GetScopesPerRegistry(WithScopesPerRegistry()) = %v, want %v", got, want2) + } + + // overwrite scopes + want1 = []string{ + "repository:bar:push", + } + want2 = []string{ + "repository:bar:pull", + } + ctx = WithScopesForHost(ctx, reg1, want1...) + ctx = WithScopesForHost(ctx, reg2, want2...) + if got := GetScopesForHost(ctx, reg1); !reflect.DeepEqual(got, want1) { + t.Errorf("GetScopesPerRegistry(WithScopesPerRegistry()) = %v, want %v", got, want1) + } + if got := GetScopesForHost(ctx, reg2); !reflect.DeepEqual(got, want2) { + t.Errorf("GetScopesPerRegistry(WithScopesPerRegistry()) = %v, want %v", got, want2) + } + + // overwrite scopes with de-duplication + scopes1 := []string{ + "repository:hello-world:push", + "repository:alpine:delete", + "repository:hello-world:pull", + "repository:alpine:delete", + } + want1 = []string{ + "repository:alpine:delete", + "repository:hello-world:pull,push", + } + scopes2 := []string{ + "repository:goodbye-world:push", + "repository:nginx:delete", + "repository:goodbye-world:pull", + "repository:nginx:delete", + } + want2 = []string{ + "repository:goodbye-world:pull,push", + "repository:nginx:delete", + } + ctx = WithScopesForHost(ctx, reg1, scopes1...) + ctx = WithScopesForHost(ctx, reg2, scopes2...) + if got := GetScopesForHost(ctx, reg1); !reflect.DeepEqual(got, want1) { + t.Errorf("GetScopesPerRegistry(WithScopesPerRegistry()) = %v, want %v", got, want1) + } + if got := GetScopesForHost(ctx, reg2); !reflect.DeepEqual(got, want2) { + t.Errorf("GetScopesPerRegistry(WithScopesPerRegistry()) = %v, want %v", got, want2) + } + + // clean scopes + var want []string + ctx = WithScopesForHost(ctx, reg1, want...) + ctx = WithScopesForHost(ctx, reg2, want...) + if got := GetScopesForHost(ctx, reg1); !reflect.DeepEqual(got, want) { + t.Errorf("GetScopesPerRegistry(WithScopesPerRegistry()) = %v, want %v", got, want) + } + if got := GetScopesForHost(ctx, reg2); !reflect.DeepEqual(got, want) { + t.Errorf("GetScopesPerRegistry(WithScopesPerRegistry()) = %v, want %v", got, want) + } +} + +func TestAppendScopesPerHost(t *testing.T) { + ctx := context.Background() + reg1 := "registry1.example.com" + reg2 := "registry2.example.com" + + // with single scope + want1 := []string{ + "repository:foo:pull", + } + want2 := []string{ + "repository:foo:push", + } + ctx = AppendScopesForHost(ctx, reg1, want1...) + ctx = AppendScopesForHost(ctx, reg2, want2...) + if got := GetScopesForHost(ctx, reg1); !reflect.DeepEqual(got, want1) { + t.Errorf("GetScopesPerRegistry(AppendScopesPerRegistry()) = %v, want %v", got, want1) + } + if got := GetScopesForHost(ctx, reg2); !reflect.DeepEqual(got, want2) { + t.Errorf("GetScopesPerRegistry(AppendScopesPerRegistry()) = %v, want %v", got, want2) + } + + // append scopes with de-duplication + scopes1 := []string{ + "repository:hello-world:push", + "repository:alpine:delete", + "repository:hello-world:pull", + "repository:alpine:delete", + } + want1 = []string{ + "repository:alpine:delete", + "repository:foo:pull", + "repository:hello-world:pull,push", + } + scopes2 := []string{ + "repository:goodbye-world:push", + "repository:nginx:delete", + "repository:goodbye-world:pull", + "repository:nginx:delete", + } + want2 = []string{ + "repository:foo:push", + "repository:goodbye-world:pull,push", + "repository:nginx:delete", + } + ctx = AppendScopesForHost(ctx, reg1, scopes1...) + ctx = AppendScopesForHost(ctx, reg2, scopes2...) + if got := GetScopesForHost(ctx, reg1); !reflect.DeepEqual(got, want1) { + t.Errorf("GetScopesPerRegistry(AppendScopesPerRegistry()) = %v, want %v", got, want1) + } + if got := GetScopesForHost(ctx, reg2); !reflect.DeepEqual(got, want2) { + t.Errorf("GetScopesPerRegistry(AppendScopesPerRegistry()) = %v, want %v", got, want2) + } + + // append empty scopes + ctx = AppendScopesForHost(ctx, reg1) + ctx = AppendScopesForHost(ctx, reg2) + if got := GetScopesForHost(ctx, reg1); !reflect.DeepEqual(got, want1) { + t.Errorf("GetScopesPerRegistry(AppendScopesPerRegistry()) = %v, want %v", got, want1) + } + if got := GetScopesForHost(ctx, reg2); !reflect.DeepEqual(got, want2) { + t.Errorf("GetScopesPerRegistry(AppendScopesPerRegistry()) = %v, want %v", got, want2) + } +} + func TestCleanScopes(t *testing.T) { tests := []struct { name string @@ -449,3 +658,71 @@ func Test_cleanActions(t *testing.T) { }) } } + +func Test_getAllScopesForHost(t *testing.T) { + host := "registry.example.com" + tests := []struct { + name string + scopes []string + globalScopes []string + want []string + }{ + { + name: "Empty per-host scopes", + scopes: []string{}, + globalScopes: []string{ + "repository:hello-world:push", + "repository:alpine:delete", + "repository:hello-world:pull", + "repository:alpine:delete", + }, + want: []string{ + "repository:alpine:delete", + "repository:hello-world:pull,push", + }, + }, + { + name: "Empty global scopes", + scopes: []string{ + "repository:hello-world:push", + "repository:alpine:delete", + "repository:hello-world:pull", + "repository:alpine:delete", + }, + globalScopes: []string{}, + want: []string{ + "repository:alpine:delete", + "repository:hello-world:pull,push", + }, + }, + { + name: "Per-host scopes + global scopes", + scopes: []string{ + "repository:hello-world:push", + "repository:alpine:delete", + "repository:hello-world:pull", + "repository:alpine:delete", + }, + globalScopes: []string{ + "repository:foo:pull", + "repository:hello-world:pull", + "repository:alpine:pull", + }, + want: []string{ + "repository:alpine:delete,pull", + "repository:foo:pull", + "repository:hello-world:pull,push", + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := context.Background() + ctx = WithScopesForHost(ctx, host, tt.scopes...) + ctx = WithScopes(ctx, tt.globalScopes...) + if got := GetAllScopesForHost(ctx, host); !reflect.DeepEqual(got, tt.want) { + t.Errorf("getAllScopesForHost() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/registry/remote/registry.go b/registry/remote/registry.go index 8ae538d9..d1334042 100644 --- a/registry/remote/registry.go +++ b/registry/remote/registry.go @@ -127,7 +127,7 @@ func (r *Registry) Ping(ctx context.Context) error { // // Reference: https://docs.docker.com/registry/spec/api/#catalog func (r *Registry) Repositories(ctx context.Context, last string, fn func(repos []string) error) error { - ctx = auth.AppendScopes(ctx, auth.ScopeRegistryCatalog) + ctx = auth.AppendScopesForHost(ctx, r.Reference.Host(), auth.ScopeRegistryCatalog) url := buildRegistryCatalogURL(r.PlainHTTP, r.Reference) var err error for err == nil { diff --git a/registry/remote/repository.go b/registry/remote/repository.go index fc4f6bf6..5373492b 100644 --- a/registry/remote/repository.go +++ b/registry/remote/repository.go @@ -37,7 +37,6 @@ import ( "oras.land/oras-go/v2/internal/cas" "oras.land/oras-go/v2/internal/httputil" "oras.land/oras-go/v2/internal/ioutil" - "oras.land/oras-go/v2/internal/registryutil" "oras.land/oras-go/v2/internal/slices" "oras.land/oras-go/v2/internal/spec" "oras.land/oras-go/v2/internal/syncutil" @@ -392,7 +391,7 @@ func (r *Repository) ParseReference(reference string) (registry.Reference, error // - https://github.com/opencontainers/distribution-spec/blob/v1.1.0-rc3/spec.md#content-discovery // - https://docs.docker.com/registry/spec/api/#tags func (r *Repository) Tags(ctx context.Context, last string, fn func(tags []string) error) error { - ctx = registryutil.WithScopeHint(ctx, r.Reference, auth.ActionPull) + ctx = auth.AppendRepositoryScope(ctx, r.Reference, auth.ActionPull) url := buildRepositoryTagListURL(r.PlainHTTP, r.Reference) var err error for err == nil { @@ -509,7 +508,7 @@ func (r *Repository) Referrers(ctx context.Context, desc ocispec.Descriptor, art func (r *Repository) referrersByAPI(ctx context.Context, desc ocispec.Descriptor, artifactType string, fn func(referrers []ocispec.Descriptor) error) error { ref := r.Reference ref.Reference = desc.Digest.String() - ctx = registryutil.WithScopeHint(ctx, ref, auth.ActionPull) + ctx = auth.AppendRepositoryScope(ctx, ref, auth.ActionPull) url := buildReferrersURL(r.PlainHTTP, ref, artifactType) var err error @@ -643,7 +642,7 @@ func (r *Repository) pingReferrers(ctx context.Context) (bool, error) { ref := r.Reference ref.Reference = zeroDigest - ctx = registryutil.WithScopeHint(ctx, ref, auth.ActionPull) + ctx = auth.AppendRepositoryScope(ctx, ref, auth.ActionPull) url := buildReferrersURL(r.PlainHTTP, ref, "") req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) @@ -677,7 +676,7 @@ func (r *Repository) pingReferrers(ctx context.Context) (bool, error) { func (r *Repository) delete(ctx context.Context, target ocispec.Descriptor, isManifest bool) error { ref := r.Reference ref.Reference = target.Digest.String() - ctx = registryutil.WithScopeHint(ctx, ref, auth.ActionDelete) + ctx = auth.AppendRepositoryScope(ctx, ref, auth.ActionDelete) buildURL := buildRepositoryBlobURL if isManifest { buildURL = buildRepositoryManifestURL @@ -713,7 +712,7 @@ type blobStore struct { func (s *blobStore) Fetch(ctx context.Context, target ocispec.Descriptor) (rc io.ReadCloser, err error) { ref := s.repo.Reference ref.Reference = target.Digest.String() - ctx = registryutil.WithScopeHint(ctx, ref, auth.ActionPull) + ctx = auth.AppendRepositoryScope(ctx, ref, auth.ActionPull) url := buildRepositoryBlobURL(s.repo.PlainHTTP, ref) req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) if err != nil { @@ -755,12 +754,12 @@ func (s *blobStore) Fetch(ctx context.Context, target ocispec.Descriptor) (rc io func (s *blobStore) Mount(ctx context.Context, desc ocispec.Descriptor, fromRepo string, getContent func() (io.ReadCloser, error)) error { // pushing usually requires both pull and push actions. // Reference: https://github.com/distribution/distribution/blob/v2.7.1/registry/handlers/app.go#L921-L930 - ctx = registryutil.WithScopeHint(ctx, s.repo.Reference, auth.ActionPull, auth.ActionPush) + ctx = auth.AppendRepositoryScope(ctx, s.repo.Reference, auth.ActionPull, auth.ActionPush) // We also need pull access to the source repo. fromRef := s.repo.Reference fromRef.Repository = fromRepo - ctx = registryutil.WithScopeHint(ctx, fromRef, auth.ActionPull) + ctx = auth.AppendRepositoryScope(ctx, fromRef, auth.ActionPull) url := buildRepositoryBlobMountURL(s.repo.PlainHTTP, s.repo.Reference, desc.Digest, fromRepo) req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, nil) @@ -833,7 +832,7 @@ func (s *blobStore) Push(ctx context.Context, expected ocispec.Descriptor, conte // start an upload // pushing usually requires both pull and push actions. // Reference: https://github.com/distribution/distribution/blob/v2.7.1/registry/handlers/app.go#L921-L930 - ctx = registryutil.WithScopeHint(ctx, s.repo.Reference, auth.ActionPull, auth.ActionPush) + ctx = auth.AppendRepositoryScope(ctx, s.repo.Reference, auth.ActionPull, auth.ActionPush) url := buildRepositoryBlobUploadURL(s.repo.PlainHTTP, s.repo.Reference) req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, nil) if err != nil { @@ -934,7 +933,7 @@ func (s *blobStore) Resolve(ctx context.Context, reference string) (ocispec.Desc if err != nil { return ocispec.Descriptor{}, err } - ctx = registryutil.WithScopeHint(ctx, ref, auth.ActionPull) + ctx = auth.AppendRepositoryScope(ctx, ref, auth.ActionPull) url := buildRepositoryBlobURL(s.repo.PlainHTTP, ref) req, err := http.NewRequestWithContext(ctx, http.MethodHead, url, nil) if err != nil { @@ -969,7 +968,7 @@ func (s *blobStore) FetchReference(ctx context.Context, reference string) (desc return ocispec.Descriptor{}, nil, err } - ctx = registryutil.WithScopeHint(ctx, ref, auth.ActionPull) + ctx = auth.AppendRepositoryScope(ctx, ref, auth.ActionPull) url := buildRepositoryBlobURL(s.repo.PlainHTTP, ref) req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) if err != nil { @@ -1044,7 +1043,7 @@ type manifestStore struct { func (s *manifestStore) Fetch(ctx context.Context, target ocispec.Descriptor) (rc io.ReadCloser, err error) { ref := s.repo.Reference ref.Reference = target.Digest.String() - ctx = registryutil.WithScopeHint(ctx, ref, auth.ActionPull) + ctx = auth.AppendRepositoryScope(ctx, ref, auth.ActionPull) url := buildRepositoryManifestURL(s.repo.PlainHTTP, ref) req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) if err != nil { @@ -1170,7 +1169,7 @@ func (s *manifestStore) Resolve(ctx context.Context, reference string) (ocispec. if err != nil { return ocispec.Descriptor{}, err } - ctx = registryutil.WithScopeHint(ctx, ref, auth.ActionPull) + ctx = auth.AppendRepositoryScope(ctx, ref, auth.ActionPull) url := buildRepositoryManifestURL(s.repo.PlainHTTP, ref) req, err := http.NewRequestWithContext(ctx, http.MethodHead, url, nil) if err != nil { @@ -1202,7 +1201,7 @@ func (s *manifestStore) FetchReference(ctx context.Context, reference string) (d return ocispec.Descriptor{}, nil, err } - ctx = registryutil.WithScopeHint(ctx, ref, auth.ActionPull) + ctx = auth.AppendRepositoryScope(ctx, ref, auth.ActionPull) url := buildRepositoryManifestURL(s.repo.PlainHTTP, ref) req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) if err != nil { @@ -1245,7 +1244,7 @@ func (s *manifestStore) Tag(ctx context.Context, desc ocispec.Descriptor, refere return err } - ctx = registryutil.WithScopeHint(ctx, ref, auth.ActionPull, auth.ActionPush) + ctx = auth.AppendRepositoryScope(ctx, ref, auth.ActionPull, auth.ActionPush) rc, err := s.Fetch(ctx, desc) if err != nil { return err @@ -1270,7 +1269,7 @@ func (s *manifestStore) push(ctx context.Context, expected ocispec.Descriptor, c ref.Reference = reference // pushing usually requires both pull and push actions. // Reference: https://github.com/distribution/distribution/blob/v2.7.1/registry/handlers/app.go#L921-L930 - ctx = registryutil.WithScopeHint(ctx, ref, auth.ActionPull, auth.ActionPush) + ctx = auth.AppendRepositoryScope(ctx, ref, auth.ActionPull, auth.ActionPush) url := buildRepositoryManifestURL(s.repo.PlainHTTP, ref) // unwrap the content for optimizations of built-in types. body := ioutil.UnwrapNopCloser(content)