Skip to content

Commit

Permalink
Config: Fix config version downgrade
Browse files Browse the repository at this point in the history
  • Loading branch information
gbjk committed Jan 9, 2025
1 parent ef4790f commit 8457cb1
Show file tree
Hide file tree
Showing 4 changed files with 101 additions and 91 deletions.
114 changes: 47 additions & 67 deletions cmd/config/config.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package main

import (
"errors"
"context"
"flag"
"fmt"
"os"
Expand All @@ -11,9 +11,10 @@ import (
"github.com/buger/jsonparser"
"github.com/thrasher-corp/gocryptotrader/common/file"
"github.com/thrasher-corp/gocryptotrader/config"
"github.com/thrasher-corp/gocryptotrader/config/versions"
)

var commands = []string{"upgrade", "encrypt", "decrypt"}
var commands = []string{"upgrade", "downgrade", "encrypt", "decrypt"}

func main() {
fmt.Println("GoCryptoTrader: config-helper tool")
Expand All @@ -22,13 +23,15 @@ func main() {

var in, out, keyStr string
var inplace bool
var version int

fs := flag.NewFlagSet("config", flag.ExitOnError)
fs.Usage = func() { usage(fs) }
fs.StringVar(&in, "in", defaultCfgFile, "The config input file to process")
fs.StringVar(&out, "out", "[in].out", "The config output file")
fs.BoolVar(&inplace, "edit", false, "Edit; Save result to the original file")
fs.StringVar(&keyStr, "key", "", "The key to use for AES encryption")
fs.IntVar(&version, "version", 0, "The version to downgrade to")

cmd, args := parseCommand(os.Args[1:])
if cmd == "" {
Expand All @@ -46,83 +49,59 @@ func main() {
out = in + ".out"
}

key := []byte(keyStr)
var err error
switch cmd {
case "upgrade":
err = upgradeFile(in, out, key)
case "decrypt":
err = encryptWrapper(in, out, key, false, decryptFile)
case "encrypt":
err = encryptWrapper(in, out, key, true, encryptFile)
}
key := []byte(keyStr)
data := readFile(in)
isEncrypted := config.IsEncrypted(data)

if err != nil {
fatal(err.Error())
if cmd == "encrypt" && isEncrypted {
fatal("Error: File is already encrypted")
}

fmt.Println("Success! File written to " + out)
}

func upgradeFile(in, out string, key []byte) error {
c := &config.Config{
EncryptionKeyProvider: func(_ bool) ([]byte, error) {
if len(key) != 0 {
return key, nil
}
return config.PromptForConfigKey(false)
},
if len(key) == 0 && (isEncrypted || cmd == "encrypt") {
if key, err = config.PromptForConfigKey(cmd == "encrypt"); err != nil {
fatal(err.Error())
}
}

if err := c.ReadConfigFromFile(in, true); err != nil {
return err
if isEncrypted {
if data, err = config.DecryptConfigData(data, key); err != nil {
fatal(err.Error())
}
}

return c.SaveConfigToFile(out)
}

type encryptFunc func(string, []byte) ([]byte, error)

func encryptWrapper(in, out string, key []byte, confirmKey bool, fn encryptFunc) error {
if len(key) == 0 {
var err error
if key, err = config.PromptForConfigKey(confirmKey); err != nil {
return err
switch cmd {
case "decrypt":
if data, err = jsonparser.Set(data, []byte("-1"), "encryptConfig"); err != nil {
fatal("Unable to decrypt config data; Error: " + err.Error())
}
case "downgrade", "upgrade":
if version == 0 {
if cmd == "downgrade" {
fmt.Fprintln(os.Stderr, "Error: downgrade requires a version")
usage(fs)
os.Exit(3)
}
version = -1
}
if data, err = versions.Manager.Deploy(context.Background(), data, version); err != nil {
fatal("Unable to " + cmd + " config; Error: " + err.Error())
}
if !isEncrypted {
break
}
fallthrough
case "encrypt":
if data, err = config.EncryptConfigData(data, key); err != nil {
fatal("Unable to encrypt config data; Error: " + err.Error())
}
}
outData, err := fn(in, key)
if err != nil {
return err
}
if err := file.Write(out, outData); err != nil {
return fmt.Errorf("unable to write output file %s; Error: %w", out, err)
}
return nil
}

func encryptFile(in string, key []byte) ([]byte, error) {
if config.IsFileEncrypted(in) {
return nil, errors.New("file is already encrypted")
}
outData, err := config.EncryptConfigFile(readFile(in), key)
if err != nil {
return nil, fmt.Errorf("unable to encrypt config data. Error: %w", err)
if err := file.Write(out, data); err != nil {
fatal("Unable to write output file `" + out + "`; Error: " + err.Error())
}
return outData, nil
}

func decryptFile(in string, key []byte) ([]byte, error) {
if !config.IsFileEncrypted(in) {
return nil, errors.New("file is already decrypted")
}
outData, err := config.DecryptConfigFile(readFile(in), key)
if err != nil {
return nil, fmt.Errorf("unable to decrypt config data. Error: %w", err)
}
if outData, err = jsonparser.Set(outData, []byte("-1"), "encryptConfig"); err != nil {
return nil, fmt.Errorf("unable to decrypt config data. Error: %w", err)
}
return outData, nil
fmt.Println("Success! File written to " + out)
}

func readFile(in string) []byte {
Expand Down Expand Up @@ -152,7 +131,7 @@ func parseCommand(a []string) (cmd string, args []string) {
switch len(cmds) {
case 0:
fmt.Fprintln(os.Stderr, "No command provided")
case 1: //
case 1:
return cmds[0], rem
default:
fmt.Fprintln(os.Stderr, "Too many commands provided: "+strings.Join(cmds, ", "))
Expand All @@ -171,6 +150,7 @@ The commands are:
encrypt encrypt infile and write to outfile
decrypt decrypt infile and write to outfile
upgrade upgrade the version of a decrypted config file
downgrade downgrade the version of a decrypted config file to a specific version
The arguments are:`)
fs.PrintDefaults()
Expand Down
2 changes: 1 addition & 1 deletion config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -1512,7 +1512,7 @@ func (c *Config) readConfig(d io.Reader) error {
}
}

if j, err = versions.Manager.Deploy(context.Background(), j); err != nil {
if j, err = versions.Manager.Deploy(context.Background(), j, -1); err != nil {
return err
}

Expand Down
56 changes: 43 additions & 13 deletions config/versions/versions.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,11 @@ package versions
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"log"
"os"
"slices"
"strconv"
"sync"
Expand Down Expand Up @@ -55,16 +57,23 @@ type manager struct {
var Manager = &manager{}

// Deploy upgrades or downgrades the config between versions
func (m *manager) Deploy(ctx context.Context, j []byte) ([]byte, error) {
// version param -1 defaults to the latest version
// Prints an error an exits if the config file version or version param is not registered
func (m *manager) Deploy(ctx context.Context, j []byte, version int) ([]byte, error) {
if err := m.checkVersions(); err != nil {
return j, err
}

target, err := m.latest()
latest, err := m.latest()
if err != nil {
return j, err
}

target := latest
if version != -1 {
target = version
}

m.m.RLock()
defer m.m.RUnlock()

Expand All @@ -77,47 +86,59 @@ func (m *manager) Deploy(ctx context.Context, j []byte) ([]byte, error) {
return j, fmt.Errorf("%w `version`: %w", common.ErrGettingField, err)
case target == current:
return j, nil
case latest < current:
errVersionNotRegistered(current, latest, "Version in config file")
case target > latest:
errVersionNotRegistered(target, latest, "Target downgrade version")
}

for current != target {
next := current + 1
action := "upgrade"
patchVersion := current + 1
action := "upgrade to"
configMethod := ConfigVersion.UpgradeConfig
exchMethod := ExchangeVersion.UpgradeExchange

if target < current {
next = current - 1
action = "downgrade"
patchVersion = current
action = "downgrade from"
configMethod = ConfigVersion.DowngradeConfig
exchMethod = ExchangeVersion.DowngradeExchange
}

log.Printf("Running %s to config version %v\n", action, next)
log.Printf("Running %s config version %v\n", action, patchVersion)

patch := m.versions[next]
patch := m.versions[patchVersion]

if cPatch, ok := patch.(ConfigVersion); ok {
if j, err = configMethod(cPatch, ctx, j); err != nil {
return j, fmt.Errorf("%w %s to %v: %w", errApplyingVersion, action, next, err)
return j, fmt.Errorf("%w %s %v: %w", errApplyingVersion, action, patchVersion, err)
}
}

if ePatch, ok := patch.(ExchangeVersion); ok {
if j, err = exchangeDeploy(ctx, ePatch, exchMethod, j); err != nil {
return j, fmt.Errorf("%w %s to %v: %w", errApplyingVersion, action, next, err)
return j, fmt.Errorf("%w %s %v: %w", errApplyingVersion, action, patchVersion, err)
}
}

current = next
current = patchVersion
if target < current {
current = patchVersion - 1
}

if j, err = jsonparser.Set(j, []byte(strconv.Itoa(current)), "version"); err != nil {
return j, fmt.Errorf("%w `version` during %s to %v: %w", common.ErrSettingField, action, next, err)
return j, fmt.Errorf("%w `version` during %s %v: %w", common.ErrSettingField, action, patchVersion, err)
}
}

var out bytes.Buffer
if err = json.Indent(&out, j, "", " "); err != nil {
return j, fmt.Errorf("error formatting json: %w", err)
}

log.Println("Version management finished")

return j, nil
return out.Bytes(), nil
}

func exchangeDeploy(ctx context.Context, patch ExchangeVersion, method func(ExchangeVersion, context.Context, []byte) ([]byte, error), j []byte) ([]byte, error) {
Expand Down Expand Up @@ -196,3 +217,12 @@ func (m *manager) checkVersions() error {
}
return nil
}

func errVersionNotRegistered(current, latest int, msg string) {
fmt.Fprintf(os.Stderr, `
%s '%d' is higher than latest available version '%d'
Switch back to the version of GoCryptoTrader containing config version '%d' and run:
$ cmd/config downgrade %d
`, msg, current, latest, current, latest)
os.Exit(1)
}
20 changes: 10 additions & 10 deletions config/versions/versions_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,45 +13,45 @@ import (
func TestDeploy(t *testing.T) {
t.Parallel()
m := manager{}
_, err := m.Deploy(context.Background(), []byte(``))
_, err := m.Deploy(context.Background(), []byte(``), -1)
assert.ErrorIs(t, err, errNoVersions)

m.registerVersion(1, &TestVersion1{})
_, err = m.Deploy(context.Background(), []byte(``))
_, err = m.Deploy(context.Background(), []byte(``), -1)
require.ErrorIs(t, err, errVersionIncompatible)

m = manager{}

m.registerVersion(0, &Version0{})
_, err = m.Deploy(context.Background(), []byte(`not an object`))
_, err = m.Deploy(context.Background(), []byte(`not an object`), -1)
require.ErrorIs(t, err, jsonparser.KeyPathNotFoundError, "Must throw the correct error trying to add version to bad json")
require.ErrorIs(t, err, common.ErrSettingField, "Must throw the correct error trying to add version to bad json")
require.ErrorContains(t, err, "version", "Must throw the correct error trying to add version to bad json")

_, err = m.Deploy(context.Background(), []byte(`{"version":"not an int"}`))
_, err = m.Deploy(context.Background(), []byte(`{"version":"not an int"}`), -1)
require.ErrorIs(t, err, common.ErrGettingField, "Must throw the correct error trying to get version from bad json")

in := []byte(`{"version":0,"exchanges":[{"name":"Juan"}]}`)
j, err := m.Deploy(context.Background(), in)
j, err := m.Deploy(context.Background(), in, -1)
require.NoError(t, err)
require.Equal(t, string(in), string(j))

m.registerVersion(1, &Version1{})
j, err = m.Deploy(context.Background(), in)
j, err = m.Deploy(context.Background(), in, -1)
require.NoError(t, err)
require.Contains(t, string(j), `"version":1`)

m.versions = m.versions[:1]
j, err = m.Deploy(context.Background(), j)
j, err = m.Deploy(context.Background(), j, -1)
require.NoError(t, err)
require.Contains(t, string(j), `"version":0`)

m.versions = append(m.versions, &TestVersion2{ConfigErr: true, ExchErr: false}) // Bit hacky, but this will actually work
_, err = m.Deploy(context.Background(), j)
_, err = m.Deploy(context.Background(), j, -1)
require.ErrorIs(t, err, errUpgrade)

m.versions[1] = &TestVersion2{ConfigErr: false, ExchErr: true}
_, err = m.Deploy(context.Background(), in)
_, err = m.Deploy(context.Background(), in, -1)
require.Implements(t, (*ExchangeVersion)(nil), m.versions[1])
require.ErrorIs(t, err, errUpgrade)
}
Expand All @@ -61,7 +61,7 @@ func TestDeploy(t *testing.T) {
func TestExchangeDeploy(t *testing.T) {
t.Parallel()
m := manager{}
_, err := m.Deploy(context.Background(), []byte(``))
_, err := m.Deploy(context.Background(), []byte(``), -1)
assert.ErrorIs(t, err, errNoVersions)

v := &TestVersion2{}
Expand Down

0 comments on commit 8457cb1

Please sign in to comment.