From bc614c68be70812b87b64f0150c689e656a09d5d Mon Sep 17 00:00:00 2001 From: XingQiang Bai Date: Tue, 9 Jan 2024 11:01:18 +0800 Subject: [PATCH] fix cmd/console load specific private key (#249) --- v3/client/go_client.go | 32 +++++++------ v3/client/go_client_test.go | 6 +-- v3/cmd/commandline/auth_manager.go | 20 +++++++-- v3/cmd/commandline/commands.go | 6 +-- v3/cmd/commandline/root.go | 45 +++++++++++++------ v3/cmd/commandline/system_config.go | 2 +- v3/examples/parallelok/main.go | 4 +- v3/precompiled/consensus/consensus_service.go | 41 ++++++----------- v3/types/receipt.go | 2 +- 9 files changed, 90 insertions(+), 68 deletions(-) diff --git a/v3/client/go_client.go b/v3/client/go_client.go index 0648f79f..69937ed5 100644 --- a/v3/client/go_client.go +++ b/v3/client/go_client.go @@ -327,6 +327,10 @@ func (c *Client) SetPrivateKey(privateKey []byte) error { return c.conn.GetCSDK().SetPrivateKey(privateKey) } +func (c *Client) PrivateKeyBytes() []byte { + return c.conn.GetCSDK().PrivateKeyBytes() +} + // TransactionReceipt returns the receipt of a transaction by transaction hash. // Note that the receipt is not available for pending transactions. func (c *Client) TransactionReceipt(ctx context.Context, txHash common.Hash) (*types.Receipt, error) { @@ -428,26 +432,29 @@ func (c *Client) GetPBFTView(ctx context.Context) ([]byte, error) { // Raft consensus } +type ConsensusNodeInfo struct { + ID string `json:"nodeID"` + Weight uint `json:"weight"` +} + // GetSealerList returns the list of consensus nodes' ID according to the groupID -func (c *Client) GetSealerList(ctx context.Context) ([]byte, error) { - var raw interface{} +func (c *Client) GetSealerList(ctx context.Context) ([]ConsensusNodeInfo, error) { + var raw []ConsensusNodeInfo err := c.conn.CallContext(ctx, &raw, "getSealerList") if err != nil { return nil, err } - js, err := json.MarshalIndent(raw, "", indent) - return js, err + return raw, err } // GetObserverList returns the list of observer nodes' ID according to the groupID -func (c *Client) GetObserverList(ctx context.Context) ([]byte, error) { - var raw interface{} +func (c *Client) GetObserverList(ctx context.Context) ([]string, error) { + var raw []string err := c.conn.CallContext(ctx, &raw, "getObserverList") if err != nil { return nil, err } - js, err := json.MarshalIndent(raw, "", indent) - return js, err + return raw, err } // GetConsensusStatus returns the status information about the consensus algorithm on a specific groupID @@ -625,14 +632,13 @@ func (c *Client) GetContractAddress(ctx context.Context, txHash common.Hash) (co } // GetPendingTxSize returns amount of the pending transactions -func (c *Client) GetPendingTxSize(ctx context.Context) ([]byte, error) { - var raw interface{} +func (c *Client) GetPendingTxSize(ctx context.Context) (int64, error) { + var raw int64 err := c.conn.CallContext(ctx, &raw, "getPendingTxSize") if err != nil { - return nil, err + return 0, err } - js, err := json.MarshalIndent(raw, "", indent) - return js, err + return raw, err } // GetCode returns the contract code according to the contract address diff --git a/v3/client/go_client_test.go b/v3/client/go_client_test.go index 8794107f..436fe932 100644 --- a/v3/client/go_client_test.go +++ b/v3/client/go_client_test.go @@ -161,7 +161,7 @@ func TestSealerList(t *testing.T) { t.Fatalf("sealer list not found: %v", err) } - t.Logf("sealer list:\n%s", sl) + t.Logf("sealer list:\n%+v", sl) } func TestObserverList(t *testing.T) { @@ -264,7 +264,7 @@ func TestPendingTxSize(t *testing.T) { t.Fatalf("pending transactions not found: %v", err) } - t.Logf("the amount of the pending transactions:\n%s", raw) + t.Logf("the amount of the pending transactions:\n%d", raw) } func deployHelloWorld(t *testing.T) (*common.Address, *common.Hash) { @@ -463,7 +463,7 @@ func TestAsnycHelloWorldSet(t *testing.T) { t.Fatalf("parsed.Pack error: %v", err) } var wg sync.WaitGroup - count := 100 + count := 50 for i := 0; i < count; i++ { tx := types.NewSimpleTx(&address, input, HelloWorldABI, "", "", c.SMCrypto()) err = c.AsyncSendTransaction(context.Background(), tx, func(receipt *types.Receipt, err error) { diff --git a/v3/cmd/commandline/auth_manager.go b/v3/cmd/commandline/auth_manager.go index e0fcc52a..df23d3a5 100644 --- a/v3/cmd/commandline/auth_manager.go +++ b/v3/cmd/commandline/auth_manager.go @@ -8,7 +8,9 @@ import ( "strconv" "github.com/FISCO-BCOS/go-sdk/v3/precompiled/auth" + "github.com/FISCO-BCOS/go-sdk/v3/smcrypto" "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/crypto" "github.com/spf13/cobra" ) @@ -183,7 +185,7 @@ Arguments: For example: - [resetAdmin] [0x112fb844934c794a9e425dd6b4e57eff1b519f17][0x112fb844934c794a9e425dd6b4e57eff1b519f17] + resetAdmin 0x112fb844934c794a9e425dd6b4e57eff1b519f17 0x112fb844934c794a9e425dd6b4e57eff1b519f17 For more information please refer: @@ -200,15 +202,25 @@ For more information please refer: fmt.Printf("the format of contractAddr %v is invalid\n", contractAddr) return } - + var currentAddress string + if RPC.SMCrypto() { + currentAddress = smcrypto.SM2KeyToAddress(RPC.PrivateKeyBytes()).Hex() + } else { + private, err := crypto.ToECDSA(RPC.PrivateKeyBytes()) + if err != nil { + fmt.Printf("resetAdmin get current private failed, err:%v\n", err) + return + } + currentAddress = crypto.PubkeyToAddress(private.PublicKey).Hex() + } authManagerService, err := auth.NewAuthManagerService(RPC) if err != nil { - fmt.Printf("resetAdmin failed, err:%v\n", err) + fmt.Printf("resetAdmin failed, currentAccount: %s, err:%v\n", currentAddress, err) return } result, err := authManagerService.ResetAdmin(common.HexToAddress(newAdmin), common.HexToAddress(contractAddr)) if err != nil { - fmt.Printf("resetAdmin failed, err: %v\n", err) + fmt.Printf("resetAdmin failed, currentAccount: %s, err: %v\n", currentAddress, err) return } fmt.Println(result) diff --git a/v3/cmd/commandline/commands.go b/v3/cmd/commandline/commands.go index 1c2b5518..086b2d49 100644 --- a/v3/cmd/commandline/commands.go +++ b/v3/cmd/commandline/commands.go @@ -116,7 +116,7 @@ var getSealerListCmd = &cobra.Command{ fmt.Printf("sealer list not found: %v\n", err) return } - fmt.Printf("Sealer List: \n%s\n", sealerList) + fmt.Printf("Sealer List: \n%+v\n", sealerList) }, } @@ -500,7 +500,7 @@ var getPendingTxSizeCmd = &cobra.Command{ fmt.Printf("transactions not found: %v\n", err) return } - fmt.Printf("Pending Transactions Count: \n hex: %s\n", tx) + fmt.Printf("Pending Transactions Count: %d\n", tx) }, } @@ -708,7 +708,7 @@ func init() { // will be global for your application. // rootCmd.PersistentFlags().StringVarP(&cfgFile, "config", "c", "", "config file (default is the project directory ./config.ini)") - rootCmd.PersistentFlags().StringVarP(&cfgFile, "privateKeyPath", "p", "", "private key file path of pem format") + rootCmd.PersistentFlags().StringVarP(&privateKeyFilePath, "privateKeyPath", "p", "", "private key file path of pem format") rootCmd.PersistentFlags().BoolVarP(&smCrypto, "smCrypto", "s", false, "use smCrypto or not, default is false") rootCmd.PersistentFlags().BoolVarP(&disableSsl, "disableSsl", "d", false, "switch off ssl or not, default use ssl") rootCmd.PersistentFlags().StringVarP(&groupID, "groupID", "g", "group0", "groupID of FISCO BCOS chain") diff --git a/v3/cmd/commandline/root.go b/v3/cmd/commandline/root.go index 82d15b35..6b628d07 100644 --- a/v3/cmd/commandline/root.go +++ b/v3/cmd/commandline/root.go @@ -9,6 +9,7 @@ import ( "strings" "github.com/FISCO-BCOS/go-sdk/v3/client" + "github.com/FISCO-BCOS/go-sdk/v3/smcrypto" "github.com/spf13/cobra" ) @@ -75,27 +76,43 @@ func initConfig() { _, err := os.Stat(privateKeyFilePath) if err != nil && os.IsNotExist(err) { fmt.Println("private key file set but not exist, use default private key") - } - key, curve, err := client.LoadECPrivateKeyFromPEM(privateKeyFilePath) - if err != nil { - fmt.Printf("parse private key failed, err: %v\n", err) - return - } - if smCrypto && curve != client.Sm2p256v1 { - fmt.Printf("smCrypto should use sm2p256v1 private key, but found %s\n", curve) + } else if err != nil { + fmt.Printf("check private key file failed, err: %v\n", err) return + } else { + key, curve, err := client.LoadECPrivateKeyFromPEM(privateKeyFilePath) + if err != nil { + fmt.Printf("parse private key failed, err: %v\n", err) + return + } + if smCrypto && curve != client.Sm2p256v1 { + fmt.Printf("smCrypto should use sm2p256v1 private key, but found %s\n", curve) + return + } + if !smCrypto && curve != client.Secp256k1 { + fmt.Printf("should use secp256k1 private key, but found %s\n", curve) + return + } + privateKey = key } - if !smCrypto && curve != client.Secp256k1 { - fmt.Printf("should use secp256k1 private key, but found %s\n", curve) - } - privateKey = key } else { + address := "0xFbb18d54e9Ee57529cda8c7c52242EFE879f064F" privateKey, _ = hex.DecodeString("145e247e170ba3afd6ae97e88f00dbc976c2345d511b0f6713355d19d8b80b58") + if smCrypto { + address = smcrypto.SM2KeyToAddress(privateKey).Hex() + } + fmt.Println("use default private key, address: ", address) } ret := strings.Split(nodeEndpoint, ":") host := ret[0] port, _ := strconv.Atoi(ret[1]) - config := &client.Config{IsSMCrypto: smCrypto, GroupID: groupID, DisableSsl: disableSsl, - PrivateKey: privateKey, Host: host, Port: port, TLSCaFile: certPath + "/ca.crt", TLSKeyFile: certPath + "/sdk.key", TLSCertFile: certPath + "/sdk.crt"} + var config *client.Config + if !smCrypto { + config = &client.Config{IsSMCrypto: smCrypto, GroupID: groupID, DisableSsl: disableSsl, + PrivateKey: privateKey, Host: host, Port: port, TLSCaFile: certPath + "/ca.crt", TLSKeyFile: certPath + "/sdk.key", TLSCertFile: certPath + "/sdk.crt"} + } else { + config = &client.Config{IsSMCrypto: smCrypto, GroupID: groupID, DisableSsl: disableSsl, + PrivateKey: privateKey, Host: host, Port: port, TLSCaFile: certPath + "/sm_ca.crt", TLSKeyFile: certPath + "/sm_sdk.key", TLSCertFile: certPath + "/sm_sdk.crt", TLSSmEnKeyFile: certPath + "/sm_ensdk.key", TLSSmEnCertFile: certPath + "/sm_ensdk.crt"} + } RPC = getClient(config) } diff --git a/v3/cmd/commandline/system_config.go b/v3/cmd/commandline/system_config.go index 9c8d58b8..28dbc96b 100644 --- a/v3/cmd/commandline/system_config.go +++ b/v3/cmd/commandline/system_config.go @@ -22,7 +22,7 @@ For example: setSystemConfigByKey tx_count_limit 10000`, configSet.String()), Args: cobra.ExactArgs(2), Run: func(cmd *cobra.Command, args []string) { - if configSet.Contains(args[0]) { + if !configSet.Contains(args[0]) { fmt.Printf("The key not found: %s, currently only support %v", args[0], configSet) return } diff --git a/v3/examples/parallelok/main.go b/v3/examples/parallelok/main.go index 4f37640e..ac491111 100644 --- a/v3/examples/parallelok/main.go +++ b/v3/examples/parallelok/main.go @@ -39,7 +39,7 @@ func main() { fmt.Println("start perf groupID:", groupID, "userCount:", userCount, "total:", total, "qps:", qps) privateKey, _ := hex.DecodeString("145e247e170ba3afd6ae97e88f00dbc976c2345d511b0f6713355d19d8b80b58") - config := &client.Config{IsSMCrypto: false, GroupID: groupID, DisableSsl: true, + config := &client.Config{IsSMCrypto: false, GroupID: groupID, DisableSsl: false, PrivateKey: privateKey, Host: "127.0.0.1", Port: 20200, TLSCaFile: "./conf/ca.crt", TLSKeyFile: "./conf/sdk.key", TLSCertFile: "./conf/sdk.crt"} client, err := client.DialContext(context.Background(), config) // client, err := client.Dial("./config.ini", groupID, privateKey) @@ -185,7 +185,7 @@ func main() { wg2.Wait() // check balance - fmt.Println("check balance") + fmt.Println("check balance...") var wg3 sync.WaitGroup for i := 0; i < userCount; i++ { wg3.Add(1) diff --git a/v3/precompiled/consensus/consensus_service.go b/v3/precompiled/consensus/consensus_service.go index 2091caad..2c7c0864 100644 --- a/v3/precompiled/consensus/consensus_service.go +++ b/v3/precompiled/consensus/consensus_service.go @@ -19,11 +19,6 @@ const ( invalidNodeID int64 = -51100 ) -type nodeIdList struct { - nodeID string `json:"nodeID"` - weight uint `json:"weight"` -} - // getErrorMessage returns the message of error code func getErrorMessage(errorCode int64) string { var message string @@ -82,17 +77,11 @@ func (service *Service) AddObserver(nodeID string) (int64, error) { // return precompiled.DefaultErrorCode, fmt.Errorf("the node is not reachable") //} - observerRaw, err := service.client.GetObserverList(context.Background()) + nodeIDs, err := service.client.GetObserverList(context.Background()) if err != nil { return precompiled.DefaultErrorCode, fmt.Errorf("get the observer list failed: %v", err) } - var nodeIDs []string - err = json.Unmarshal(observerRaw, &nodeIDs) - if err != nil { - return precompiled.DefaultErrorCode, fmt.Errorf("unmarshal the observer list failed: %v", err) - } - for _, nID := range nodeIDs { if nID == nodeID { return precompiled.DefaultErrorCode, fmt.Errorf("the node is already in the observer list") @@ -115,19 +104,13 @@ func (service *Service) AddSealer(nodeID string, weight int64) (int64, error) { // return precompiled.DefaultErrorCode, fmt.Errorf("the node is not reachable") //} - sealerRaw, err := service.client.GetSealerList(context.Background()) + nodes, err := service.client.GetSealerList(context.Background()) if err != nil { return precompiled.DefaultErrorCode, fmt.Errorf("get the sealer list failed: %v", err) } - var nodeIDs []nodeIdList - err = json.Unmarshal(sealerRaw, &nodeIDs) - if err != nil { - return precompiled.DefaultErrorCode, fmt.Errorf("unmarshal the sealer list failed: %v", err) - } - - for _, nID := range nodeIDs { - if nID.nodeID == nodeID { + for _, node := range nodes { + if node.ID == nodeID { return precompiled.DefaultErrorCode, fmt.Errorf("the node is already in the sealer list") } } @@ -194,17 +177,21 @@ func (service *Service) isValidNodeID(nodeID string) (bool, error) { } func (service *Service) SetWeight(nodeID string, weight int64) (int64, error) { - sealerRaw, err := service.client.GetSealerList(context.Background()) + nodes, err := service.client.GetSealerList(context.Background()) if err != nil { return precompiled.DefaultErrorCode, fmt.Errorf("get the sealer list failed: %v", err) } - var nodeIDs []nodeIdList - err = json.Unmarshal(sealerRaw, &nodeIDs) - if err != nil { - return precompiled.DefaultErrorCode, fmt.Errorf("unmarshal the sealer list failed: %v", err) + find := false + for _, node := range nodes { + if node.ID == nodeID { + find = true + break + } + } + if !find { + return precompiled.DefaultErrorCode, fmt.Errorf("the node is not in the sealer list") } - _, _, receipt, err := service.consensus.SetWeight(service.consensusAuth, nodeID, big.NewInt(weight)) if err != nil { return precompiled.DefaultErrorCode, fmt.Errorf("ConsensusService setWeight failed: %+v", err) diff --git a/v3/types/receipt.go b/v3/types/receipt.go index a1c37c39..284d637a 100644 --- a/v3/types/receipt.go +++ b/v3/types/receipt.go @@ -197,7 +197,7 @@ func (r *Receipt) GetErrorMessage() string { return fmt.Sprintf("receipt error code: %v, receipt error message: %v", r.Status, errorMessage) } -// String returns the string representation of Receipt sturct. +// String returns the string representation of Receipt struct. func (r *Receipt) String() string { out, err := json.MarshalIndent(r, "", "\t") if err != nil {