diff --git a/core/services/keystore/eth.go b/core/services/keystore/eth.go index a2b207044c0..cea7b6762e0 100644 --- a/core/services/keystore/eth.go +++ b/core/services/keystore/eth.go @@ -221,13 +221,18 @@ func (ks *eth) Enable(address common.Address, chainID *big.Int, qopts ...pg.QOpt func (ks *eth) enable(address common.Address, chainID *big.Int, qopts ...pg.QOpt) error { state := new(ethkey.State) q := ks.q.WithOpts(qopts...) - sql := `UPDATE evm.key_states SET disabled = false, updated_at = NOW() WHERE address = $1 AND evm_chain_id = $2 - RETURNING *;` + sql := `INSERT INTO evm.key_states as key_states ("address", "evm_chain_id", "disabled", "created_at", "updated_at") VALUES ($1, $2, false, NOW(), NOW()) + ON CONFLICT ("address", "evm_chain_id") DO UPDATE SET "disabled" = false, "updated_at" = NOW() WHERE key_states."address" = $1 AND key_states."evm_chain_id" = $2 + RETURNING *;` if err := q.Get(state, sql, address, chainID.String()); err != nil { return errors.Wrap(err, "failed to enable state") } - ks.keyStates.enable(address, chainID, state.UpdatedAt) + if state.CreatedAt.Equal(state.UpdatedAt) { + ks.keyStates.add(state) + } else { + ks.keyStates.enable(address, chainID, state.UpdatedAt) + } ks.notify() return nil } diff --git a/core/services/keystore/eth_test.go b/core/services/keystore/eth_test.go index 4165350300f..d9c5cf5cb89 100644 --- a/core/services/keystore/eth_test.go +++ b/core/services/keystore/eth_test.go @@ -541,6 +541,51 @@ func Test_EthKeyStore_SubscribeToKeyChanges(t *testing.T) { assertCountAtLeast(1) } +func Test_EthKeyStore_Enable(t *testing.T) { + t.Parallel() + + db := pgtest.NewSqlxDB(t) + cfg := configtest.NewTestGeneralConfig(t) + keyStore := cltest.NewKeyStore(t, db, cfg.Database()) + ks := keyStore.Eth() + + t.Run("already existing disabled key gets enabled", func(t *testing.T) { + k, _ := cltest.MustInsertRandomKeyNoChains(t, ks) + require.NoError(t, ks.Add(k.Address, testutils.SimulatedChainID)) + require.NoError(t, ks.Disable(k.Address, testutils.SimulatedChainID)) + require.NoError(t, ks.Enable(k.Address, testutils.SimulatedChainID)) + key, err := ks.GetState(k.Address.Hex(), testutils.SimulatedChainID) + require.NoError(t, err) + require.Equal(t, key.Disabled, false) + }) + + t.Run("creates key, deletes it unsafely and then enable creates it again", func(t *testing.T) { + k, _ := cltest.MustInsertRandomKeyNoChains(t, ks) + require.NoError(t, ks.Add(k.Address, testutils.SimulatedChainID)) + _, err := db.Exec("DELETE FROM evm.key_states WHERE address = $1", k.Address) + require.NoError(t, err) + require.NoError(t, ks.Enable(k.Address, testutils.SimulatedChainID)) + key, err := ks.GetState(k.Address.Hex(), testutils.SimulatedChainID) + require.NoError(t, err) + require.Equal(t, key.Disabled, false) + }) + + t.Run("creates key and enables it if it exists in the keystore, but is missing from key states db table", func(t *testing.T) { + k, _ := cltest.MustInsertRandomKeyNoChains(t, ks) + require.NoError(t, ks.Enable(k.Address, testutils.SimulatedChainID)) + key, err := ks.GetState(k.Address.Hex(), testutils.SimulatedChainID) + require.NoError(t, err) + require.Equal(t, key.Disabled, false) + }) + + t.Run("errors if key is not present in keystore", func(t *testing.T) { + addrNotInKs := testutils.NewAddress() + require.Error(t, ks.Enable(addrNotInKs, testutils.SimulatedChainID)) + _, err := ks.GetState(addrNotInKs.Hex(), testutils.SimulatedChainID) + require.Error(t, err) + }) +} + func Test_EthKeyStore_EnsureKeys(t *testing.T) { t.Parallel()