Skip to content

Commit 407821a

Browse files
committed
add tests to backupParseRelsCmdFunc
also does some minor refactoring on the function to make it testable and easier to follow
1 parent 68a7a80 commit 407821a

File tree

3 files changed

+166
-10
lines changed

3 files changed

+166
-10
lines changed

internal/cmd/backup.go

+22-10
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,9 @@ var (
6767
Use: "parse-relationships <filename>",
6868
Short: "Extract the relationships from a backup file",
6969
Args: cobra.ExactArgs(1),
70-
RunE: backupParseRelsCmdFunc,
70+
RunE: func(cmd *cobra.Command, args []string) error {
71+
return backupParseRelsCmdFunc(cmd, os.Stdout, args)
72+
},
7173
}
7274
)
7375

@@ -556,11 +558,12 @@ func backupParseRevisionCmdFunc(_ *cobra.Command, args []string) error {
556558
return nil
557559
}
558560

559-
func backupParseRelsCmdFunc(cmd *cobra.Command, args []string) error {
561+
func backupParseRelsCmdFunc(cmd *cobra.Command, out *os.File, args []string) error {
560562
filename := "" // Default to stdin.
561563
if len(args) > 0 {
562564
filename = args[0]
563565
}
566+
prefix := cobrautil.MustGetString(cmd, "prefix-filter")
564567

565568
f, _, err := openRestoreFile(filename)
566569
if err != nil {
@@ -573,15 +576,24 @@ func backupParseRelsCmdFunc(cmd *cobra.Command, args []string) error {
573576
}
574577

575578
for rel, err := decoder.Next(); rel != nil && err == nil; rel, err = decoder.Next() {
576-
if hasRelPrefix(rel, cobrautil.MustGetString(cmd, "prefix-filter")) {
577-
relString, err := tuple.StringRelationship(rel)
578-
if err != nil {
579-
return err
580-
}
581-
relString = strings.Replace(relString, "@", " ", 1)
582-
relString = strings.Replace(relString, "#", " ", 1)
583-
fmt.Println(relString)
579+
if !hasRelPrefix(rel, prefix) {
580+
continue
581+
}
582+
583+
relString, err := tuple.StringRelationship(rel)
584+
if err != nil {
585+
return err
586+
}
587+
588+
if _, err = fmt.Fprintln(out, replaceRelString(relString)); err != nil {
589+
return err
584590
}
585591
}
592+
586593
return nil
587594
}
595+
596+
func replaceRelString(rel string) string {
597+
rel = strings.Replace(rel, "@", " ", 1)
598+
return strings.Replace(rel, "#", " ", 1)
599+
}

internal/cmd/backup_test.go

+69
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,25 @@
11
package cmd
22

33
import (
4+
"os"
45
"testing"
56

7+
"github.com/rs/zerolog"
68
"github.com/stretchr/testify/require"
79
)
810

11+
func init() {
12+
zerolog.SetGlobalLevel(zerolog.Disabled)
13+
}
14+
15+
const testSchema = `definition test/user {}\ndefinition resource {relation reader: test/user}\n`
16+
17+
var testRelationships = []string{
18+
`test/user:1#reader@test/resource:1`,
19+
`test/user:2#reader@test/resource:2`,
20+
`test/user:3#reader@test/resource:3`,
21+
}
22+
923
func TestFilterSchemaDefs(t *testing.T) {
1024
for _, tt := range []struct {
1125
name string
@@ -99,3 +113,58 @@ func TestFilterSchemaDefs(t *testing.T) {
99113
})
100114
}
101115
}
116+
117+
func TestBackupParseRelsCmdFunc(t *testing.T) {
118+
for _, tt := range []struct {
119+
name string
120+
filter string
121+
schema string
122+
relationships []string
123+
output []string
124+
err string
125+
}{
126+
{
127+
name: "basic test",
128+
filter: "test",
129+
schema: testSchema,
130+
relationships: testRelationships,
131+
output: mapRelationshipTuplesToCLIOutput(t, testRelationships),
132+
},
133+
{
134+
name: "filters out",
135+
filter: "test",
136+
schema: testSchema,
137+
relationships: append([]string{"foo/user:0#reader@foo/resource:0"}, testRelationships...),
138+
output: mapRelationshipTuplesToCLIOutput(t, testRelationships),
139+
},
140+
{
141+
name: "allows empty backup",
142+
filter: "test",
143+
schema: testSchema,
144+
relationships: nil,
145+
output: nil,
146+
},
147+
} {
148+
t.Run(tt.name, func(t *testing.T) {
149+
tt := tt
150+
t.Parallel()
151+
152+
cmd := createTestCobraCommandWithFlagValue(t, "prefix-filter", tt.filter)
153+
backupName := createTestBackup(t, tt.schema, tt.relationships)
154+
f, err := os.CreateTemp("", "parse-output")
155+
require.NoError(t, err)
156+
defer func() {
157+
_ = f.Close()
158+
}()
159+
t.Cleanup(func() {
160+
_ = os.Remove(f.Name())
161+
})
162+
163+
err = backupParseRelsCmdFunc(cmd, f, []string{backupName})
164+
require.NoError(t, err)
165+
166+
lines := readLines(t, f.Name())
167+
require.Equal(t, tt.output, lines)
168+
})
169+
}
170+
}

internal/cmd/helpers_test.go

+75
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
package cmd
2+
3+
import (
4+
"bufio"
5+
"os"
6+
"testing"
7+
8+
v1 "github.com/authzed/authzed-go/proto/authzed/api/v1"
9+
"github.com/authzed/spicedb/pkg/tuple"
10+
"github.com/samber/lo"
11+
"github.com/spf13/cobra"
12+
"github.com/stretchr/testify/require"
13+
14+
"github.com/authzed/zed/pkg/backupformat"
15+
)
16+
17+
func mapRelationshipTuplesToCLIOutput(t *testing.T, input []string) []string {
18+
t.Helper()
19+
20+
return lo.Map[string, string](input, func(item string, _ int) string {
21+
return replaceRelString(item)
22+
})
23+
}
24+
25+
func readLines(t *testing.T, fileName string) []string {
26+
t.Helper()
27+
28+
f, err := os.Open(fileName)
29+
require.NoError(t, err)
30+
defer func() {
31+
_ = f.Close()
32+
}()
33+
34+
var lines []string
35+
scanner := bufio.NewScanner(f)
36+
for scanner.Scan() {
37+
lines = append(lines, scanner.Text())
38+
}
39+
40+
return lines
41+
}
42+
43+
func createTestCobraCommandWithFlagValue(t *testing.T, flagName, flagValue string) *cobra.Command {
44+
t.Helper()
45+
46+
c := cobra.Command{}
47+
c.Flags().String(flagName, flagValue, "")
48+
49+
return &c
50+
}
51+
52+
func createTestBackup(t *testing.T, schema string, relationships []string) string {
53+
t.Helper()
54+
55+
f, err := os.CreateTemp("", "test-backup")
56+
require.NoError(t, err)
57+
defer f.Close()
58+
t.Cleanup(func() {
59+
_ = os.Remove(f.Name())
60+
})
61+
62+
avroWriter, err := backupformat.NewEncoder(f, schema, &v1.ZedToken{Token: "test"})
63+
require.NoError(t, err)
64+
defer func() {
65+
require.NoError(t, avroWriter.Close())
66+
}()
67+
68+
for _, rel := range relationships {
69+
r := tuple.ParseRel(rel)
70+
require.NotNil(t, r)
71+
require.NoError(t, avroWriter.Append(r))
72+
}
73+
74+
return f.Name()
75+
}

0 commit comments

Comments
 (0)