diff --git a/cmd/store/import.go b/cmd/store/import.go index f187e5a..9baacf7 100644 --- a/cmd/store/import.go +++ b/cmd/store/import.go @@ -26,6 +26,7 @@ import ( "github.com/schollz/progressbar/v3" + openfga "github.com/openfga/go-sdk" "github.com/openfga/go-sdk/client" "github.com/spf13/cobra" @@ -116,30 +117,74 @@ func importStore( maxTuplesPerWrite, maxParallelRequests int32, fileName string, ) (*CreateStoreAndModelResponse, error) { - var ( - response *CreateStoreAndModelResponse - err error - ) + response, err := createOrUpdateStore(clientConfig, fgaClient, storeData, format, storeID, fileName) + if err != nil { + return nil, err + } + + if len(storeData.Tuples) == 0 { + return response, nil + } + + err = importTuples(fgaClient, storeData.Tuples, maxTuplesPerWrite, maxParallelRequests) + if err != nil { + return nil, err + } + + return response, nil +} +func createOrUpdateStore( + clientConfig *fga.ClientConfig, + fgaClient client.SdkClient, + storeData *storetest.StoreData, + format authorizationmodel.ModelFormat, + storeID string, + fileName string, +) (*CreateStoreAndModelResponse, error) { if storeID == "" { - response, err = createStore(clientConfig, storeData, format, fileName) - if err != nil { - return nil, fmt.Errorf("failed to create store: %w", err) + return createStore(clientConfig, storeData, format, fileName) + } + + return updateStore(clientConfig, fgaClient, storeData, format, storeID) +} + +func importTuples( + fgaClient client.SdkClient, + tuples []openfga.TupleKey, + maxTuplesPerWrite, maxParallelRequests int32, +) error { + bar := createProgressBar(len(tuples)) + + for index := 0; index < len(tuples); index += int(maxTuplesPerWrite) { + end := index + int(maxTuplesPerWrite) + if end > len(tuples) { + end = len(tuples) } - } else { - response, err = updateStore(clientConfig, fgaClient, storeData, format, storeID) - if err != nil { - return nil, fmt.Errorf("failed to update store: %w", err) + + writeRequest := client.ClientWriteRequest{ + Writes: tuples[index:end], + } + if _, err := tuple.ImportTuples(fgaClient, writeRequest, maxTuplesPerWrite, maxParallelRequests); err != nil { + return fmt.Errorf("failed to import tuples: %w", err) + } + + if err := bar.Add(end - index); err != nil { + return fmt.Errorf("failed to update progress bar: %w", err) } + + time.Sleep(progressBarUpdateDelay) } - fgaClient, err = clientConfig.GetFgaClient() - if err != nil { - return nil, fmt.Errorf("failed to initialize FGA Client: %w", err) + if err := bar.Finish(); err != nil { + return fmt.Errorf("failed to finish progress bar: %w", err) } - // Initialize progress bar - bar := progressbar.NewOptions(len(storeData.Tuples), + return nil +} + +func createProgressBar(total int) *progressbar.ProgressBar { + return progressbar.NewOptions(total, progressbar.OptionSetWriter(os.Stderr), progressbar.OptionSetDescription("Importing tuples"), progressbar.OptionShowCount(), @@ -158,34 +203,6 @@ func importStore( BarEnd: "]", }), ) - - for index := 0; index < len(storeData.Tuples); index += int(maxTuplesPerWrite) { - end := index + int(maxTuplesPerWrite) - if end > len(storeData.Tuples) { - end = len(storeData.Tuples) - } - - writeRequest := client.ClientWriteRequest{ - Writes: storeData.Tuples[index:end], - } - if _, err := tuple.ImportTuples(fgaClient, writeRequest, maxTuplesPerWrite, maxParallelRequests); err != nil { - return nil, fmt.Errorf("failed to import tuples: %w", err) - } - - if err := bar.Add(end - index); err != nil { - return nil, fmt.Errorf("failed to update progress bar: %w", err) - } - - // Introduce a small delay to smooth out the progress bar rendering - time.Sleep(progressBarUpdateDelay) - } - - // Ensure progress bar is completed and cleared - if err := bar.Finish(); err != nil { - return nil, fmt.Errorf("failed to finish progress bar: %w", err) - } - - return response, nil } // importCmd represents the get command.