From 79c69d57e99725269d00961f18c3fc6311e6f6ab Mon Sep 17 00:00:00 2001 From: Pawel Kopiczko Date: Wed, 20 Nov 2024 21:58:51 +0000 Subject: [PATCH] Fix broken auth Access Request creation tests (#49258) This got exposed while working on Access Request reason required PR: https://github.com/gravitational/teleport/pull/49124 --- lib/auth/auth_with_roles_test.go | 54 ++++++++++++++++++++++++-------- 1 file changed, 41 insertions(+), 13 deletions(-) diff --git a/lib/auth/auth_with_roles_test.go b/lib/auth/auth_with_roles_test.go index 538e4655fce7a..ab4340bc31493 100644 --- a/lib/auth/auth_with_roles_test.go +++ b/lib/auth/auth_with_roles_test.go @@ -8053,7 +8053,7 @@ func TestCreateAccessRequest(t *testing.T) { clock := srv.Clock() alice, bob, admin := createSessionTestUsers(t, srv.Auth()) - searchRole, err := types.NewRole("requestRole", types.RoleSpecV6{ + searchRole, err := types.NewRole("searchRole", types.RoleSpecV6{ Allow: types.RoleConditions{ Request: &types.AccessRequestConditions{ Roles: []string{"requestRole"}, @@ -8063,11 +8063,32 @@ func TestCreateAccessRequest(t *testing.T) { }) require.NoError(t, err) - requestRole, err := types.NewRole("requestRole", types.RoleSpecV6{}) + requestRole, err := types.NewRole("requestRole", types.RoleSpecV6{ + Allow: types.RoleConditions{ + GroupLabels: types.Labels{ + types.Wildcard: []string{types.Wildcard}, + }, + NodeLabels: types.Labels{ + types.Wildcard: []string{types.Wildcard}, + }, + }, + }) require.NoError(t, err) - srv.Auth().CreateRole(ctx, searchRole) - srv.Auth().CreateRole(ctx, requestRole) + nodeAllowedByRequestRole, err := types.NewServerWithLabels( + "test-node", + types.KindNode, + types.ServerSpecV2{}, + map[string]string{"any-key": "any-val"}, + ) + require.NoError(t, err) + + _, err = srv.Auth().UpsertNode(ctx, nodeAllowedByRequestRole) + require.NoError(t, err) + _, err = srv.Auth().CreateRole(ctx, requestRole) + require.NoError(t, err) + _, err = srv.Auth().CreateRole(ctx, searchRole) + require.NoError(t, err) user, err := srv.Auth().GetUser(ctx, alice, true) require.NoError(t, err) @@ -8110,12 +8131,12 @@ func TestCreateAccessRequest(t *testing.T) { user: alice, accessRequest: mustAccessRequest(t, alice, types.RequestState_PENDING, clock.Now(), clock.Now().Add(time.Hour), []string{requestRole.GetName()}, []types.ResourceID{ - mustResourceID(srv.ClusterName(), types.KindRole, requestRole.GetName()), + mustResourceID(srv.ClusterName(), nodeAllowedByRequestRole.GetKind(), nodeAllowedByRequestRole.GetName()), }), errAssertionFunc: require.NoError, expected: mustAccessRequest(t, alice, types.RequestState_PENDING, clock.Now(), clock.Now().Add(time.Hour), []string{requestRole.GetName()}, []types.ResourceID{ - mustResourceID(srv.ClusterName(), types.KindRole, requestRole.GetName()), + mustResourceID(srv.ClusterName(), nodeAllowedByRequestRole.GetKind(), nodeAllowedByRequestRole.GetName()), }), }, { @@ -8123,12 +8144,15 @@ func TestCreateAccessRequest(t *testing.T) { user: admin, accessRequest: mustAccessRequest(t, alice, types.RequestState_PENDING, clock.Now(), clock.Now().Add(time.Hour), []string{requestRole.GetName()}, []types.ResourceID{ - mustResourceID(srv.ClusterName(), types.KindRole, requestRole.GetName()), + mustResourceID(srv.ClusterName(), types.KindUserGroup, userGroup1.GetName()), }), errAssertionFunc: require.NoError, expected: mustAccessRequest(t, alice, types.RequestState_PENDING, clock.Now(), clock.Now().Add(time.Hour), []string{requestRole.GetName()}, []types.ResourceID{ - mustResourceID(srv.ClusterName(), types.KindRole, requestRole.GetName()), + mustResourceID(srv.ClusterName(), types.KindUserGroup, userGroup1.GetName()), + mustResourceID(srv.ClusterName(), types.KindApp, userGroup1.GetApplications()[0]), + mustResourceID(srv.ClusterName(), types.KindApp, userGroup1.GetApplications()[1]), + mustResourceID(srv.ClusterName(), types.KindApp, userGroup1.GetApplications()[2]), }), }, { @@ -8136,7 +8160,7 @@ func TestCreateAccessRequest(t *testing.T) { user: bob, accessRequest: mustAccessRequest(t, alice, types.RequestState_PENDING, clock.Now(), clock.Now().Add(time.Hour), []string{requestRole.GetName()}, []types.ResourceID{ - mustResourceID(srv.ClusterName(), types.KindRole, requestRole.GetName()), + mustResourceID(srv.ClusterName(), types.KindUserGroup, userGroup1.GetName()), }), errAssertionFunc: require.Error, }, @@ -8145,7 +8169,7 @@ func TestCreateAccessRequest(t *testing.T) { user: alice, accessRequest: mustAccessRequest(t, alice, types.RequestState_PENDING, clock.Now(), clock.Now().Add(time.Hour), []string{requestRole.GetName()}, []types.ResourceID{ - mustResourceID(srv.ClusterName(), types.KindRole, requestRole.GetName()), + mustResourceID(srv.ClusterName(), nodeAllowedByRequestRole.GetKind(), nodeAllowedByRequestRole.GetName()), mustResourceID(srv.ClusterName(), types.KindUserGroup, userGroup1.GetName()), mustResourceID(srv.ClusterName(), types.KindApp, "app1"), mustResourceID(srv.ClusterName(), types.KindUserGroup, userGroup2.GetName()), @@ -8154,7 +8178,7 @@ func TestCreateAccessRequest(t *testing.T) { errAssertionFunc: require.NoError, expected: mustAccessRequest(t, alice, types.RequestState_PENDING, clock.Now(), clock.Now().Add(time.Hour), []string{requestRole.GetName()}, []types.ResourceID{ - mustResourceID(srv.ClusterName(), types.KindRole, requestRole.GetName()), + mustResourceID(srv.ClusterName(), nodeAllowedByRequestRole.GetKind(), nodeAllowedByRequestRole.GetName()), mustResourceID(srv.ClusterName(), types.KindUserGroup, userGroup1.GetName()), mustResourceID(srv.ClusterName(), types.KindApp, "app1"), mustResourceID(srv.ClusterName(), types.KindUserGroup, userGroup2.GetName()), @@ -8389,9 +8413,13 @@ func TestAccessRequestNonGreedyAnnotations(t *testing.T) { require.NoError(t, err) paymentsServer.SetStaticLabels(map[string]string{"service": "payments"}) - idServer, err := types.NewServer("server-identity", types.KindNode, types.ServerSpecV2{}) + idServer, err := types.NewServerWithLabels( + "server-identity", + types.KindNode, + types.ServerSpecV2{}, + map[string]string{"service": "identity"}, + ) require.NoError(t, err) - idServer.SetStaticLabels(map[string]string{"service": "payments"}) ctx := context.Background() srv := newTestTLSServer(t)