Skip to content

Commit

Permalink
forward context to commands
Browse files Browse the repository at this point in the history
  • Loading branch information
maxlaverse committed Aug 25, 2024
1 parent 209b97a commit cebf8b2
Show file tree
Hide file tree
Showing 13 changed files with 109 additions and 98 deletions.
103 changes: 52 additions & 51 deletions internal/bitwarden/bw/client.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package bw

import (
"context"
"encoding/json"
"fmt"
"os"
Expand All @@ -9,24 +10,24 @@ import (
)

type Client interface {
CreateAttachment(itemId, filePath string) (*Object, error)
CreateObject(Object) (*Object, error)
EditObject(Object) (*Object, error)
GetAttachment(itemId, attachmentId string) ([]byte, error)
GetObject(Object) (*Object, error)
CreateAttachment(ctx context.Context, itemId, filePath string) (*Object, error)
CreateObject(context.Context, Object) (*Object, error)
EditObject(context.Context, Object) (*Object, error)
GetAttachment(ctx context.Context, itemId, attachmentId string) ([]byte, error)
GetObject(context.Context, Object) (*Object, error)
GetSessionKey() string
HasSessionKey() bool
ListObjects(objType string, options ...ListObjectsOption) ([]Object, error)
LoginWithAPIKey(password, clientId, clientSecret string) error
LoginWithPassword(username, password string) error
Logout() error
DeleteAttachment(itemId, attachmentId string) error
DeleteObject(Object) error
SetServer(string) error
ListObjects(ctx context.Context, objType string, options ...ListObjectsOption) ([]Object, error)
LoginWithAPIKey(ctx context.Context, password, clientId, clientSecret string) error
LoginWithPassword(ctx context.Context, username, password string) error
Logout(context.Context) error
DeleteAttachment(ctx context.Context, itemId, attachmentId string) error
DeleteObject(context.Context, Object) error
SetServer(context.Context, string) error
SetSessionKey(string)
Status() (*Status, error)
Sync() error
Unlock(password string) error
Status(context.Context) (*Status, error)
Sync(context.Context) error
Unlock(ctx context.Context, password string) error
}

func NewClient(execPath string, opts ...Options) Client {
Expand Down Expand Up @@ -79,8 +80,8 @@ func DisableRetryBackoff() Options {
}
}

func (c *client) CreateObject(obj Object) (*Object, error) {
objEncoded, err := c.encode(obj)
func (c *client) CreateObject(ctx context.Context, obj Object) (*Object, error) {
objEncoded, err := c.encode(ctx, obj)
if err != nil {
return nil, err
}
Expand All @@ -95,7 +96,7 @@ func (c *client) CreateObject(obj Object) (*Object, error) {
args = append(args, "--organizationid", obj.OrganizationID)
}

out, err := c.cmdWithSession(args...).Run()
out, err := c.cmdWithSession(args...).Run(ctx)
if err != nil {
return nil, err
}
Expand All @@ -109,8 +110,8 @@ func (c *client) CreateObject(obj Object) (*Object, error) {
return &obj, nil
}

func (c *client) CreateAttachment(itemId string, filePath string) (*Object, error) {
out, err := c.cmdWithSession("create", string(ObjectTypeAttachment), "--itemid", itemId, "--file", filePath).Run()
func (c *client) CreateAttachment(ctx context.Context, itemId string, filePath string) (*Object, error) {
out, err := c.cmdWithSession("create", string(ObjectTypeAttachment), "--itemid", itemId, "--file", filePath).Run(ctx)
if err != nil {
return nil, err
}
Expand All @@ -126,29 +127,29 @@ func (c *client) CreateAttachment(itemId string, filePath string) (*Object, erro
return &obj, nil
}

func (c *client) EditObject(obj Object) (*Object, error) {
objEncoded, err := c.encode(obj)
func (c *client) EditObject(ctx context.Context, obj Object) (*Object, error) {
objEncoded, err := c.encode(ctx, obj)
if err != nil {
return nil, err
}

out, err := c.cmdWithSession("edit", string(obj.Object), obj.ID, objEncoded).Run()
out, err := c.cmdWithSession("edit", string(obj.Object), obj.ID, objEncoded).Run(ctx)
if err != nil {
return nil, err
}
err = json.Unmarshal(out, &obj)
if err != nil {
return nil, newUnmarshallError(err, "edit object", out)
}
err = c.Sync()
err = c.Sync(ctx)
if err != nil {
return nil, fmt.Errorf("error syncing: %v, %v", err, string(out))
}

return &obj, nil
}

func (c *client) GetObject(obj Object) (*Object, error) {
func (c *client) GetObject(ctx context.Context, obj Object) (*Object, error) {
args := []string{
"get",
string(obj.Object),
Expand All @@ -159,7 +160,7 @@ func (c *client) GetObject(obj Object) (*Object, error) {
args = append(args, "--organizationid", obj.OrganizationID)
}

out, err := c.cmdWithSession(args...).Run()
out, err := c.cmdWithSession(args...).Run(ctx)
if err != nil {
return nil, remapError(err)
}
Expand All @@ -172,8 +173,8 @@ func (c *client) GetObject(obj Object) (*Object, error) {
return &obj, nil
}

func (c *client) GetAttachment(itemId, attachmentId string) ([]byte, error) {
out, err := c.cmdWithSession("get", string(ObjectTypeAttachment), attachmentId, "--itemid", itemId, "--raw").Run()
func (c *client) GetAttachment(ctx context.Context, itemId, attachmentId string) ([]byte, error) {
out, err := c.cmdWithSession("get", string(ObjectTypeAttachment), attachmentId, "--itemid", itemId, "--raw").Run(ctx)
if err != nil {
return nil, remapError(err)
}
Expand All @@ -186,7 +187,7 @@ func (c *client) GetSessionKey() string {
}

// ListObjects returns objects of a given type matching given filters.
func (c *client) ListObjects(objType string, options ...ListObjectsOption) ([]Object, error) {
func (c *client) ListObjects(ctx context.Context, objType string, options ...ListObjectsOption) ([]Object, error) {
args := []string{
"list",
objType,
Expand All @@ -196,7 +197,7 @@ func (c *client) ListObjects(objType string, options ...ListObjectsOption) ([]Ob
applyOption(&args)
}

out, err := c.cmdWithSession(args...).Run()
out, err := c.cmdWithSession(args...).Run(ctx)
if err != nil {
return nil, remapError(err)
}
Expand All @@ -212,8 +213,8 @@ func (c *client) ListObjects(objType string, options ...ListObjectsOption) ([]Ob

// LoginWithPassword logs in using a password and retrieves the session key,
// allowing authenticated requests using the client.
func (c *client) LoginWithPassword(username, password string) error {
out, err := c.cmd("login", username, "--raw", "--passwordenv", "BW_PASSWORD").AppendEnv([]string{fmt.Sprintf("BW_PASSWORD=%s", password)}).Run()
func (c *client) LoginWithPassword(ctx context.Context, username, password string) error {
out, err := c.cmd("login", username, "--raw", "--passwordenv", "BW_PASSWORD").AppendEnv([]string{fmt.Sprintf("BW_PASSWORD=%s", password)}).Run(ctx)
if err != nil {
return err
}
Expand All @@ -223,20 +224,20 @@ func (c *client) LoginWithPassword(username, password string) error {

// LoginWithPassword logs in using an API key and unlock the Vault in order to retrieve a session key,
// allowing authenticated requests using the client.
func (c *client) LoginWithAPIKey(password, clientId, clientSecret string) error {
_, err := c.cmd("login", "--apikey").AppendEnv([]string{fmt.Sprintf("BW_CLIENTID=%s", clientId), fmt.Sprintf("BW_CLIENTSECRET=%s", clientSecret)}).Run()
func (c *client) LoginWithAPIKey(ctx context.Context, password, clientId, clientSecret string) error {
_, err := c.cmd("login", "--apikey").AppendEnv([]string{fmt.Sprintf("BW_CLIENTID=%s", clientId), fmt.Sprintf("BW_CLIENTSECRET=%s", clientSecret)}).Run(ctx)
if err != nil {
return err
}
return c.Unlock(password)
return c.Unlock(ctx, password)
}

func (c *client) Logout() error {
_, err := c.cmd("logout").Run()
func (c *client) Logout(ctx context.Context) error {
_, err := c.cmd("logout").Run(ctx)
return err
}

func (c *client) DeleteObject(obj Object) error {
func (c *client) DeleteObject(ctx context.Context, obj Object) error {
args := []string{
"delete",
string(obj.Object),
Expand All @@ -247,22 +248,22 @@ func (c *client) DeleteObject(obj Object) error {
args = append(args, "--organizationid", obj.OrganizationID)
}

_, err := c.cmdWithSession(args...).Run()
_, err := c.cmdWithSession(args...).Run(ctx)
return err
}

func (c *client) DeleteAttachment(itemId, attachmentId string) error {
_, err := c.cmdWithSession("delete", string(ObjectTypeAttachment), attachmentId, "--itemid", itemId).Run()
func (c *client) DeleteAttachment(ctx context.Context, itemId, attachmentId string) error {
_, err := c.cmdWithSession("delete", string(ObjectTypeAttachment), attachmentId, "--itemid", itemId).Run(ctx)
return err
}

func (c *client) SetServer(server string) error {
_, err := c.cmd("config", "server", server).Run()
func (c *client) SetServer(ctx context.Context, server string) error {
_, err := c.cmd("config", "server", server).Run(ctx)
return err
}

func (c *client) Status() (*Status, error) {
out, err := c.cmdWithSession("status").Run()
func (c *client) Status(ctx context.Context) (*Status, error) {
out, err := c.cmdWithSession("status").Run(ctx)
if err != nil {
return nil, err
}
Expand All @@ -276,8 +277,8 @@ func (c *client) Status() (*Status, error) {
return &status, nil
}

func (c *client) Unlock(password string) error {
out, err := c.cmd("unlock", "--raw", "--passwordenv", "BW_PASSWORD").AppendEnv([]string{fmt.Sprintf("BW_PASSWORD=%s", password)}).Run()
func (c *client) Unlock(ctx context.Context, password string) error {
out, err := c.cmd("unlock", "--raw", "--passwordenv", "BW_PASSWORD").AppendEnv([]string{fmt.Sprintf("BW_PASSWORD=%s", password)}).Run(ctx)
if err != nil {
return err
}
Expand All @@ -294,11 +295,11 @@ func (c *client) SetSessionKey(sessionKey string) {
c.sessionKey = sessionKey
}

func (c *client) Sync() error {
func (c *client) Sync(ctx context.Context) error {
if c.disableSync {
return nil
}
_, err := c.cmdWithSession("sync").Run()
_, err := c.cmdWithSession("sync").Run(ctx)
return err
}

Expand All @@ -322,13 +323,13 @@ func (c *client) env() []string {
return defaultEnv
}

func (c *client) encode(item Object) (string, error) {
func (c *client) encode(ctx context.Context, item Object) (string, error) {
newOut, err := json.Marshal(item)
if err != nil {
return "", fmt.Errorf("marshalling error: %v, %v", err, string(newOut))
}

out, err := c.cmd("encode").WithStdin(string(newOut)).Run()
out, err := c.cmd("encode").WithStdin(string(newOut)).Run(ctx)
if err != nil {
return "", fmt.Errorf("encoding error: %v, %v", err, string(newOut))
}
Expand Down
9 changes: 5 additions & 4 deletions internal/bitwarden/bw/client_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package bw

import (
"context"
"testing"

test_command "github.com/maxlaverse/terraform-provider-bitwarden/internal/command/test"
Expand All @@ -15,7 +16,7 @@ func TestCreateObjectEncoding(t *testing.T) {
defer removeMocks(t)

b := NewClient("dummy")
_, err := b.CreateObject(Object{
_, err := b.CreateObject(context.Background(), Object{
Type: ItemTypeLogin,
Fields: []Field{
{
Expand All @@ -40,7 +41,7 @@ func TestListObjects(t *testing.T) {
defer removeMocks(t)

b := NewClient("dummy")
_, err := b.ListObjects("item", WithFolderID("folder-id"), WithCollectionID("collection-id"), WithSearch("search"))
_, err := b.ListObjects(context.Background(), "item", WithFolderID("folder-id"), WithCollectionID("collection-id"), WithSearch("search"))

assert.NoError(t, err)
if assert.Len(t, commandsExecuted(), 1) {
Expand All @@ -55,7 +56,7 @@ func TestGetItem(t *testing.T) {
defer removeMocks(t)

b := NewClient("dummy")
_, err := b.GetObject(Object{ID: "object-id", Object: ObjectTypeItem, Type: ItemTypeLogin})
_, err := b.GetObject(context.Background(), Object{ID: "object-id", Object: ObjectTypeItem, Type: ItemTypeLogin})

assert.NoError(t, err)
if assert.Len(t, commandsExecuted(), 1) {
Expand All @@ -70,7 +71,7 @@ func TestGetOrgCollection(t *testing.T) {
defer removeMocks(t)

b := NewClient("dummy")
_, err := b.GetObject(Object{ID: "object-id", Object: ObjectTypeOrgCollection, OrganizationID: "org-id"})
_, err := b.GetObject(context.Background(), Object{ID: "object-id", Object: ObjectTypeOrgCollection, OrganizationID: "org-id"})

assert.NoError(t, err)
if assert.Len(t, commandsExecuted(), 1) {
Expand Down
2 changes: 1 addition & 1 deletion internal/bitwarden/bw/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ var (
)

func newUnmarshallError(err error, cmd string, out []byte) error {
return fmt.Errorf("unable to parse result of '%s' command: %v, output: %v", cmd, err, string(out))
return fmt.Errorf("unable to parse result of '%s', error: '%v', output: '%v'", cmd, err, string(out))
}

func remapError(err error) error {
Expand Down
7 changes: 4 additions & 3 deletions internal/command/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package command

import (
"bytes"
"context"
"io"
"log"
"os/exec"
Expand All @@ -27,7 +28,7 @@ type command struct {
type Command interface {
AppendEnv(envs []string) Command
WithStdin(string) Command
Run() ([]byte, error)
Run(ctx context.Context) ([]byte, error)
}

func (c *command) AppendEnv(envs []string) Command {
Expand All @@ -42,11 +43,11 @@ func (c *command) WithStdin(dir string) Command {
return c
}

func (c *command) Run() ([]byte, error) {
func (c *command) Run(ctx context.Context) ([]byte, error) {
log.Printf("[DEBUG] Running command '%v'\n", c.args)
var stdOut, stdErr bytes.Buffer

cmd := exec.Command(c.binary, c.args...)
cmd := exec.CommandContext(ctx, c.binary, c.args...)
cmd.Env = c.env
cmd.Stdin = c.stdin
cmd.Stdout = &stdOut
Expand Down
5 changes: 3 additions & 2 deletions internal/command/cmd_retryable.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package command

import (
"context"
"log"
"time"
)
Expand Down Expand Up @@ -33,11 +34,11 @@ func (c *retryableCommand) WithStdin(dir string) Command {
return c
}

func (c *retryableCommand) Run() ([]byte, error) {
func (c *retryableCommand) Run(ctx context.Context) ([]byte, error) {
attempts := 0
for {
attempts = attempts + 1
out, err := c.cmd.Run()
out, err := c.cmd.Run(ctx)
if err == nil || !c.retryHandler.IsRetryable(err, attempts) {
return out, err
}
Expand Down
5 changes: 3 additions & 2 deletions internal/command/cmd_retryable_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package command

import (
"context"
"fmt"
"os"
"strings"
Expand All @@ -21,7 +22,7 @@ func TestCommandRerunOnMatchingError(t *testing.T) {
cmd := NewWithRetries(retryHandler)(os.Args[0], "-test.run=TestCommandRerunOnMatchingError")
cmd.AppendEnv([]string{"GO_WANT_HELPER_PROCESS=1"})

_, err := cmd.Run()
_, err := cmd.Run(context.Background())

assert.NotNil(t, err)
assert.Error(t, err)
Expand All @@ -39,7 +40,7 @@ func TestCommandFailsOnUnmatchedError(t *testing.T) {
cmd := NewWithRetries(retryHandler)(os.Args[0], "-test.run=TestCommandFailsOnUnmatchedError")
cmd.AppendEnv([]string{"GO_WANT_HELPER_PROCESS=1"})

_, err := cmd.Run()
_, err := cmd.Run(context.Background())

assert.NotNil(t, err)
assert.Error(t, err)
Expand Down
Loading

0 comments on commit cebf8b2

Please sign in to comment.