Skip to content

Commit

Permalink
fixup! WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
gbjk committed Jan 8, 2025
1 parent 79e7f47 commit 91fc3d0
Show file tree
Hide file tree
Showing 6 changed files with 61 additions and 94 deletions.
82 changes: 25 additions & 57 deletions cmd/config/config.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package main

import (
"errors"
"context"
"flag"
"fmt"
"os"
Expand All @@ -23,15 +23,15 @@ func main() {

var in, out, keyStr string
var inplace bool
var version uint
var version int

fs := flag.NewFlagSet("config", flag.ExitOnError)
fs.Usage = func() { usage(fs) }
fs.StringVar(&in, "in", defaultCfgFile, "The config input file to process")
fs.StringVar(&out, "out", "[in].out", "The config output file")
fs.BoolVar(&inplace, "edit", false, "Edit; Save result to the original file")
fs.StringVar(&keyStr, "key", "", "The key to use for AES encryption")
fs.UintVar(&version, "version", 0, "The version to downgrade to")
fs.IntVar(&version, "version", 0, "The version to downgrade to")

cmd, args := parseCommand(os.Args[1:])
if cmd == "" {
Expand Down Expand Up @@ -70,71 +70,39 @@ func main() {
}

switch cmd {
case "downgrade":
case "decrypt":
if data, err = jsonparser.Set(data, []byte("-1"), "encryptConfig"); err != nil {
fatal("Unable to decrypt config data; Error: " + err.Error())
}
case "downgrade", "upgrade":
if version == 0 {
fmt.Fprintln(os.Stderr, "Error: downgrade requires a version")
usage(fs)
os.Exit(3)
if cmd == "downgrade" {
fmt.Fprintln(os.Stderr, "Error: downgrade requires a version")
usage(fs)
os.Exit(3)
}
version = -1
}
versions.Manager.DowngradeTarget = version
//err = deployVersion(in, out, key)
case "upgrade":
if version != 0 {
fmt.Fprintln(os.Stderr, "Error: upgrade does not accept a version")
usage(fs)
os.Exit(3)
if data, err = versions.Manager.Deploy(context.Background(), data, version); err != nil {
fatal("Unable to " + cmd + " config; Error: " + err.Error())
}
//err = deployVersion(in, out, key)
case "decrypt":
err = encryptWrapper(in, out, key, false, decryptFile)
if !isEncrypted {
break
}
fallthrough
case "encrypt":
err = encryptWrapper(in, out, key, true, encryptFile)
if data, err = config.EncryptConfigData(data, key); err != nil {
fatal("Unable to encrypt config data; Error: " + err.Error())
}
}

if err != nil {
fatal(err.Error())
if err := file.Write(out, data); err != nil {
fatal("Unable to write output file `" + out + "`; Error: " + err.Error())
}

fmt.Println("Success! File written to " + out)
}

type encryptFunc func(string, []byte) ([]byte, error)

func encryptWrapper(in, out string, key []byte, confirmKey bool, fn encryptFunc) error {
if len(key) == 0 {
}
outData, err := fn(in, key)
if err != nil {
return err
}
if err := file.Write(out, outData); err != nil {
return fmt.Errorf("unable to write output file %s; Error: %w", out, err)
}
return nil
}

func encryptFile(in string, key []byte) ([]byte, error) {
outData, err := config.EncryptConfigFile(readFile(in), key)
if err != nil {
return nil, fmt.Errorf("unable to encrypt config data. Error: %w", err)
}
return outData, nil
}

func decryptFile(in string, key []byte) ([]byte, error) {
if !config.IsFileEncrypted(in) {
return nil, errors.New("file is already decrypted")
}
outData, err := config.DecryptConfigFile(readFile(in), key)
if err != nil {
return nil, fmt.Errorf("unable to decrypt config data. Error: %w", err)
}
if outData, err = jsonparser.Set(outData, []byte("-1"), "encryptConfig"); err != nil {
return nil, fmt.Errorf("unable to decrypt config data. Error: %w", err)
}
return outData, nil
}

func readFile(in string) []byte {
fileData, err := os.ReadFile(in)
if err != nil {
Expand Down
4 changes: 2 additions & 2 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -1502,7 +1502,7 @@ func (c *Config) readConfig(d io.Reader) error {
}
}

if j, err = versions.Manager.Deploy(context.Background(), j); err != nil {
if j, err = versions.Manager.Deploy(context.Background(), j, -1); err != nil {
return err
}

Expand Down Expand Up @@ -1593,7 +1593,7 @@ func (c *Config) Save(writerProvider func() (io.Writer, error)) error {
}
c.sessionDK, c.storedSalt = sessionDK, storedSalt
}
payload, err = c.encryptConfigFile(payload)
payload, err = c.encryptConfigData(payload)
if err != nil {
return err
}
Expand Down
14 changes: 7 additions & 7 deletions config/config_encryption.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,8 @@ func getSensitiveInput(prompt string) (resp []byte, err error) {
return bytes.TrimRight(resp, "\r\n"), err
}

// EncryptConfigFile encrypts json config data with a key
func EncryptConfigFile(configData, key []byte) ([]byte, error) {
// EncryptConfigData encrypts json config data with a key
func EncryptConfigData(configData, key []byte) ([]byte, error) {
sessionDK, salt, err := makeNewSessionDK(key)
if err != nil {
return nil, err
Expand All @@ -105,12 +105,12 @@ func EncryptConfigFile(configData, key []byte) ([]byte, error) {
sessionDK: sessionDK,
storedSalt: salt,
}
return c.encryptConfigFile(configData)
return c.encryptConfigData(configData)
}

// encryptConfigFile encrypts json config data with a key
// encryptConfigData encrypts json config data with a key
// The EncryptConfig field is set to config enabled (1)
func (c *Config) encryptConfigFile(configData []byte) ([]byte, error) {
func (c *Config) encryptConfigData(configData []byte) ([]byte, error) {
configData, err := jsonparser.Set(configData, []byte("1"), "encryptConfig")
if err != nil {
return nil, fmt.Errorf("%w: %w", ErrSettingEncryptConfig, err)
Expand All @@ -135,8 +135,8 @@ func (c *Config) encryptConfigFile(configData []byte) ([]byte, error) {
return appendedFile, nil
}

// DecryptConfigFile decrypts config data with a key
func DecryptConfigFile(d, key []byte) ([]byte, error) {
// DecryptConfigData decrypts config data with a key
func DecryptConfigData(d, key []byte) ([]byte, error) {
return (&Config{}).DecryptConfigData(d, key)
}

Expand Down
18 changes: 9 additions & 9 deletions config/config_encryption_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,16 +59,16 @@ func TestPromptForConfigKey(t *testing.T) {

func TestEncryptConfigFile(t *testing.T) {
t.Parallel()
_, err := EncryptConfigFile([]byte("test"), nil)
_, err := EncryptConfigData([]byte("test"), nil)
require.ErrorIs(t, err, errKeyIsEmpty)

c := &Config{
sessionDK: []byte("a"),
}
_, err = c.encryptConfigFile([]byte(`test`))
_, err = c.encryptConfigData([]byte(`test`))
require.ErrorIs(t, err, ErrSettingEncryptConfig)

_, err = c.encryptConfigFile([]byte(`{"test":1}`))
_, err = c.encryptConfigData([]byte(`{"test":1}`))
require.Error(t, err)
require.IsType(t, aes.KeySizeError(1), err)

Expand All @@ -79,26 +79,26 @@ func TestEncryptConfigFile(t *testing.T) {
sessionDK: sessDk,
storedSalt: salt,
}
_, err = c.encryptConfigFile([]byte(`{"test":1}`))
_, err = c.encryptConfigData([]byte(`{"test":1}`))
require.NoError(t, err)
}

func TestDecryptConfigFile(t *testing.T) {
t.Parallel()
e, err := EncryptConfigFile([]byte(`{"test":1}`), []byte("key"))
e, err := EncryptConfigData([]byte(`{"test":1}`), []byte("key"))
require.NoError(t, err)

d, err := DecryptConfigFile(e, []byte("key"))
d, err := DecryptConfigData(e, []byte("key"))
require.NoError(t, err)
assert.Equal(t, `{"test":1,"encryptConfig":1}`, string(d), "encryptConfig should be set to 1 after first encryption")

_, err = DecryptConfigFile(e, nil)
_, err = DecryptConfigData(e, nil)
require.ErrorIs(t, err, errKeyIsEmpty)

_, err = DecryptConfigFile([]byte("test"), nil)
_, err = DecryptConfigData([]byte("test"), nil)
require.ErrorIs(t, err, errNoPrefix)

_, err = DecryptConfigFile(encryptionPrefix, []byte("AAAAAAAAAAAAAAAA"))
_, err = DecryptConfigData(encryptionPrefix, []byte("AAAAAAAAAAAAAAAA"))
require.ErrorIs(t, err, errAESBlockSize)
}

Expand Down
15 changes: 7 additions & 8 deletions config/versions/versions.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,18 +48,17 @@ type ExchangeVersion interface {

// manager contains versions registerVersioned during import init
type manager struct {
m sync.RWMutex
versions []any
DowngradeTarget uint
m sync.RWMutex
versions []any
}

// Manager is a public instance of the config version manager
var Manager = &manager{}

// Deploy upgrades or downgrades the config between versions
// Prints a warning and exits if the config version is ahead of the latest version unless DowngradeTarget is set
// Downgrades may only be invoked through cmd/config downgrade -version
func (m *manager) Deploy(ctx context.Context, j []byte) ([]byte, error) {
// version param -1 defaults to the latest version
// Prints an error an exits if the config file version or version param is not registered
func (m *manager) Deploy(ctx context.Context, j []byte, version int) ([]byte, error) {
if err := m.checkVersions(); err != nil {
return j, err
}
Expand All @@ -71,8 +70,8 @@ func (m *manager) Deploy(ctx context.Context, j []byte) ([]byte, error) {

target := latest

if m.DowngradeTarget != 0 {
target = int(m.DowngradeTarget)
if version != 0 {
target = int(version)

Check failure on line 74 in config/versions/versions.go

View workflow job for this annotation

GitHub Actions / lint

unnecessary conversion (unconvert)
}

m.m.RLock()
Expand Down
22 changes: 11 additions & 11 deletions config/versions/versions_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,48 +10,48 @@ import (
"github.com/thrasher-corp/gocryptotrader/common"
)

func TestUpgrade(t *testing.T) {
func TestDeploy(t *testing.T) {
t.Parallel()
m := manager{}
_, err := m.Upgrade(context.Background(), []byte(``))
_, err := m.Deploy(context.Background(), []byte(``), -1)
assert.ErrorIs(t, err, errNoVersions)

m.registerVersion(1, &TestVersion1{})
_, err = m.Upgrade(context.Background(), []byte(``))
_, err = m.Deploy(context.Background(), []byte(``), -1)
require.ErrorIs(t, err, errVersionIncompatible)

m = manager{}

m.registerVersion(0, &Version0{})
_, err = m.Upgrade(context.Background(), []byte(`not an object`))
_, err = m.Deploy(context.Background(), []byte(`not an object`), -1)
require.ErrorIs(t, err, jsonparser.KeyPathNotFoundError, "Must throw the correct error trying to add version to bad json")
require.ErrorIs(t, err, common.ErrSettingField, "Must throw the correct error trying to add version to bad json")
require.ErrorContains(t, err, "version", "Must throw the correct error trying to add version to bad json")

_, err = m.Upgrade(context.Background(), []byte(`{"version":"not an int"}`))
_, err = m.Deploy(context.Background(), []byte(`{"version":"not an int"}`), -1)
require.ErrorIs(t, err, common.ErrGettingField, "Must throw the correct error trying to get version from bad json")

in := []byte(`{"version":0,"exchanges":[{"name":"Juan"}]}`)
j, err := m.Upgrade(context.Background(), in)
j, err := m.Deploy(context.Background(), in, -1)
require.NoError(t, err)
require.Equal(t, string(in), string(j))

m.registerVersion(1, &Version1{})
j, err = m.Upgrade(context.Background(), in)
j, err = m.Deploy(context.Background(), in, -1)
require.NoError(t, err)
require.Contains(t, string(j), `"version":1`)

m.versions = m.versions[:1]
j, err = m.Upgrade(context.Background(), j)
j, err = m.Deploy(context.Background(), j, -1)
require.NoError(t, err)
require.Contains(t, string(j), `"version":0`)

m.versions = append(m.versions, &TestVersion2{ConfigErr: true, ExchErr: false}) // Bit hacky, but this will actually work
_, err = m.Upgrade(context.Background(), j)
_, err = m.Deploy(context.Background(), j, -1)
require.ErrorIs(t, err, errUpgrade)

m.versions[1] = &TestVersion2{ConfigErr: false, ExchErr: true}
_, err = m.Upgrade(context.Background(), in)
_, err = m.Deploy(context.Background(), in, -1)
require.Implements(t, (*ExchangeVersion)(nil), m.versions[1])
require.ErrorIs(t, err, errUpgrade)
}
Expand All @@ -61,7 +61,7 @@ func TestUpgrade(t *testing.T) {
func TestExchangeDeploy(t *testing.T) {
t.Parallel()
m := manager{}
_, err := m.Upgrade(context.Background(), []byte(``))
_, err := m.Deploy(context.Background(), []byte(``), -1)
assert.ErrorIs(t, err, errNoVersions)

v := &TestVersion2{}
Expand Down

0 comments on commit 91fc3d0

Please sign in to comment.