diff --git a/.golangci.yaml b/.golangci.yaml index e9a010d..a1e5b5e 100644 --- a/.golangci.yaml +++ b/.golangci.yaml @@ -61,6 +61,7 @@ linters-settings: - github.com/openfga/go-sdk - github.com/openfga/openfga - github.com/stretchr + - github.com/openfga/api/proto tagliatelle: case: diff --git a/example/model.fga.yaml b/example/model.fga.yaml index 231b027..a92ecf7 100644 --- a/example/model.fga.yaml +++ b/example/model.fga.yaml @@ -47,7 +47,7 @@ tests: can_view: true can_write: true can_share: false - list_objects: # Each check test is made of: a user, an object type and the expected result for one or more relations + list_objects: # Each list objects test is made of: a user, an object type and the expected result for one or more relations - user: user:anne type: folder assertions: @@ -68,4 +68,18 @@ tests: - folder:product-2021 - folder:product-2021Q1 can_write: [] - can_share: [] \ No newline at end of file + can_share: [] + list_users: # Each list user test is made of: an object, a user filter and the expected result for one or more relations + - object: folder:product-2021 + user_filter: + - type: user + assertions: + can_view: + users: + - user:anne + - user:beth + excluded_users: [] + can_write: + users: + - user:anne + excluded_users: [] diff --git a/example/store_abac.fga.yaml b/example/store_abac.fga.yaml index 728f11b..5fa3487 100644 --- a/example/store_abac.fga.yaml +++ b/example/store_abac.fga.yaml @@ -40,7 +40,7 @@ tests: user_ip: "192.168.1.0" assertions: viewer: false # current time is within granted time interval but the user's ip address is outside the CIDR range - list_objects: # Each check test is made of: a user, an object type and the expected result for one or more relations + list_objects: # Each list objects test is made of: a user, an object type and the expected result for one or more relations - user: user:anne type: document context: @@ -63,3 +63,23 @@ tests: user_ip: "192.168.1.0" assertions: viewer: [] + list_users: # Each list user test is made of: an object, a user filter and the expected result for one or more relations + - object: document:1 + user_filter: + - type: user + context: + current_timestamp: "2023-05-03T21:25:23+00:00" + user_ip: "192.168.0.0" + assertions: + viewer: + users: + - user:anne + - object: document:1 + user_filter: + - type: user + context: + current_timestamp: "2023-05-03T21:25:31+00:00" + user_ip: "192.168.0.0" + assertions: + viewer: + users: [] \ No newline at end of file diff --git a/go.mod b/go.mod index bc20ed4..6fb2d7b 100644 --- a/go.mod +++ b/go.mod @@ -10,10 +10,10 @@ require ( github.com/muesli/roff v0.1.0 github.com/nwidger/jsoncolor v0.3.2 github.com/oklog/ulid/v2 v2.1.0 - github.com/openfga/api/proto v0.0.0-20240425220334-619029c1d3d3 + github.com/openfga/api/proto v0.0.0-20240430203311-36050418a284 github.com/openfga/go-sdk v0.3.6-0.20240430041914-d27ef8fa20b8 github.com/openfga/language/pkg/go v0.0.0-20240429103126-f3e71ca3287d - github.com/openfga/openfga v1.5.3 + github.com/openfga/openfga v1.5.4-0.20240430205231-c4953b813b89 github.com/spf13/cobra v1.8.0 github.com/spf13/pflag v1.0.5 github.com/spf13/viper v1.18.2 diff --git a/go.sum b/go.sum index ca5e557..8eb8f19 100644 --- a/go.sum +++ b/go.sum @@ -37,8 +37,8 @@ github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1 github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/distribution/reference v0.5.0 h1:/FUIFXtfc/x2gpa5/VGfiGLuOIdYa1t65IKK2OFGvA0= github.com/distribution/reference v0.5.0/go.mod h1:BbU0aIcezP1/5jX/8MP0YiH4SdvB5Y4f/wlDRiLyi3E= -github.com/docker/docker v26.0.0+incompatible h1:Ng2qi+gdKADUa/VM+6b6YaY2nlZhk/lVJiKR/2bMudU= -github.com/docker/docker v26.0.0+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk= +github.com/docker/docker v26.0.2+incompatible h1:yGVmKUFGgcxA6PXWAokO0sQL22BrQ67cgVjko8tGdXE= +github.com/docker/docker v26.0.2+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk= github.com/docker/go-connections v0.5.0 h1:USnMq7hx7gwdVZq1L49hLXaFtUdTADjXGp+uj1Br63c= github.com/docker/go-connections v0.5.0/go.mod h1:ov60Kzw0kKElRwhNs9UlUHAE/F9Fe6GLaXnqyDdmEXc= github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4= @@ -183,16 +183,14 @@ github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8 github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM= github.com/opencontainers/image-spec v1.1.0 h1:8SG7/vwALn54lVB/0yZ/MMwhFrPYtpEHQb2IpWsCzug= github.com/opencontainers/image-spec v1.1.0/go.mod h1:W4s4sFTMaBeK1BQLXbG4AdM2szdn85PY75RI83NrTrM= -github.com/openfga/api/proto v0.0.0-20240425220334-619029c1d3d3 h1:d1Kpom9kYMCJQlImCVHAEUfTcL4WeiEGQHcc5+KYQnk= -github.com/openfga/api/proto v0.0.0-20240425220334-619029c1d3d3/go.mod h1:5LtWOArDX4FlbcfDvBoJAzDEYJKLz/OEUoi+0S2tyM8= -github.com/openfga/go-sdk v0.3.6-0.20240425164425-0ac2df2acad5 h1:hjhR7oRVvpKQ9kbFCcel6vAbO2nlr5nwORn80sVzuFU= -github.com/openfga/go-sdk v0.3.6-0.20240425164425-0ac2df2acad5/go.mod h1:AoMnFlPw65sU/7O4xOPpCb2vXA8ZD9K9xp2hZjcvt4g= +github.com/openfga/api/proto v0.0.0-20240430203311-36050418a284 h1:gA09kLBB/voWQCtCH745P2Lx6ZoI7DDs+XlvLFm9S3M= +github.com/openfga/api/proto v0.0.0-20240430203311-36050418a284/go.mod h1:5LtWOArDX4FlbcfDvBoJAzDEYJKLz/OEUoi+0S2tyM8= github.com/openfga/go-sdk v0.3.6-0.20240430041914-d27ef8fa20b8 h1:P2S/gfxpoHBW0tDCxu/aC67MniCcVMR6nQtNzuLLelE= github.com/openfga/go-sdk v0.3.6-0.20240430041914-d27ef8fa20b8/go.mod h1:t2iDiGuJtdyjvMHzkxDsbWbCWOLlVMViPKrU7Sau4K8= github.com/openfga/language/pkg/go v0.0.0-20240429103126-f3e71ca3287d h1:a8YgRx1RfofFiDbMj/Azwh0UeMbdfUf7OiMqSom/smQ= github.com/openfga/language/pkg/go v0.0.0-20240429103126-f3e71ca3287d/go.mod h1:wkI4GcY3yNNuFMU2ncHPWqBaF7XylQTkJYfBi2pIpK8= -github.com/openfga/openfga v1.5.3 h1:Uynmlsx3iz/eiP7wW9n2dbwxDh/kiS/27W6C24Y7oyY= -github.com/openfga/openfga v1.5.3/go.mod h1:IcQBDtytjhBxjfJ+1zCzUxQvQa1BB/4Ed+POpZKojTI= +github.com/openfga/openfga v1.5.4-0.20240430205231-c4953b813b89 h1:hPSTLw4uhdpbcM8h9qsU2WjM6ypF8ktwhAepSbouRfE= +github.com/openfga/openfga v1.5.4-0.20240430205231-c4953b813b89/go.mod h1:k6hBuz6L6tJcqZFogrx0p5qRQF1V3i9qh92i2GJbpNE= github.com/opentracing/opentracing-go v1.1.0/go.mod h1:UkNAQd3GIcIGf0SeVgPpRdFStlNbqXla1AfSYxPUl2o= github.com/pborman/getopt v0.0.0-20170112200414-7148bc3a4c30/go.mod h1:85jBQOZwpVEaDAr341tbn15RS4fCAsIst0qp7i8ex1o= github.com/pelletier/go-toml/v2 v2.2.1 h1:9TA9+T8+8CUCO2+WYnDLCgrYi9+omqKXyjDtosvtEhg= @@ -205,8 +203,8 @@ github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRI github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c h1:ncq/mPwQF4JjgDlrVEn3C11VoGHZN7m8qihwgMEtzYw= github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE= -github.com/pressly/goose/v3 v3.19.2 h1:z1yuD41jS4iaqLkyjkzGkKBz4rgyz/BYtCyMMGHlgzQ= -github.com/pressly/goose/v3 v3.19.2/go.mod h1:BHkf3LzSBmO8E5FTMPupUYIpMTIh/ZuQVy+YTfhZLD4= +github.com/pressly/goose/v3 v3.20.0 h1:uPJdOxF/Ipj7ABVNOAMJXSxwFXZGwMGHNqjC8e61VA0= +github.com/pressly/goose/v3 v3.20.0/go.mod h1:BRfF2GcG4FTG12QfdBVy3q1yveaf4ckL9vWwEcIO3lA= github.com/prometheus/client_golang v1.19.0 h1:ygXvpU1AoN1MhdzckN+PyD9QJOSD4x7kmXYlnfbA6JU= github.com/prometheus/client_golang v1.19.0/go.mod h1:ZRM9uEAypZakd+q/x7+gmsvXdURP+DABIEIjnmDdp+k= github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= diff --git a/internal/storetest/conversion.go b/internal/storetest/conversion.go new file mode 100644 index 0000000..929b28a --- /dev/null +++ b/internal/storetest/conversion.go @@ -0,0 +1,91 @@ +package storetest + +import ( + "fmt" + "strings" + + pb "github.com/openfga/api/proto/openfga/v1" + openfga "github.com/openfga/go-sdk" + "github.com/openfga/go-sdk/client" + "google.golang.org/protobuf/types/known/structpb" +) + +func convertClientTupleKeysToProtoTupleKeys( + tuples []client.ClientContextualTupleKey, +) ([]*pb.TupleKey, error) { + pbTuples := []*pb.TupleKey{} + + for index := 0; index < len(tuples); index++ { + tuple := tuples[index] + tpl := pb.TupleKey{ + User: tuple.User, + Relation: tuple.Relation, + Object: tuple.Object, + } + + if tuple.Condition != nil { + conditionContext, err := structpb.NewStruct(tuple.Condition.GetContext()) + if err != nil { + return nil, fmt.Errorf("failed to construct a proto struct: %w", err) + } + + tpl.Condition = &pb.RelationshipCondition{ + Name: tuple.Condition.Name, + Context: conditionContext, + } + } + + pbTuples = append(pbTuples, &tpl) + } + + return pbTuples, nil +} + +func convertStoreObjectToObject(object string) (openfga.FgaObject, *pb.Object) { + splitObject := strings.Split(object, ":") + + return openfga.FgaObject{ + Type: splitObject[0], + Id: splitObject[1], + }, &pb.Object{ + Type: splitObject[0], + Id: splitObject[1], + } +} + +func convertPbUsersToStrings(users []*pb.User) []string { + simpleUsers := []string{} + + for _, user := range users { + switch typedUser := user.GetUser().(type) { + case *pb.User_Object: + simpleUsers = append(simpleUsers, typedUser.Object.GetType()+":"+typedUser.Object.GetId()) + case *pb.User_Userset: + simpleUsers = append( + simpleUsers, + typedUser.Userset.GetType()+":"+typedUser.Userset.GetId()+"#"+typedUser.Userset.GetRelation(), + ) + case *pb.User_Wildcard: + simpleUsers = append(simpleUsers, typedUser.Wildcard.GetType()+":*") + } + } + + return simpleUsers +} + +func convertOpenfgaUsers(users []openfga.User) []string { + simpleUsers := []string{} + + for _, user := range users { + switch { + case user.Object != nil: + simpleUsers = append(simpleUsers, user.Object.Type+":"+user.Object.Id) + case user.Userset != nil: + simpleUsers = append(simpleUsers, user.Userset.Type+":"+user.Userset.Id+"#"+user.Userset.Relation) + case user.Wildcard != nil: + simpleUsers = append(simpleUsers, user.Wildcard.Type+":*") + } + } + + return simpleUsers +} diff --git a/internal/storetest/conversion_test.go b/internal/storetest/conversion_test.go new file mode 100644 index 0000000..aa1f7ff --- /dev/null +++ b/internal/storetest/conversion_test.go @@ -0,0 +1,139 @@ +package storetest + +import ( + "testing" + + pb "github.com/openfga/api/proto/openfga/v1" + openfga "github.com/openfga/go-sdk" + "github.com/openfga/go-sdk/client" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestConvertPbUsersToStrings(t *testing.T) { + t.Parallel() + + tests := map[string]struct { + input *pb.User + expected string + }{ + "User_Object": { + input: &pb.User{User: &pb.User_Object{Object: &pb.Object{Type: "user", Id: "anne"}}}, + expected: "user:anne", + }, + "User_Userset": { + input: &pb.User{User: &pb.User_Userset{Userset: &pb.UsersetUser{Type: "group", Id: "fga", Relation: "member"}}}, + expected: "group:fga#member", + }, + "User_Wildcard": { + input: &pb.User{User: &pb.User_Wildcard{Wildcard: &pb.TypedWildcard{Type: "user"}}}, + expected: "user:*", + }, + } + + for name, testcase := range tests { + testcase := testcase + + t.Run(name, func(t *testing.T) { + t.Parallel() + + got := convertPbUsersToStrings([]*pb.User{testcase.input}) + + assert.Equal(t, []string{testcase.expected}, got) + }) + } +} + +func TestConvertOpenfgaUsersToStrings(t *testing.T) { + t.Parallel() + + tests := map[string]struct { + input openfga.User + expected string + }{ + "User_Object": { + input: openfga.User{Object: &openfga.FgaObject{Type: "user", Id: "anne"}}, + expected: "user:anne", + }, + "User_Userset": { + input: openfga.User{Userset: &openfga.UsersetUser{Type: "group", Id: "fga", Relation: "member"}}, + expected: "group:fga#member", + }, + "User_Wildcard": { + input: openfga.User{Wildcard: &openfga.TypedWildcard{Type: "user"}}, + expected: "user:*", + }, + } + + for name, testcase := range tests { + testcase := testcase + + t.Run(name, func(t *testing.T) { + t.Parallel() + + got := convertOpenfgaUsers([]openfga.User{testcase.input}) + + assert.Equal(t, []string{testcase.expected}, got) + }) + } +} + +func TestConvertStoreObjectToObject(t *testing.T) { + t.Parallel() + + tests := map[string]struct { + input string + expectedFGAObject openfga.FgaObject + expectedPBObject *pb.Object + }{ + "Converts object": { + input: "document:roadmap", + expectedFGAObject: openfga.FgaObject{Type: "document", Id: "roadmap"}, + expectedPBObject: &pb.Object{Type: "document", Id: "roadmap"}, + }, + } + + for name, testcase := range tests { + testcase := testcase + + t.Run(name, func(t *testing.T) { + t.Parallel() + + fgaObject, pbObject := convertStoreObjectToObject(testcase.input) + + assert.Equal(t, testcase.expectedFGAObject, fgaObject) + assert.Equal(t, testcase.expectedPBObject, pbObject) + }) + } +} + +func TestConvertClientTupleKeysToProtoTupleKeys(t *testing.T) { + t.Parallel() + + tests := map[string]struct { + input []client.ClientContextualTupleKey + expected []*pb.TupleKey + }{ + "User_Object": { + input: []client.ClientContextualTupleKey{ + {User: "user:anne", Relation: "owner", Object: "folder:product"}, + }, + expected: []*pb.TupleKey{ + {User: "user:anne", Relation: "owner", Object: "folder:product"}, + }, + }, + } + + for name, testcase := range tests { + testcase := testcase + + t.Run(name, func(t *testing.T) { + t.Parallel() + + tuples, err := convertClientTupleKeysToProtoTupleKeys(testcase.input) + + require.NoError(t, err) + assert.Equal(t, testcase.expected, tuples) + }) + } +} diff --git a/internal/storetest/localstore.go b/internal/storetest/localstore.go index 73d86ec..bcc128e 100644 --- a/internal/storetest/localstore.go +++ b/internal/storetest/localstore.go @@ -89,6 +89,7 @@ func getLocalServerModelAndTuples( fgaServer, err := server.NewServerWithOpts( server.WithDatastore(datastore), + server.WithExperimentals(server.ExperimentalEnableListUsers), ) if err != nil { return nil, nil, stopServerFn, err //nolint:wrapcheck diff --git a/internal/storetest/localtest.go b/internal/storetest/localtest.go index e6fcd5d..7c2046a 100644 --- a/internal/storetest/localtest.go +++ b/internal/storetest/localtest.go @@ -143,6 +143,81 @@ func RunLocalListObjectsTest( return results } +func RunSingleLocalListUsersTest( + fgaServer *server.Server, + listUsersRequest *pb.ListUsersRequest, +) (*pb.ListUsersResponse, error) { + return fgaServer.ListUsers(context.Background(), listUsersRequest) //nolint:wrapcheck +} + +func RunLocalListUsersTest( + fgaServer *server.Server, + listUsersTest ModelTestListUsers, + tuples []client.ClientContextualTupleKey, + options ModelTestOptions, +) []ModelTestListUsersSingleResult { + results := []ModelTestListUsersSingleResult{} + + object, pbObject := convertStoreObjectToObject(listUsersTest.Object) + + userFilter := &pb.UserTypeFilter{ + Type: listUsersTest.UserFilter[0].GetType(), + Relation: listUsersTest.UserFilter[0].GetRelation(), + } + + for relation, expectation := range listUsersTest.Assertions { + result := ModelTestListUsersSingleResult{ + Request: client.ClientListUsersRequest{ + Object: object, + Relation: relation, + UserFilters: listUsersTest.UserFilter, + ContextualTuples: tuples, + Context: listUsersTest.Context, + }, + Expected: expectation, + } + + var ( + ctx *structpb.Struct + err error + ) + + if listUsersTest.Context != nil { + ctx, err = structpb.NewStruct(*listUsersTest.Context) + } + + if err != nil { + result.Error = err + } else { + response, err := RunSingleLocalListUsersTest(fgaServer, + &pb.ListUsersRequest{ + StoreId: *options.StoreID, + AuthorizationModelId: *options.ModelID, + Object: pbObject, + Relation: relation, + UserFilters: []*pb.UserTypeFilter{userFilter}, + Context: ctx, + }, + ) + if err != nil { + result.Error = err + } + + if response != nil { + result.Got = ModelTestListUsersAssertion{ + Users: convertPbUsersToStrings(response.GetUsers()), + ExcludedUsers: []string{}, + } + result.TestResult = result.IsPassing() + } + } + + results = append(results, result) + } + + return results +} + func RunLocalTest( fgaServer *server.Server, test ModelTest, @@ -151,6 +226,7 @@ func RunLocalTest( ) (TestResult, error) { checkResults := []ModelTestCheckSingleResult{} listObjectResults := []ModelTestListObjectsSingleResult{} + listUsersResults := []ModelTestListUsersSingleResult{} storeID, modelID, err := initLocalStore(fgaServer, model.GetProtoModel(), tuples) if err != nil { @@ -172,10 +248,16 @@ func RunLocalTest( listObjectResults = append(listObjectResults, results...) } + for index := 0; index < len(test.ListUsers); index++ { + results := RunLocalListUsersTest(fgaServer, test.ListUsers[index], tuples, testOptions) + listUsersResults = append(listUsersResults, results...) + } + return TestResult{ Name: test.Name, Description: test.Description, CheckResults: checkResults, ListObjectsResults: listObjectResults, + ListUsersResults: listUsersResults, }, nil } diff --git a/internal/storetest/remotetest.go b/internal/storetest/remotetest.go index c981e88..a560cae 100644 --- a/internal/storetest/remotetest.go +++ b/internal/storetest/remotetest.go @@ -97,6 +97,56 @@ func RunRemoteListObjectsTest( return results } +func RunSingleRemoteListUsersTest( + fgaClient *client.OpenFgaClient, + listUsersRequest client.ClientListUsersRequest, + expectation ModelTestListUsersAssertion, +) ModelTestListUsersSingleResult { + response, err := fgaClient.ListUsers(context.Background()).Body(listUsersRequest).Execute() + + result := ModelTestListUsersSingleResult{ + Request: listUsersRequest, + Expected: expectation, + Error: err, + } + + if response != nil { + result.Got = ModelTestListUsersAssertion{ + Users: convertOpenfgaUsers(response.GetUsers()), + ExcludedUsers: []string{}, + } + result.TestResult = result.IsPassing() + } + + return result +} + +func RunRemoteListUsersTest( + fgaClient *client.OpenFgaClient, + listUsersTest ModelTestListUsers, + tuples []client.ClientContextualTupleKey, +) []ModelTestListUsersSingleResult { + results := []ModelTestListUsersSingleResult{} + + object, _ := convertStoreObjectToObject(listUsersTest.Object) + for relation, expectation := range listUsersTest.Assertions { + result := RunSingleRemoteListUsersTest(fgaClient, + client.ClientListUsersRequest{ + Object: object, + Relation: relation, + UserFilters: listUsersTest.UserFilter, + Context: listUsersTest.Context, + ContextualTuples: tuples, + }, + expectation, + ) + + results = append(results, result) + } + + return results +} + func RunRemoteTest( fgaClient *client.OpenFgaClient, test ModelTest, @@ -116,10 +166,18 @@ func RunRemoteTest( listObjectResults = append(listObjectResults, results...) } + listUserResults := []ModelTestListUsersSingleResult{} + + for index := 0; index < len(test.ListUsers); index++ { + results := RunRemoteListUsersTest(fgaClient, test.ListUsers[index], testTuples) + listUserResults = append(listUserResults, results...) + } + return TestResult{ Name: test.Name, Description: test.Description, CheckResults: checkResults, ListObjectsResults: listObjectResults, + ListUsersResults: listUserResults, } } diff --git a/internal/storetest/storedata.go b/internal/storetest/storedata.go index 86fdf45..53068e1 100644 --- a/internal/storetest/storedata.go +++ b/internal/storetest/storedata.go @@ -22,6 +22,7 @@ import ( "fmt" "path" + openfga "github.com/openfga/go-sdk" "github.com/openfga/go-sdk/client" "github.com/openfga/cli/internal/authorizationmodel" @@ -42,6 +43,18 @@ type ModelTestListObjects struct { Assertions map[string][]string `json:"assertions" yaml:"assertions"` } +type ModelTestListUsers struct { + Object string `json:"object" yaml:"object"` + UserFilter []openfga.UserTypeFilter `json:"user_filter" yaml:"user_filter"` //nolint:tagliatelle + Context *map[string]interface{} `json:"context" yaml:"context,omitempty"` + Assertions map[string]ModelTestListUsersAssertion `json:"assertions" yaml:"assertions"` +} + +type ModelTestListUsersAssertion struct { + Users []string `json:"users" yaml:"users"` + ExcludedUsers []string `json:"excluded_users" yaml:"excluded_users"` //nolint:tagliatelle +} + type ModelTest struct { Name string `json:"name" yaml:"name"` Description string `json:"description" yaml:"description,omitempty"` @@ -49,6 +62,7 @@ type ModelTest struct { TupleFile string `json:"tuple_file" yaml:"tuple_file,omitempty"` //nolint:tagliatelle Check []ModelTestCheck `json:"check" yaml:"check"` ListObjects []ModelTestListObjects `json:"list_objects" yaml:"list_objects,omitempty"` //nolint:tagliatelle + ListUsers []ModelTestListUsers `json:"list_users" yaml:"list_users,omitempty"` //nolint:tagliatelle } type StoreData struct { diff --git a/internal/storetest/testresult.go b/internal/storetest/testresult.go index 1999ab2..7b7faa7 100644 --- a/internal/storetest/testresult.go +++ b/internal/storetest/testresult.go @@ -10,6 +10,8 @@ import ( "github.com/openfga/cli/internal/comparison" ) +const NoValueString = "N/A" + type ModelTestCheckSingleResult struct { Request client.ClientCheckRequest `json:"request"` Expected bool `json:"expected"` @@ -34,11 +36,26 @@ func (result ModelTestListObjectsSingleResult) IsPassing() bool { return result.Error == nil && result.Got != nil && comparison.CheckStringArraysEqual(result.Got, result.Expected) } +type ModelTestListUsersSingleResult struct { + Request client.ClientListUsersRequest `json:"request"` + Expected ModelTestListUsersAssertion `json:"expected"` + Got ModelTestListUsersAssertion `json:"got"` + Error error `json:"error"` + TestResult bool `json:"test_result"` +} + +func (result ModelTestListUsersSingleResult) IsPassing() bool { + return result.Error == nil && + comparison.CheckStringArraysEqual(result.Got.Users, result.Expected.Users) && + comparison.CheckStringArraysEqual(result.Got.ExcludedUsers, result.Expected.ExcludedUsers) +} + type TestResult struct { Name string `json:"name"` Description string `json:"description"` CheckResults []ModelTestCheckSingleResult `json:"check_results"` ListObjectsResults []ModelTestListObjectsSingleResult `json:"list_objects_results"` + ListUsersResults []ModelTestListUsersSingleResult `json:"list_users_results"` } // IsPassing - indicates whether a Test has succeeded completely or has any failing parts. @@ -55,6 +72,12 @@ func (result TestResult) IsPassing() bool { } } + for index := 0; index < len(result.ListUsersResults); index++ { + if !result.ListUsersResults[index].IsPassing() { + return false + } + } + return true } @@ -63,8 +86,11 @@ func (result TestResult) FriendlyFailuresDisplay() string { failedCheckCount := 0 totalListObjectsCount := len(result.ListObjectsResults) failedListObjectsCount := 0 + totalListUsersCount := len(result.ListUsersResults) + failedListUsersCount := 0 checkResultsOutput := "" listObjectsResultsOutput := "" + listUsersResultsOutput := "" if totalCheckCount > 0 { failedCheckCount, checkResultsOutput = buildCheckTestResults( @@ -76,12 +102,18 @@ func (result TestResult) FriendlyFailuresDisplay() string { totalListObjectsCount, result, failedListObjectsCount, listObjectsResultsOutput) } - if failedCheckCount+failedListObjectsCount != 0 { + if totalListUsersCount > 0 { + failedListUsersCount, listUsersResultsOutput = buildListUsersTestResults( + totalListUsersCount, result, failedListUsersCount, listUsersResultsOutput) + } + + if failedCheckCount+failedListObjectsCount+failedListUsersCount != 0 { return buildTestResultOutput( - result, totalCheckCount, - failedCheckCount, totalListObjectsCount, - failedListObjectsCount, checkResultsOutput, - listObjectsResultsOutput) + result, + totalCheckCount, failedCheckCount, + totalListObjectsCount, failedListObjectsCount, + totalListUsersCount, failedListUsersCount, + checkResultsOutput, listObjectsResultsOutput, listUsersResultsOutput) } return "" @@ -97,7 +129,7 @@ func buildCheckTestResults(totalCheckCount int, if !checkResult.IsPassing() { failedCheckCount++ - got := "N/A" + got := NoValueString if checkResult.Got != nil { got = strconv.FormatBool(*checkResult.Got) } @@ -133,7 +165,7 @@ func buildListObjectsTestResults( if !listObjectsResult.IsPassing() { failedListObjectsCount++ - got := "N/A" + got := NoValueString if listObjectsResult.Got != nil { got = fmt.Sprintf("%s", listObjectsResult.Got) } @@ -159,9 +191,48 @@ func buildListObjectsTestResults( return failedListObjectsCount, listObjectsResultsOutput } -func buildTestResultOutput(result TestResult, totalCheckCount int, failedCheckCount int, +func buildListUsersTestResults( + totalListUsersCount int, result TestResult, + failedListUsersCount int, listUsersResultsOutput string, +) (int, string) { + for index := 0; index < totalListUsersCount; index++ { + listUsersResult := result.ListUsersResults[index] + + if !listUsersResult.IsPassing() { + failedListUsersCount++ + + got := NoValueString + if listUsersResult.Got.Users != nil || listUsersResult.Got.ExcludedUsers != nil { + got = fmt.Sprintf("%+v", listUsersResult.Got) + } + + userFilter := listUsersResult.Request.UserFilters[0] + + listUsersResultsOutput += fmt.Sprintf( + "\nⅹ ListUsers(object=%+v,relation=%s,user_filter=%+v", + listUsersResult.Request.Object, + listUsersResult.Request.Relation, + userFilter) + + if listUsersResult.Request.Context != nil { + listUsersResultsOutput += fmt.Sprintf(", context:%v", listUsersResult.Request.Context) + } + + listUsersResultsOutput += fmt.Sprintf("): expected=%+v, got=%+v", listUsersResult.Expected, got) + + if listUsersResult.Error != nil { + listUsersResultsOutput += fmt.Sprintf(", error=%v", listUsersResult.Error) + } + } + } + + return failedListUsersCount, listUsersResultsOutput +} + +func buildTestResultOutput(result TestResult, totalCheckCount int, failedCheckCount int, //nolint:cyclop totalListObjectsCount int, failedListObjectsCount int, - checkResultsOutput string, listObjectsResultsOutput string, + totalListUsersCount int, failedListUsersCount int, + checkResultsOutput string, listObjectsResultsOutput string, listUsersOutput string, ) string { testStatus := "FAILING" output := fmt.Sprintf("(%s) %s: ", testStatus, result.Name) @@ -179,6 +250,15 @@ func buildTestResultOutput(result TestResult, totalCheckCount int, failedCheckCo totalListObjectsCount-failedListObjectsCount, totalListObjectsCount) } + if totalListObjectsCount > 0 && totalListUsersCount > 0 { + output += " | " + } + + if totalListUsersCount > 0 { + output += fmt.Sprintf("ListUsers(%d/%d passing)", + totalListUsersCount-failedListUsersCount, totalListUsersCount) + } + if failedCheckCount > 0 { output = fmt.Sprintf("%s%s", output, checkResultsOutput) } @@ -187,6 +267,10 @@ func buildTestResultOutput(result TestResult, totalCheckCount int, failedCheckCo output = fmt.Sprintf("%s%s", output, listObjectsResultsOutput) } + if failedListUsersCount > 0 { + output = fmt.Sprintf("%s%s", output, listUsersOutput) + } + return output } @@ -205,7 +289,7 @@ func (test TestResults) IsPassing() bool { return true } -func (test TestResults) FriendlyDisplay() string { +func (test TestResults) FriendlyDisplay() string { //nolint:cyclop friendlyResults := []string{} for index := 0; index < len(test.Results); index++ { @@ -222,6 +306,8 @@ func (test TestResults) FriendlyDisplay() string { failedCheckCount := 0 totalListObjectsCount := 0 failedListObjectsCount := 0 + totalListUsersCount := 0 + failedListUsersCount := 0 for _, testResult := range test.Results { if !testResult.IsPassing() { @@ -243,21 +329,34 @@ func (test TestResults) FriendlyDisplay() string { failedListObjectsCount++ } } + + totalListUsersCount += len(testResult.ListUsersResults) + + for _, listUsersResult := range testResult.ListUsersResults { + if !listUsersResult.IsPassing() { + failedListUsersCount++ + } + } } summary := failuresText if totalTestCount > 0 { summary = buildTestSummary( - failedTestCount, summary, totalTestCount, totalCheckCount, failedCheckCount, - totalListObjectsCount, failedListObjectsCount) + failedTestCount, summary, totalTestCount, + totalCheckCount, failedCheckCount, + totalListObjectsCount, failedListObjectsCount, + totalListUsersCount, failedListUsersCount, + ) } return summary } func buildTestSummary(failedTestCount int, summary string, totalTestCount int, - totalCheckCount int, failedCheckCount int, totalListObjectsCount int, failedListObjectsCount int, + totalCheckCount int, failedCheckCount int, + totalListObjectsCount int, failedListObjectsCount int, + totalListUsersCount int, failedListUsersCount int, ) string { if failedTestCount > 0 { summary += "\n---\n" @@ -276,5 +375,10 @@ func buildTestSummary(failedTestCount int, summary string, totalTestCount int, totalListObjectsCount-failedListObjectsCount, totalListObjectsCount) } + if totalListUsersCount > 0 { + summary += fmt.Sprintf("\nListUsers %d/%d passing", + totalListUsersCount-failedListUsersCount, totalListUsersCount) + } + return summary } diff --git a/internal/storetest/tuplekey.go b/internal/storetest/tuplekey.go deleted file mode 100644 index 186243c..0000000 --- a/internal/storetest/tuplekey.go +++ /dev/null @@ -1,40 +0,0 @@ -package storetest - -import ( - "fmt" - - pb "github.com/openfga/api/proto/openfga/v1" - "github.com/openfga/go-sdk/client" - "google.golang.org/protobuf/types/known/structpb" -) - -func convertClientTupleKeysToProtoTupleKeys( - tuples []client.ClientContextualTupleKey, -) ([]*pb.TupleKey, error) { - pbTuples := []*pb.TupleKey{} - - for index := 0; index < len(tuples); index++ { - tuple := tuples[index] - tpl := pb.TupleKey{ - User: tuple.User, - Relation: tuple.Relation, - Object: tuple.Object, - } - - if tuple.Condition != nil { - conditionContext, err := structpb.NewStruct(tuple.Condition.GetContext()) - if err != nil { - return nil, fmt.Errorf("failed to construct a proto struct: %w", err) - } - - tpl.Condition = &pb.RelationshipCondition{ - Name: tuple.Condition.Name, - Context: conditionContext, - } - } - - pbTuples = append(pbTuples, &tpl) - } - - return pbTuples, nil -}