Skip to content

Commit

Permalink
fix cmd/console load specific private key (#249)
Browse files Browse the repository at this point in the history
  • Loading branch information
bxq2011hust authored Jan 9, 2024
1 parent 670b1fe commit bc614c6
Show file tree
Hide file tree
Showing 9 changed files with 90 additions and 68 deletions.
32 changes: 19 additions & 13 deletions v3/client/go_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,10 @@ func (c *Client) SetPrivateKey(privateKey []byte) error {
return c.conn.GetCSDK().SetPrivateKey(privateKey)
}

func (c *Client) PrivateKeyBytes() []byte {
return c.conn.GetCSDK().PrivateKeyBytes()
}

// TransactionReceipt returns the receipt of a transaction by transaction hash.
// Note that the receipt is not available for pending transactions.
func (c *Client) TransactionReceipt(ctx context.Context, txHash common.Hash) (*types.Receipt, error) {
Expand Down Expand Up @@ -428,26 +432,29 @@ func (c *Client) GetPBFTView(ctx context.Context) ([]byte, error) {
// Raft consensus
}

type ConsensusNodeInfo struct {
ID string `json:"nodeID"`
Weight uint `json:"weight"`
}

// GetSealerList returns the list of consensus nodes' ID according to the groupID
func (c *Client) GetSealerList(ctx context.Context) ([]byte, error) {
var raw interface{}
func (c *Client) GetSealerList(ctx context.Context) ([]ConsensusNodeInfo, error) {
var raw []ConsensusNodeInfo
err := c.conn.CallContext(ctx, &raw, "getSealerList")
if err != nil {
return nil, err
}
js, err := json.MarshalIndent(raw, "", indent)
return js, err
return raw, err
}

// GetObserverList returns the list of observer nodes' ID according to the groupID
func (c *Client) GetObserverList(ctx context.Context) ([]byte, error) {
var raw interface{}
func (c *Client) GetObserverList(ctx context.Context) ([]string, error) {
var raw []string
err := c.conn.CallContext(ctx, &raw, "getObserverList")
if err != nil {
return nil, err
}
js, err := json.MarshalIndent(raw, "", indent)
return js, err
return raw, err
}

// GetConsensusStatus returns the status information about the consensus algorithm on a specific groupID
Expand Down Expand Up @@ -625,14 +632,13 @@ func (c *Client) GetContractAddress(ctx context.Context, txHash common.Hash) (co
}

// GetPendingTxSize returns amount of the pending transactions
func (c *Client) GetPendingTxSize(ctx context.Context) ([]byte, error) {
var raw interface{}
func (c *Client) GetPendingTxSize(ctx context.Context) (int64, error) {
var raw int64
err := c.conn.CallContext(ctx, &raw, "getPendingTxSize")
if err != nil {
return nil, err
return 0, err
}
js, err := json.MarshalIndent(raw, "", indent)
return js, err
return raw, err
}

// GetCode returns the contract code according to the contract address
Expand Down
6 changes: 3 additions & 3 deletions v3/client/go_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ func TestSealerList(t *testing.T) {
t.Fatalf("sealer list not found: %v", err)
}

t.Logf("sealer list:\n%s", sl)
t.Logf("sealer list:\n%+v", sl)
}

func TestObserverList(t *testing.T) {
Expand Down Expand Up @@ -264,7 +264,7 @@ func TestPendingTxSize(t *testing.T) {
t.Fatalf("pending transactions not found: %v", err)
}

t.Logf("the amount of the pending transactions:\n%s", raw)
t.Logf("the amount of the pending transactions:\n%d", raw)
}

func deployHelloWorld(t *testing.T) (*common.Address, *common.Hash) {
Expand Down Expand Up @@ -463,7 +463,7 @@ func TestAsnycHelloWorldSet(t *testing.T) {
t.Fatalf("parsed.Pack error: %v", err)
}
var wg sync.WaitGroup
count := 100
count := 50
for i := 0; i < count; i++ {
tx := types.NewSimpleTx(&address, input, HelloWorldABI, "", "", c.SMCrypto())
err = c.AsyncSendTransaction(context.Background(), tx, func(receipt *types.Receipt, err error) {
Expand Down
20 changes: 16 additions & 4 deletions v3/cmd/commandline/auth_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@ import (
"strconv"

"github.com/FISCO-BCOS/go-sdk/v3/precompiled/auth"
"github.com/FISCO-BCOS/go-sdk/v3/smcrypto"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/crypto"
"github.com/spf13/cobra"
)

Expand Down Expand Up @@ -183,7 +185,7 @@ Arguments:
For example:
[resetAdmin] [0x112fb844934c794a9e425dd6b4e57eff1b519f17][0x112fb844934c794a9e425dd6b4e57eff1b519f17]
resetAdmin 0x112fb844934c794a9e425dd6b4e57eff1b519f17 0x112fb844934c794a9e425dd6b4e57eff1b519f17
For more information please refer:
Expand All @@ -200,15 +202,25 @@ For more information please refer:
fmt.Printf("the format of contractAddr %v is invalid\n", contractAddr)
return
}

var currentAddress string
if RPC.SMCrypto() {
currentAddress = smcrypto.SM2KeyToAddress(RPC.PrivateKeyBytes()).Hex()
} else {
private, err := crypto.ToECDSA(RPC.PrivateKeyBytes())
if err != nil {
fmt.Printf("resetAdmin get current private failed, err:%v\n", err)
return
}
currentAddress = crypto.PubkeyToAddress(private.PublicKey).Hex()
}
authManagerService, err := auth.NewAuthManagerService(RPC)
if err != nil {
fmt.Printf("resetAdmin failed, err:%v\n", err)
fmt.Printf("resetAdmin failed, currentAccount: %s, err:%v\n", currentAddress, err)
return
}
result, err := authManagerService.ResetAdmin(common.HexToAddress(newAdmin), common.HexToAddress(contractAddr))
if err != nil {
fmt.Printf("resetAdmin failed, err: %v\n", err)
fmt.Printf("resetAdmin failed, currentAccount: %s, err: %v\n", currentAddress, err)
return
}
fmt.Println(result)
Expand Down
6 changes: 3 additions & 3 deletions v3/cmd/commandline/commands.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ var getSealerListCmd = &cobra.Command{
fmt.Printf("sealer list not found: %v\n", err)
return
}
fmt.Printf("Sealer List: \n%s\n", sealerList)
fmt.Printf("Sealer List: \n%+v\n", sealerList)
},
}

Expand Down Expand Up @@ -500,7 +500,7 @@ var getPendingTxSizeCmd = &cobra.Command{
fmt.Printf("transactions not found: %v\n", err)
return
}
fmt.Printf("Pending Transactions Count: \n hex: %s\n", tx)
fmt.Printf("Pending Transactions Count: %d\n", tx)
},
}

Expand Down Expand Up @@ -708,7 +708,7 @@ func init() {
// will be global for your application.

// rootCmd.PersistentFlags().StringVarP(&cfgFile, "config", "c", "", "config file (default is the project directory ./config.ini)")
rootCmd.PersistentFlags().StringVarP(&cfgFile, "privateKeyPath", "p", "", "private key file path of pem format")
rootCmd.PersistentFlags().StringVarP(&privateKeyFilePath, "privateKeyPath", "p", "", "private key file path of pem format")
rootCmd.PersistentFlags().BoolVarP(&smCrypto, "smCrypto", "s", false, "use smCrypto or not, default is false")
rootCmd.PersistentFlags().BoolVarP(&disableSsl, "disableSsl", "d", false, "switch off ssl or not, default use ssl")
rootCmd.PersistentFlags().StringVarP(&groupID, "groupID", "g", "group0", "groupID of FISCO BCOS chain")
Expand Down
45 changes: 31 additions & 14 deletions v3/cmd/commandline/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"strings"

"github.com/FISCO-BCOS/go-sdk/v3/client"
"github.com/FISCO-BCOS/go-sdk/v3/smcrypto"
"github.com/spf13/cobra"
)

Expand Down Expand Up @@ -75,27 +76,43 @@ func initConfig() {
_, err := os.Stat(privateKeyFilePath)
if err != nil && os.IsNotExist(err) {
fmt.Println("private key file set but not exist, use default private key")
}
key, curve, err := client.LoadECPrivateKeyFromPEM(privateKeyFilePath)
if err != nil {
fmt.Printf("parse private key failed, err: %v\n", err)
return
}
if smCrypto && curve != client.Sm2p256v1 {
fmt.Printf("smCrypto should use sm2p256v1 private key, but found %s\n", curve)
} else if err != nil {
fmt.Printf("check private key file failed, err: %v\n", err)
return
} else {
key, curve, err := client.LoadECPrivateKeyFromPEM(privateKeyFilePath)
if err != nil {
fmt.Printf("parse private key failed, err: %v\n", err)
return
}
if smCrypto && curve != client.Sm2p256v1 {
fmt.Printf("smCrypto should use sm2p256v1 private key, but found %s\n", curve)
return
}
if !smCrypto && curve != client.Secp256k1 {
fmt.Printf("should use secp256k1 private key, but found %s\n", curve)
return
}
privateKey = key
}
if !smCrypto && curve != client.Secp256k1 {
fmt.Printf("should use secp256k1 private key, but found %s\n", curve)
}
privateKey = key
} else {
address := "0xFbb18d54e9Ee57529cda8c7c52242EFE879f064F"
privateKey, _ = hex.DecodeString("145e247e170ba3afd6ae97e88f00dbc976c2345d511b0f6713355d19d8b80b58")
if smCrypto {
address = smcrypto.SM2KeyToAddress(privateKey).Hex()
}
fmt.Println("use default private key, address: ", address)
}
ret := strings.Split(nodeEndpoint, ":")
host := ret[0]
port, _ := strconv.Atoi(ret[1])
config := &client.Config{IsSMCrypto: smCrypto, GroupID: groupID, DisableSsl: disableSsl,
PrivateKey: privateKey, Host: host, Port: port, TLSCaFile: certPath + "/ca.crt", TLSKeyFile: certPath + "/sdk.key", TLSCertFile: certPath + "/sdk.crt"}
var config *client.Config
if !smCrypto {
config = &client.Config{IsSMCrypto: smCrypto, GroupID: groupID, DisableSsl: disableSsl,
PrivateKey: privateKey, Host: host, Port: port, TLSCaFile: certPath + "/ca.crt", TLSKeyFile: certPath + "/sdk.key", TLSCertFile: certPath + "/sdk.crt"}
} else {
config = &client.Config{IsSMCrypto: smCrypto, GroupID: groupID, DisableSsl: disableSsl,
PrivateKey: privateKey, Host: host, Port: port, TLSCaFile: certPath + "/sm_ca.crt", TLSKeyFile: certPath + "/sm_sdk.key", TLSCertFile: certPath + "/sm_sdk.crt", TLSSmEnKeyFile: certPath + "/sm_ensdk.key", TLSSmEnCertFile: certPath + "/sm_ensdk.crt"}
}
RPC = getClient(config)
}
2 changes: 1 addition & 1 deletion v3/cmd/commandline/system_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ For example:
setSystemConfigByKey tx_count_limit 10000`, configSet.String()),
Args: cobra.ExactArgs(2),
Run: func(cmd *cobra.Command, args []string) {
if configSet.Contains(args[0]) {
if !configSet.Contains(args[0]) {
fmt.Printf("The key not found: %s, currently only support %v", args[0], configSet)
return
}
Expand Down
4 changes: 2 additions & 2 deletions v3/examples/parallelok/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ func main() {
fmt.Println("start perf groupID:", groupID, "userCount:", userCount, "total:", total, "qps:", qps)

privateKey, _ := hex.DecodeString("145e247e170ba3afd6ae97e88f00dbc976c2345d511b0f6713355d19d8b80b58")
config := &client.Config{IsSMCrypto: false, GroupID: groupID, DisableSsl: true,
config := &client.Config{IsSMCrypto: false, GroupID: groupID, DisableSsl: false,
PrivateKey: privateKey, Host: "127.0.0.1", Port: 20200, TLSCaFile: "./conf/ca.crt", TLSKeyFile: "./conf/sdk.key", TLSCertFile: "./conf/sdk.crt"}
client, err := client.DialContext(context.Background(), config)
// client, err := client.Dial("./config.ini", groupID, privateKey)
Expand Down Expand Up @@ -185,7 +185,7 @@ func main() {
wg2.Wait()

// check balance
fmt.Println("check balance")
fmt.Println("check balance...")
var wg3 sync.WaitGroup
for i := 0; i < userCount; i++ {
wg3.Add(1)
Expand Down
41 changes: 14 additions & 27 deletions v3/precompiled/consensus/consensus_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,6 @@ const (
invalidNodeID int64 = -51100
)

type nodeIdList struct {
nodeID string `json:"nodeID"`
weight uint `json:"weight"`
}

// getErrorMessage returns the message of error code
func getErrorMessage(errorCode int64) string {
var message string
Expand Down Expand Up @@ -82,17 +77,11 @@ func (service *Service) AddObserver(nodeID string) (int64, error) {
// return precompiled.DefaultErrorCode, fmt.Errorf("the node is not reachable")
//}

observerRaw, err := service.client.GetObserverList(context.Background())
nodeIDs, err := service.client.GetObserverList(context.Background())
if err != nil {
return precompiled.DefaultErrorCode, fmt.Errorf("get the observer list failed: %v", err)
}

var nodeIDs []string
err = json.Unmarshal(observerRaw, &nodeIDs)
if err != nil {
return precompiled.DefaultErrorCode, fmt.Errorf("unmarshal the observer list failed: %v", err)
}

for _, nID := range nodeIDs {
if nID == nodeID {
return precompiled.DefaultErrorCode, fmt.Errorf("the node is already in the observer list")
Expand All @@ -115,19 +104,13 @@ func (service *Service) AddSealer(nodeID string, weight int64) (int64, error) {
// return precompiled.DefaultErrorCode, fmt.Errorf("the node is not reachable")
//}

sealerRaw, err := service.client.GetSealerList(context.Background())
nodes, err := service.client.GetSealerList(context.Background())
if err != nil {
return precompiled.DefaultErrorCode, fmt.Errorf("get the sealer list failed: %v", err)
}

var nodeIDs []nodeIdList
err = json.Unmarshal(sealerRaw, &nodeIDs)
if err != nil {
return precompiled.DefaultErrorCode, fmt.Errorf("unmarshal the sealer list failed: %v", err)
}

for _, nID := range nodeIDs {
if nID.nodeID == nodeID {
for _, node := range nodes {
if node.ID == nodeID {
return precompiled.DefaultErrorCode, fmt.Errorf("the node is already in the sealer list")
}
}
Expand Down Expand Up @@ -194,17 +177,21 @@ func (service *Service) isValidNodeID(nodeID string) (bool, error) {
}

func (service *Service) SetWeight(nodeID string, weight int64) (int64, error) {
sealerRaw, err := service.client.GetSealerList(context.Background())
nodes, err := service.client.GetSealerList(context.Background())
if err != nil {
return precompiled.DefaultErrorCode, fmt.Errorf("get the sealer list failed: %v", err)
}

var nodeIDs []nodeIdList
err = json.Unmarshal(sealerRaw, &nodeIDs)
if err != nil {
return precompiled.DefaultErrorCode, fmt.Errorf("unmarshal the sealer list failed: %v", err)
find := false
for _, node := range nodes {
if node.ID == nodeID {
find = true
break
}
}
if !find {
return precompiled.DefaultErrorCode, fmt.Errorf("the node is not in the sealer list")
}

_, _, receipt, err := service.consensus.SetWeight(service.consensusAuth, nodeID, big.NewInt(weight))
if err != nil {
return precompiled.DefaultErrorCode, fmt.Errorf("ConsensusService setWeight failed: %+v", err)
Expand Down
2 changes: 1 addition & 1 deletion v3/types/receipt.go
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ func (r *Receipt) GetErrorMessage() string {
return fmt.Sprintf("receipt error code: %v, receipt error message: %v", r.Status, errorMessage)
}

// String returns the string representation of Receipt sturct.
// String returns the string representation of Receipt struct.
func (r *Receipt) String() string {
out, err := json.MarshalIndent(r, "", "\t")
if err != nil {
Expand Down

0 comments on commit bc614c6

Please sign in to comment.