diff --git a/README.md b/README.md index e981dbdb13..dd764d63e5 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@

Ontology

-

Version 1.6.0

+

Version 1.8.0

[![GoDoc](https://godoc.org/github.com/ontio/ontology?status.svg)](https://godoc.org/github.com/ontio/ontology) [![Go Report Card](https://goreportcard.com/badge/github.com/ontio/ontology)](https://goreportcard.com/report/github.com/ontio/ontology) @@ -49,7 +49,7 @@ New features are still being rapidly developed, therefore the master branch may ## Build Development Environment The requirements to build Ontology are: -- [Golang](https://golang.org/doc/install) version 1.9 or later +- [Golang](https://golang.org/doc/install) version 1.11 or later - [Glide](https://glide.sh) (a third party package management tool for Golang) ## Download Ontology diff --git a/README_CN.md b/README_CN.md index 82a2683c9c..34ba977aa2 100644 --- a/README_CN.md +++ b/README_CN.md @@ -1,6 +1,6 @@

Ontology

-

Version 1.0

+

Version 1.8.0

[![GoDoc](https://godoc.org/github.com/ontio/ontology?status.svg)](https://godoc.org/github.com/ontio/ontology) [![Go Report Card](https://goreportcard.com/badge/github.com/ontio/ontology)](https://goreportcard.com/report/github.com/ontio/ontology) @@ -52,7 +52,7 @@ Ontology MainNet 已经在2018年6月30日成功上线。
## 构建开发环境 成功编译ontology需要以下准备: -* Golang版本在1.9及以上 +* Golang版本在1.11及以上 * 安装第三方包管理工具glide * 正确的Go语言开发环境 * Golang所支持的操作系统 diff --git a/cmd/import_cmd.go b/cmd/import_cmd.go index aebea4c100..5cc0180024 100644 --- a/cmd/import_cmd.go +++ b/cmd/import_cmd.go @@ -136,39 +136,56 @@ func importBlocks(ctx *cli.Context) error { PrintInfoMsg("Start import blocks.") - for i := uint32(startBlockHeight); i <= endBlockHeight; i++ { - size, err := serialization.ReadUint32(fReader) - if err != nil { - return fmt.Errorf("read block height:%d error:%s", i, err) - } - compressData := make([]byte, size) - _, err = io.ReadFull(fReader, compressData) - if err != nil { - return fmt.Errorf("read block data height:%d error:%s", i, err) + var readErr error + blocks := make(chan *types.Block, 10) + go func() { + defer close(blocks) + for i := uint32(startBlockHeight); i <= endBlockHeight; i++ { + size, err := serialization.ReadUint32(fReader) + if err != nil { + readErr = fmt.Errorf("read block height:%d error:%s", i, err) + break + } + compressData := make([]byte, size) + _, err = io.ReadFull(fReader, compressData) + if err != nil { + readErr = fmt.Errorf("read block data height:%d error:%s", i, err) + break + } + + if i <= currBlockHeight { + continue + } + + blockData, err := utils.DecompressBlockData(compressData, metadata.CompressType) + if err != nil { + readErr = fmt.Errorf("block height:%d decompress error:%s", i, err) + break + } + block, err := types.BlockFromRawBytes(blockData) + if err != nil { + readErr = fmt.Errorf("block height:%d deserialize error:%s", i, err) + break + } + blocks <- block } + }() - if i <= currBlockHeight { - continue - } - - blockData, err := utils.DecompressBlockData(compressData, metadata.CompressType) - if err != nil { - return fmt.Errorf("block height:%d decompress error:%s", i, err) - } - block, err := types.BlockFromRawBytes(blockData) - if err != nil { - return fmt.Errorf("block height:%d deserialize error:%s", i, err) - } + for block := range blocks { execResult, err := ledger.DefLedger.ExecuteBlock(block) if err != nil { - return fmt.Errorf("block height:%d ExecuteBlock error:%s", i, err) + return fmt.Errorf("block height:%d ExecuteBlock error:%s", block.Header.Height, err) } err = ledger.DefLedger.SubmitBlock(block, execResult) if err != nil { - return fmt.Errorf("SubmitBlock block height:%d error:%s", i, err) + return fmt.Errorf("SubmitBlock block height:%d error:%s", block.Header.Height, err) } bar.Incr() } + if readErr != nil { + return readErr + } + uiprogress.Stop() PrintInfoMsg("Import block completed, current block height:%d.", ledger.DefLedger.GetCurrentBlockHeight()) return nil diff --git a/cmd/sigsvr/handlers/sig_native_invoke_tx.go b/cmd/sigsvr/handlers/sig_native_invoke_tx.go index 83acec8779..bff493e0f5 100644 --- a/cmd/sigsvr/handlers/sig_native_invoke_tx.go +++ b/cmd/sigsvr/handlers/sig_native_invoke_tx.go @@ -19,7 +19,6 @@ package handlers import ( - "bytes" "encoding/hex" "encoding/json" "github.com/ontio/ontology/cmd/abi" @@ -102,14 +101,7 @@ func SigNativeInvokeTx(req *clisvrcom.CliRpcRequest, resp *clisvrcom.CliRpcRespo return } - buf := bytes.NewBuffer(nil) - err = immutable.Serialize(buf) - if err != nil { - log.Infof("Cli Qid:%s SigNativeInvokeTx tx Serialize error:%s", req.Qid, err) - resp.ErrorCode = clisvrcom.CLIERR_INTERNAL_ERR - return - } resp.Result = &SigNativeInvokeTxRsp{ - SignedTx: hex.EncodeToString(buf.Bytes()), + SignedTx: hex.EncodeToString(common.SerializeToBytes(immutable)), } } diff --git a/cmd/sigsvr/handlers/sig_neovm_invoke_tx_abi.go b/cmd/sigsvr/handlers/sig_neovm_invoke_tx_abi.go index 7042b9c12f..834d6ea04b 100644 --- a/cmd/sigsvr/handlers/sig_neovm_invoke_tx_abi.go +++ b/cmd/sigsvr/handlers/sig_neovm_invoke_tx_abi.go @@ -18,7 +18,6 @@ package handlers import ( - "bytes" "encoding/hex" "encoding/json" clisvrcom "github.com/ontio/ontology/cmd/sigsvr/common" @@ -107,16 +106,7 @@ func SigNeoVMInvokeAbiTx(req *clisvrcom.CliRpcRequest, resp *clisvrcom.CliRpcRes resp.ErrorCode = clisvrcom.CLIERR_INTERNAL_ERR return } - sink := common.ZeroCopySink{} - tx.Serialization(&sink) - buf := bytes.NewBuffer(nil) - err = tx.Serialize(buf) - if err != nil { - log.Infof("Cli Qid:%s SigNeoVMInvokeAbiTx tx Serialize error:%s", req.Qid, err) - resp.ErrorCode = clisvrcom.CLIERR_INTERNAL_ERR - return - } resp.Result = &SigNeoVMInvokeTxAbiRsp{ - SignedTx: hex.EncodeToString(buf.Bytes()), + SignedTx: hex.EncodeToString(common.SerializeToBytes(tx)), } } diff --git a/cmd/utils/ont.go b/cmd/utils/ont.go index fc8b2ff6dc..e3c7933f52 100644 --- a/cmd/utils/ont.go +++ b/cmd/utils/ont.go @@ -38,6 +38,7 @@ import ( "github.com/ontio/ontology/smartcontract/service/native/ont" "github.com/ontio/ontology/smartcontract/service/native/utils" cstates "github.com/ontio/ontology/smartcontract/states" + "io" "math/rand" "sort" "strconv" @@ -436,12 +437,7 @@ func Sign(data []byte, signer *account.Account) ([]byte, error) { //SendRawTransaction send a transaction to ontology network, and return hash of the transaction func SendRawTransaction(tx *types.Transaction) (string, error) { - var buffer bytes.Buffer - err := tx.Serialize(&buffer) - if err != nil { - return "", fmt.Errorf("serialize error:%s", err) - } - txData := hex.EncodeToString(buffer.Bytes()) + txData := hex.EncodeToString(common.SerializeToBytes(tx)) return SendRawTransactionData(txData) } @@ -641,13 +637,11 @@ func PrepareDeployContract( if err != nil { return nil, fmt.Errorf("NewDeployCodeTransaction error:%s", err) } - tx, _ := mutable.IntoImmutable() - var buffer bytes.Buffer - err = tx.Serialize(&buffer) + tx, err := mutable.IntoImmutable() if err != nil { - return nil, fmt.Errorf("tx serialize error:%s", err) + return nil, err } - txData := hex.EncodeToString(buffer.Bytes()) + txData := hex.EncodeToString(common.SerializeToBytes(tx)) return PrepareSendRawTransaction(txData) } @@ -710,12 +704,7 @@ func PrepareInvokeNeoVMContract( return nil, err } - var buffer bytes.Buffer - err = tx.Serialize(&buffer) - if err != nil { - return nil, fmt.Errorf("tx serialize error:%s", err) - } - txData := hex.EncodeToString(buffer.Bytes()) + txData := hex.EncodeToString(common.SerializeToBytes(tx)) return PrepareSendRawTransaction(txData) } @@ -728,12 +717,7 @@ func PrepareInvokeCodeNeoVMContract(code []byte) (*cstates.PreExecResult, error) if err != nil { return nil, err } - var buffer bytes.Buffer - err = tx.Serialize(&buffer) - if err != nil { - return nil, fmt.Errorf("tx serialize error:%s", err) - } - txData := hex.EncodeToString(buffer.Bytes()) + txData := hex.EncodeToString(common.SerializeToBytes(tx)) return PrepareSendRawTransaction(txData) } @@ -749,12 +733,7 @@ func PrepareInvokeWasmVMContract(contractAddress common.Address, params []interf return nil, err } - var buffer bytes.Buffer - err = tx.Serialize(&buffer) - if err != nil { - return nil, fmt.Errorf("tx serialize error:%s", err) - } - txData := hex.EncodeToString(buffer.Bytes()) + txData := hex.EncodeToString(common.SerializeToBytes(tx)) return PrepareSendRawTransaction(txData) } @@ -771,12 +750,7 @@ func PrepareInvokeNativeContract( if err != nil { return nil, err } - var buffer bytes.Buffer - err = tx.Serialize(&buffer) - if err != nil { - return nil, fmt.Errorf("tx serialize error:%s", err) - } - txData := hex.EncodeToString(buffer.Bytes()) + txData := hex.EncodeToString(common.SerializeToBytes(tx)) return PrepareSendRawTransaction(txData) } @@ -833,10 +807,13 @@ func ParseWasmVMContractReturnTypeByteArray(hexStr string) (string, error) { if err != nil { return "", fmt.Errorf("common.HexToBytes:%s error:%s", hexStr, err) } - bf := bytes.NewBuffer(hexbs) - bs, err := serialization.ReadVarBytes(bf) - if err != nil { - return "", fmt.Errorf("ParseWasmVMContractReturnTypeByteArray:%s error:%s", hexStr, err) + source := common.NewZeroCopySource(hexbs) + bs, _, irregular, eof := source.NextVarBytes() + if irregular { + return "", fmt.Errorf("ParseWasmVMContractReturnTypeByteArray:%s error:%s", hexStr, common.ErrIrregularData) + } + if eof { + return "", fmt.Errorf("ParseWasmVMContractReturnTypeByteArray:%s error:%s", hexStr, io.ErrUnexpectedEOF) } return common.ToHexString(bs), nil } @@ -847,8 +824,15 @@ func ParseWasmVMContractReturnTypeString(hexStr string) (string, error) { if err != nil { return "", fmt.Errorf("common.HexToBytes:%s error:%s", hexStr, err) } - bf := bytes.NewBuffer(hexbs) - return serialization.ReadString(bf) + source := common.NewZeroCopySource(hexbs) + data, _, irregular, eof := source.NextString() + if irregular { + return "", common.ErrIrregularData + } + if eof { + return "", io.ErrUnexpectedEOF + } + return data, nil } //ParseWasmVMContractReturnTypeInteger return integer value of smart contract execute code. diff --git a/common/address.go b/common/address.go index ef77529b56..e85d11144a 100644 --- a/common/address.go +++ b/common/address.go @@ -22,11 +22,11 @@ import ( "crypto/sha256" "errors" "fmt" - "io" "math/big" "github.com/itchyny/base58-go" "golang.org/x/crypto/ripemd160" + "io" ) const ADDR_LEN = 20 @@ -41,16 +41,16 @@ func (self *Address) ToHexString() string { } // Serialize serialize Address into io.Writer -func (self *Address) Serialize(w io.Writer) error { - _, err := w.Write(self[:]) - return err +func (self *Address) Serialization(sink *ZeroCopySink) { + sink.WriteAddress(*self) } // Deserialize deserialize Address from io.Reader -func (self *Address) Deserialize(r io.Reader) error { - _, err := io.ReadFull(r, self[:]) - if err != nil { - return errors.New("deserialize Address error") +func (self *Address) Deserialization(source *ZeroCopySource) error { + var eof bool + *self, eof = source.NextAddress() + if eof { + return io.ErrUnexpectedEOF } return nil } diff --git a/common/address_test.go b/common/address_test.go index 5e2db9666f..7324086001 100644 --- a/common/address_test.go +++ b/common/address_test.go @@ -19,7 +19,6 @@ package common import ( - "bytes" "crypto/rand" "testing" @@ -55,10 +54,11 @@ func TestAddress_Serialize(t *testing.T) { var addr Address rand.Read(addr[:]) - buf := bytes.NewBuffer(nil) - addr.Serialize(buf) + sink := NewZeroCopySink(nil) + addr.Serialization(sink) var addr2 Address - addr2.Deserialize(buf) + source := NewZeroCopySource(sink.Bytes()) + addr2.Deserialization(source) assert.Equal(t, addr, addr2) } diff --git a/common/codec.go b/common/codec.go new file mode 100644 index 0000000000..4cb9923f59 --- /dev/null +++ b/common/codec.go @@ -0,0 +1,31 @@ +/* + * Copyright (C) 2018 The ontology Authors + * This file is part of The ontology library. + * + * The ontology is free software: you can redistribute it and/or modify + * it under the terms of the GNU Lesser General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * The ontology is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public License + * along with The ontology. If not, see . + */ +package common + +type Serializable interface { + Serialization(sink *ZeroCopySink) +} + +func SerializeToBytes(values ...Serializable) []byte { + sink := NewZeroCopySink(nil) + for _, val := range values { + val.Serialization(sink) + } + + return sink.Bytes() +} diff --git a/common/config/config.go b/common/config/config.go index e0c9a1c2c2..47a9b81deb 100644 --- a/common/config/config.go +++ b/common/config/config.go @@ -30,7 +30,6 @@ import ( "github.com/ontio/ontology/common" "github.com/ontio/ontology/common/constants" "github.com/ontio/ontology/common/log" - "github.com/ontio/ontology/common/serialization" "github.com/ontio/ontology/errors" ) @@ -309,111 +308,98 @@ type VBFTConfig struct { Peers []*VBFTPeerStakeInfo `json:"peers"` } -func (this *VBFTConfig) Serialize(w io.Writer) error { - if err := serialization.WriteUint32(w, this.N); err != nil { - return errors.NewDetailErr(err, errors.ErrNoCode, "serialization.WriteUint32, serialize n error!") - } - if err := serialization.WriteUint32(w, this.C); err != nil { - return errors.NewDetailErr(err, errors.ErrNoCode, "serialization.WriteUint32, serialize c error!") - } - if err := serialization.WriteUint32(w, this.K); err != nil { - return errors.NewDetailErr(err, errors.ErrNoCode, "serialization.WriteUint32, serialize k error!") - } - if err := serialization.WriteUint32(w, this.L); err != nil { - return errors.NewDetailErr(err, errors.ErrNoCode, "serialization.WriteUint32, serialize l error!") - } - if err := serialization.WriteUint32(w, this.BlockMsgDelay); err != nil { - return errors.NewDetailErr(err, errors.ErrNoCode, "serialization.WriteUint32, serialize block_msg_delay error!") - } - if err := serialization.WriteUint32(w, this.HashMsgDelay); err != nil { - return errors.NewDetailErr(err, errors.ErrNoCode, "serialization.WriteUint32, serialize hash_msg_delay error!") - } - if err := serialization.WriteUint32(w, this.PeerHandshakeTimeout); err != nil { - return errors.NewDetailErr(err, errors.ErrNoCode, "serialization.WriteUint32, serialize peer_handshake_timeout error!") - } - if err := serialization.WriteUint32(w, this.MaxBlockChangeView); err != nil { - return errors.NewDetailErr(err, errors.ErrNoCode, "serialization.WriteUint32, serialize max_block_change_view error!") - } - if err := serialization.WriteUint32(w, this.MinInitStake); err != nil { - return errors.NewDetailErr(err, errors.ErrNoCode, "serialization.WriteUint32, serialize min_init_stake error!") - } - if err := serialization.WriteString(w, this.AdminOntID); err != nil { - return errors.NewDetailErr(err, errors.ErrNoCode, "serialization.WriteString, serialize admin_ont_id error!") - } - if err := serialization.WriteString(w, this.VrfValue); err != nil { - return errors.NewDetailErr(err, errors.ErrNoCode, "serialization.WriteString, serialize vrf_value error!") - } - if err := serialization.WriteString(w, this.VrfProof); err != nil { - return errors.NewDetailErr(err, errors.ErrNoCode, "serialization.WriteString, serialize vrf_proof error!") - } - if err := serialization.WriteVarUint(w, uint64(len(this.Peers))); err != nil { - return errors.NewDetailErr(err, errors.ErrNoCode, "serialization.WriteVarUint, serialize peer length error!") - } - for _, peer := range this.Peers { - if err := peer.Serialize(w); err != nil { - return errors.NewDetailErr(err, errors.ErrNoCode, "serialize peer error!") +func (self *VBFTConfig) Serialization(sink *common.ZeroCopySink) error { + sink.WriteUint32(self.N) + sink.WriteUint32(self.C) + sink.WriteUint32(self.K) + sink.WriteUint32(self.L) + sink.WriteUint32(self.BlockMsgDelay) + sink.WriteUint32(self.HashMsgDelay) + sink.WriteUint32(self.PeerHandshakeTimeout) + sink.WriteUint32(self.MaxBlockChangeView) + sink.WriteUint32(self.MinInitStake) + sink.WriteString(self.AdminOntID) + sink.WriteString(self.VrfValue) + sink.WriteString(self.VrfProof) + sink.WriteVarUint(uint64(len(self.Peers))) + for _, peer := range self.Peers { + if err := peer.Serialization(sink); err != nil { + return err } } + return nil } -func (this *VBFTConfig) Deserialize(r io.Reader) error { - n, err := serialization.ReadUint32(r) - if err != nil { - return errors.NewDetailErr(err, errors.ErrNoCode, "serialization.ReadUint32, deserialize n error!") +func (this *VBFTConfig) Deserialization(source *common.ZeroCopySource) error { + n, eof := source.NextUint32() + if eof { + return errors.NewDetailErr(io.ErrUnexpectedEOF, errors.ErrNoCode, "serialization.ReadUint32, deserialize n error!") } - c, err := serialization.ReadUint32(r) - if err != nil { - return errors.NewDetailErr(err, errors.ErrNoCode, "serialization.ReadUint32, deserialize c error!") + c, eof := source.NextUint32() + if eof { + return errors.NewDetailErr(io.ErrUnexpectedEOF, errors.ErrNoCode, "serialization.ReadUint32, deserialize c error!") } - k, err := serialization.ReadUint32(r) - if err != nil { - return errors.NewDetailErr(err, errors.ErrNoCode, "serialization.ReadUint32, deserialize k error!") + k, eof := source.NextUint32() + if eof { + return errors.NewDetailErr(io.ErrUnexpectedEOF, errors.ErrNoCode, "serialization.ReadUint32, deserialize k error!") } - l, err := serialization.ReadUint32(r) - if err != nil { - return errors.NewDetailErr(err, errors.ErrNoCode, "serialization.ReadUint32, deserialize l error!") + l, eof := source.NextUint32() + if eof { + return errors.NewDetailErr(io.ErrUnexpectedEOF, errors.ErrNoCode, "serialization.ReadUint32, deserialize l error!") } - blockMsgDelay, err := serialization.ReadUint32(r) - if err != nil { - return errors.NewDetailErr(err, errors.ErrNoCode, "serialization.ReadUint32, deserialize blockMsgDelay error!") + blockMsgDelay, eof := source.NextUint32() + if eof { + return errors.NewDetailErr(io.ErrUnexpectedEOF, errors.ErrNoCode, "serialization.ReadUint32, deserialize blockMsgDelay error!") } - hashMsgDelay, err := serialization.ReadUint32(r) - if err != nil { - return errors.NewDetailErr(err, errors.ErrNoCode, "serialization.ReadUint32, deserialize hashMsgDelay error!") + hashMsgDelay, eof := source.NextUint32() + if eof { + return errors.NewDetailErr(io.ErrUnexpectedEOF, errors.ErrNoCode, "serialization.ReadUint32, deserialize hashMsgDelay error!") } - peerHandshakeTimeout, err := serialization.ReadUint32(r) - if err != nil { - return errors.NewDetailErr(err, errors.ErrNoCode, "serialization.ReadUint32, deserialize peerHandshakeTimeout error!") + peerHandshakeTimeout, eof := source.NextUint32() + if eof { + return errors.NewDetailErr(io.ErrUnexpectedEOF, errors.ErrNoCode, "serialization.ReadUint32, deserialize peerHandshakeTimeout error!") } - maxBlockChangeView, err := serialization.ReadUint32(r) - if err != nil { - return errors.NewDetailErr(err, errors.ErrNoCode, "serialization.ReadUint32, deserialize maxBlockChangeView error!") + maxBlockChangeView, eof := source.NextUint32() + if eof { + return errors.NewDetailErr(io.ErrUnexpectedEOF, errors.ErrNoCode, "serialization.ReadUint32, deserialize maxBlockChangeView error!") } - minInitStake, err := serialization.ReadUint32(r) - if err != nil { - return errors.NewDetailErr(err, errors.ErrNoCode, "serialization.ReadUint32, deserialize minInitStake error!") + minInitStake, eof := source.NextUint32() + if eof { + return errors.NewDetailErr(io.ErrUnexpectedEOF, errors.ErrNoCode, "serialization.ReadUint32, deserialize minInitStake error!") } - adminOntID, err := serialization.ReadString(r) - if err != nil { - return errors.NewDetailErr(err, errors.ErrNoCode, "serialization.ReadString, deserialize adminOntID error!") + adminOntID, _, irregular, eof := source.NextString() + if irregular { + return errors.NewDetailErr(common.ErrIrregularData, errors.ErrNoCode, "serialization.ReadString, deserialize adminOntID error!") } - vrfValue, err := serialization.ReadString(r) - if err != nil { - return errors.NewDetailErr(err, errors.ErrNoCode, "serialization.ReadString, deserialize vrfValue error!") + if eof { + return errors.NewDetailErr(io.ErrUnexpectedEOF, errors.ErrNoCode, "serialization.ReadString, deserialize adminOntID error!") } - vrfProof, err := serialization.ReadString(r) - if err != nil { - return errors.NewDetailErr(err, errors.ErrNoCode, "serialization.ReadString, deserialize vrfProof error!") + vrfValue, _, irregular, eof := source.NextString() + if irregular { + return errors.NewDetailErr(common.ErrIrregularData, errors.ErrNoCode, "serialization.ReadString, deserialize vrfValue error!") } - length, err := serialization.ReadVarUint(r, 0) - if err != nil { - return errors.NewDetailErr(err, errors.ErrNoCode, "serialization.ReadVarUint, deserialize peer length error!") + if eof { + return errors.NewDetailErr(io.ErrUnexpectedEOF, errors.ErrNoCode, "serialization.ReadString, deserialize vrfValue error!") + } + vrfProof, _, irregular, eof := source.NextString() + if irregular { + return errors.NewDetailErr(common.ErrIrregularData, errors.ErrNoCode, "serialization.ReadString, deserialize vrfProof error!") + } + if eof { + return errors.NewDetailErr(io.ErrUnexpectedEOF, errors.ErrNoCode, "serialization.ReadString, deserialize vrfProof error!") + } + length, _, irregular, eof := source.NextVarUint() + if irregular { + return errors.NewDetailErr(common.ErrIrregularData, errors.ErrNoCode, "serialization.ReadVarUint, deserialize peer length error!") + } + if eof { + return errors.NewDetailErr(io.ErrUnexpectedEOF, errors.ErrNoCode, "serialization.ReadVarUint, deserialize peer length error!") } peers := make([]*VBFTPeerStakeInfo, 0) for i := 0; uint64(i) < length; i++ { peer := new(VBFTPeerStakeInfo) - err = peer.Deserialize(r) + err := peer.Deserialization(source) if err != nil { return errors.NewDetailErr(err, errors.ErrNoCode, "deserialize peer error!") } @@ -442,43 +428,39 @@ type VBFTPeerStakeInfo struct { InitPos uint64 `json:"initPos"` } -func (this *VBFTPeerStakeInfo) Serialize(w io.Writer) error { - if err := serialization.WriteUint32(w, this.Index); err != nil { - return errors.NewDetailErr(err, errors.ErrNoCode, "serialization.WriteUint32, serialize index error!") - } - if err := serialization.WriteString(w, this.PeerPubkey); err != nil { - return errors.NewDetailErr(err, errors.ErrNoCode, "serialization.WriteUint32, serialize peerPubkey error!") - } +func (this *VBFTPeerStakeInfo) Serialization(sink *common.ZeroCopySink) error { + sink.WriteUint32(this.Index) + sink.WriteString(this.PeerPubkey) + address, err := common.AddressFromBase58(this.Address) if err != nil { - return errors.NewDetailErr(err, errors.ErrNoCode, "common.AddressFromBase58, address format error!") - } - if err := address.Serialize(w); err != nil { - return errors.NewDetailErr(err, errors.ErrNoCode, "serialization.WriteUint32, serialize address error!") - } - if err := serialization.WriteUint64(w, this.InitPos); err != nil { - return errors.NewDetailErr(err, errors.ErrNoCode, "serialization.WriteUint32, serialize initPos error!") + return fmt.Errorf("serialize VBFTPeerStackInfo error: %v", err) } + address.Serialization(sink) + sink.WriteUint64(this.InitPos) return nil } -func (this *VBFTPeerStakeInfo) Deserialize(r io.Reader) error { - index, err := serialization.ReadUint32(r) - if err != nil { - return errors.NewDetailErr(err, errors.ErrNoCode, "serialization.ReadUint32, deserialize index error!") +func (this *VBFTPeerStakeInfo) Deserialization(source *common.ZeroCopySource) error { + index, eof := source.NextUint32() + if eof { + return errors.NewDetailErr(io.ErrUnexpectedEOF, errors.ErrNoCode, "serialization.ReadUint32, deserialize index error!") } - peerPubkey, err := serialization.ReadString(r) - if err != nil { - return errors.NewDetailErr(err, errors.ErrNoCode, "serialization.ReadUint32, deserialize peerPubkey error!") + peerPubkey, _, irregular, eof := source.NextString() + if irregular { + return errors.NewDetailErr(common.ErrIrregularData, errors.ErrNoCode, "serialization.ReadUint32, deserialize peerPubkey error!") + } + if eof { + return errors.NewDetailErr(io.ErrUnexpectedEOF, errors.ErrNoCode, "serialization.ReadUint32, deserialize peerPubkey error!") } address := new(common.Address) - err = address.Deserialize(r) + err := address.Deserialization(source) if err != nil { return errors.NewDetailErr(err, errors.ErrNoCode, "address.Deserialize, deserialize address error!") } - initPos, err := serialization.ReadUint64(r) - if err != nil { - return errors.NewDetailErr(err, errors.ErrNoCode, "serialization.ReadUint32, deserialize initPos error!") + initPos, eof := source.NextUint64() + if eof { + return errors.NewDetailErr(io.ErrUnexpectedEOF, errors.ErrNoCode, "serialization.ReadUint32, deserialize initPos error!") } this.Index = index this.PeerPubkey = peerPubkey diff --git a/common/serialization/serialize.go b/common/serialization/serialize.go index 54793ceeff..932f211647 100644 --- a/common/serialization/serialize.go +++ b/common/serialization/serialize.go @@ -249,12 +249,6 @@ func WriteUint64(writer io.Writer, val uint64) error { return err } -func ToArray(data SerializableData) []byte { - buf := new(bytes.Buffer) - data.Serialize(buf) - return buf.Bytes() -} - //************************************************************************** //** internal func *** //************************************************************************** diff --git a/consensus/dbft/dbft_service.go b/consensus/dbft/dbft_service.go index cf846b13c1..7ab3a237c1 100644 --- a/consensus/dbft/dbft_service.go +++ b/consensus/dbft/dbft_service.go @@ -19,7 +19,6 @@ package dbft import ( - "bytes" "fmt" "reflect" "time" @@ -609,9 +608,9 @@ func (ds *DbftService) RequestChangeView() { } func (ds *DbftService) SignAndRelay(payload *p2pmsg.ConsensusPayload) { - buf := new(bytes.Buffer) - payload.SerializeUnsigned(buf) - payload.Signature, _ = signature.Sign(ds.Account, buf.Bytes()) + sink := common.NewZeroCopySink(nil) + payload.SerializationUnsigned(sink) + payload.Signature, _ = signature.Sign(ds.Account, sink.Bytes()) ds.p2p.Broadcast(payload) } diff --git a/consensus/vbft/config/config.go b/consensus/vbft/config/config.go index 55d0d08d3e..a6f6d7fa5c 100644 --- a/consensus/vbft/config/config.go +++ b/consensus/vbft/config/config.go @@ -27,7 +27,6 @@ import ( "time" "github.com/ontio/ontology/common" - "github.com/ontio/ontology/common/serialization" ) var ( @@ -101,26 +100,21 @@ func (cc *ChainConfig) Serialize(w io.Writer) error { return nil } -func (pc *PeerConfig) Serialize(w io.Writer) error { - if err := serialization.WriteUint32(w, pc.Index); err != nil { - return fmt.Errorf("ChainConfig peer index length serialization failed %s", err) - } - if err := serialization.WriteString(w, pc.ID); err != nil { - return fmt.Errorf("ChainConfig peer ID length serialization failed %s", err) - } - return nil +func (pc *PeerConfig) Serialization(sink *common.ZeroCopySink) { + sink.WriteUint32(pc.Index) + sink.WriteString(pc.ID) } -func (pc *PeerConfig) Deserialize(r io.Reader) error { - index, err := serialization.ReadUint32(r) - if err != nil { - return fmt.Errorf("serialization PeerConfig index err:%s", err) +func (pc *PeerConfig) Deserialization(source *common.ZeroCopySource) error { + index, eof := source.NextUint32() + if eof { + return fmt.Errorf("Deserialization PeerConfig index err:%s", io.ErrUnexpectedEOF) } pc.Index = index - nodeid, err := serialization.ReadString(r) - if err != nil { - return fmt.Errorf("serialization PeerConfig nodeid err:%s", err) + nodeid, _, irregular, eof := source.NextString() + if irregular || eof { + return fmt.Errorf("serialization PeerConfig nodeid irregular:%v, eof:%v", irregular, eof) } pc.ID = nodeid return nil diff --git a/consensus/vbft/msg_types.go b/consensus/vbft/msg_types.go index 7d3a383d9f..357d678843 100644 --- a/consensus/vbft/msg_types.go +++ b/consensus/vbft/msg_types.go @@ -104,7 +104,7 @@ func (msg *blockProposalMsg) GetBlockNum() uint32 { } func (msg *blockProposalMsg) Serialize() ([]byte, error) { - return msg.Block.Serialize() + return msg.Block.Serialize(), nil } func (msg *blockProposalMsg) UnmarshalJSON(data []byte) error { @@ -118,7 +118,7 @@ func (msg *blockProposalMsg) UnmarshalJSON(data []byte) error { } func (msg *blockProposalMsg) MarshalJSON() ([]byte, error) { - return msg.Block.Serialize() + return msg.Block.Serialize(), nil } type FaultyReport struct { @@ -337,10 +337,7 @@ func (msg *BlockFetchRespMsg) Serialize() ([]byte, error) { buffer := bytes.NewBuffer([]byte{}) serialization.WriteUint32(buffer, msg.BlockNumber) msg.BlockHash.Serialize(buffer) - blockbuff, err := msg.BlockData.Serialize() - if err != nil { - return nil, err - } + blockbuff := msg.BlockData.Serialize() buffer.Write(blockbuff) return buffer.Bytes(), nil } diff --git a/consensus/vbft/msg_types_test.go b/consensus/vbft/msg_types_test.go index ad5f765614..918b6dbf33 100644 --- a/consensus/vbft/msg_types_test.go +++ b/consensus/vbft/msg_types_test.go @@ -442,10 +442,7 @@ func TestBlockSerialization(t *testing.T) { return } - data, err := blk.Serialize() - if err != nil { - t.Fatalf("serialize blk: %s", err) - } + data := blk.Serialize() blk2 := &Block{} if err := blk2.Deserialize(data); err != nil { @@ -453,10 +450,7 @@ func TestBlockSerialization(t *testing.T) { } blk.EmptyBlock = nil - data2, err := blk.Serialize() - if err != nil { - t.Fatalf("serialize blk2: %s", err) - } + data2 := blk.Serialize() blk3 := &Block{} if err := blk3.Deserialize(data2); err != nil { t.Fatalf("deserialize blk2: %s", err) diff --git a/consensus/vbft/node_utils.go b/consensus/vbft/node_utils.go index dce8cdf27f..176fb52f02 100644 --- a/consensus/vbft/node_utils.go +++ b/consensus/vbft/node_utils.go @@ -19,14 +19,14 @@ package vbft import ( - "bytes" "fmt" "math" + "github.com/ontio/ontology/common" "github.com/ontio/ontology/common/log" - vconfig "github.com/ontio/ontology/consensus/vbft/config" + "github.com/ontio/ontology/consensus/vbft/config" "github.com/ontio/ontology/core/signature" - msgpack "github.com/ontio/ontology/p2pserver/message/msg_pack" + "github.com/ontio/ontology/p2pserver/message/msg_pack" p2pmsg "github.com/ontio/ontology/p2pserver/message/types" ) @@ -407,11 +407,9 @@ func (self *Server) sendToPeer(peerIdx uint32, data []byte) error { Owner: self.account.PublicKey, } - buf := new(bytes.Buffer) - if err := msg.SerializeUnsigned(buf); err != nil { - return fmt.Errorf("failed to serialize consensus msg: %s", err) - } - msg.Signature, _ = signature.Sign(self.account, buf.Bytes()) + sink := common.NewZeroCopySink(nil) + msg.SerializationUnsigned(sink) + msg.Signature, _ = signature.Sign(self.account, sink.Bytes()) cons := msgpack.NewConsensus(msg) p2pid, present := self.peerPool.getP2pId(peerIdx) @@ -436,11 +434,9 @@ func (self *Server) broadcastToAll(data []byte) error { Owner: self.account.PublicKey, } - buf := new(bytes.Buffer) - if err := msg.SerializeUnsigned(buf); err != nil { - return fmt.Errorf("failed to serialize consensus msg: %s", err) - } - msg.Signature, _ = signature.Sign(self.account, buf.Bytes()) + sink := common.NewZeroCopySink(nil) + msg.SerializationUnsigned(sink) + msg.Signature, _ = signature.Sign(self.account, sink.Bytes()) self.p2p.Broadcast(msg) return nil diff --git a/consensus/vbft/types.go b/consensus/vbft/types.go index 3e5a54d72d..5fa003b2e2 100644 --- a/consensus/vbft/types.go +++ b/consensus/vbft/types.go @@ -70,21 +70,16 @@ func (blk *Block) getVrfProof() []byte { return blk.Info.VrfProof } -func (blk *Block) Serialize() ([]byte, error) { - sink := common.NewZeroCopySink(nil) - blk.Block.Serialization(sink) - +func (blk *Block) Serialize() []byte { payload := common.NewZeroCopySink(nil) - payload.WriteVarBytes(sink.Bytes()) + payload.WriteVarBytes(common.SerializeToBytes(blk.Block)) payload.WriteBool(blk.EmptyBlock != nil) if blk.EmptyBlock != nil { - sink2 := common.NewZeroCopySink(nil) - blk.EmptyBlock.Serialization(sink2) - payload.WriteVarBytes(sink2.Bytes()) + payload.WriteVarBytes(common.SerializeToBytes(blk.EmptyBlock)) } payload.WriteHash(blk.PrevBlockMerkleRoot) - return payload.Bytes(), nil + return payload.Bytes() } func (blk *Block) Deserialize(data []byte) error { diff --git a/consensus/vbft/types_test.go b/consensus/vbft/types_test.go index 46d6a5e2bb..d1fe599574 100644 --- a/consensus/vbft/types_test.go +++ b/consensus/vbft/types_test.go @@ -198,11 +198,7 @@ func TestSerialize(t *testing.T) { if err != nil { t.Errorf("constructBlock failed: %v", err) } - _, err = blk.Serialize() - if err != nil { - t.Errorf("Block Serialize failed :%v", err) - return - } + blk.Serialize() t.Log("Block Serialize succ") } diff --git a/consensus/vbft/utils.go b/consensus/vbft/utils.go index a4f0475acd..1b600f805a 100644 --- a/consensus/vbft/utils.go +++ b/consensus/vbft/utils.go @@ -140,7 +140,7 @@ func GetVbftConfigInfo(memdb *overlaydb.MemDB) (*config.VBFTConfig, error) { return nil, err } if data != nil { - err = preCfg.Deserialize(bytes.NewBuffer(data)) + err = preCfg.Deserialization(common.NewZeroCopySource(data)) if err != nil { return nil, err } @@ -164,7 +164,7 @@ func GetVbftConfigInfo(memdb *overlaydb.MemDB) (*config.VBFTConfig, error) { return nil, err } cfg := new(gov.Configuration) - err = cfg.Deserialize(bytes.NewBuffer(data)) + err = cfg.Deserialization(common.NewZeroCopySource(data)) if err != nil { return nil, err } @@ -199,7 +199,7 @@ func GetPeersConfig(memdb *overlaydb.MemDB) ([]*config.VBFTPeerStakeInfo, error) peerMap := &gov.PeerPoolMap{ PeerPoolMap: make(map[string]*gov.PeerPoolItem), } - err = peerMap.Deserialize(bytes.NewBuffer(data)) + err = peerMap.Deserialization(common.NewZeroCopySource(data)) if err != nil { return nil, err } diff --git a/core/genesis/genesis.go b/core/genesis/genesis.go index f385ae87e7..75f18bfc2c 100644 --- a/core/genesis/genesis.go +++ b/core/genesis/genesis.go @@ -19,7 +19,6 @@ package genesis import ( - "bytes" "fmt" "sort" "strconv" @@ -68,9 +67,12 @@ func BuildGenesisBlock(defaultBookkeeper []keypair.PublicKey, genesisConfig *con if err != nil { return nil, fmt.Errorf("[Block],BuildGenesisBlock err with GetBookkeeperAddress: %s", err) } - conf := bytes.NewBuffer(nil) + conf := common.NewZeroCopySink(nil) if genesisConfig.VBFT != nil { - genesisConfig.VBFT.Serialize(conf) + err := genesisConfig.VBFT.Serialization(conf) + if err != nil { + return nil, err + } } govConfig := newGoverConfigInit(conf.Bytes()) consensusPayload, err := vconfig.GenesisConsensusPayload(govConfig.Hash(), 0) @@ -218,11 +220,11 @@ func newGoverningInit() *types.Transaction { value uint64 }{{addr, constants.ONT_TOTAL_SUPPLY}} - args := bytes.NewBuffer(nil) - nutils.WriteVarUint(args, uint64(len(distribute))) + args := common.NewZeroCopySink(nil) + nutils.EncodeVarUint(args, uint64(len(distribute))) for _, part := range distribute { - nutils.WriteAddress(args, part.addr) - nutils.WriteVarUint(args, part.value) + nutils.EncodeAddress(args, part.addr) + nutils.EncodeVarUint(args, part.value) } mutable := utils.BuildNativeTransaction(nutils.OntContractAddress, ont.INIT_NAME, args.Bytes()) @@ -259,8 +261,8 @@ func newParamInit() *types.Transaction { for _, v := range s { params.SetParam(global_params.Param{Key: v, Value: INIT_PARAM[v]}) } - bf := new(bytes.Buffer) - params.Serialize(bf) + sink := common.NewZeroCopySink(nil) + params.Serialization(sink) bookkeepers, _ := config.DefConfig.GetBookkeepers() var addr common.Address @@ -274,9 +276,9 @@ func newParamInit() *types.Transaction { } addr = temp } - nutils.WriteAddress(bf, addr) + nutils.EncodeAddress(sink, addr) - mutable := utils.BuildNativeTransaction(nutils.ParamContractAddress, global_params.INIT_NAME, bf.Bytes()) + mutable := utils.BuildNativeTransaction(nutils.ParamContractAddress, global_params.INIT_NAME, sink.Bytes()) tx, err := mutable.IntoImmutable() if err != nil { panic("construct genesis governing token transaction error ") diff --git a/core/payload/invoke_code.go b/core/payload/invoke_code.go index 1c663684d9..57bc9888e1 100644 --- a/core/payload/invoke_code.go +++ b/core/payload/invoke_code.go @@ -19,11 +19,9 @@ package payload import ( - "fmt" "io" "github.com/ontio/ontology/common" - "github.com/ontio/ontology/common/serialization" ) // InvokeCode is an implementation of transaction payload for invoke smartcontract @@ -31,22 +29,6 @@ type InvokeCode struct { Code []byte } -func (self *InvokeCode) Serialize(w io.Writer) error { - if err := serialization.WriteVarBytes(w, self.Code); err != nil { - return fmt.Errorf("InvokeCode Code Serialize failed: %s", err) - } - return nil -} - -func (self *InvokeCode) Deserialize(r io.Reader) error { - code, err := serialization.ReadVarBytes(r) - if err != nil { - return fmt.Errorf("InvokeCode Code Deserialize failed: %s", err) - } - self.Code = code - return nil -} - //note: InvokeCode.Code has data reference of param source func (self *InvokeCode) Deserialization(source *common.ZeroCopySource) error { code, _, irregular, eof := source.NextVarBytes() diff --git a/core/payload/invoke_code_test.go b/core/payload/invoke_code_test.go index 6cd6aa26c9..3a57dc1075 100644 --- a/core/payload/invoke_code_test.go +++ b/core/payload/invoke_code_test.go @@ -18,9 +18,9 @@ package payload import ( - "bytes" "testing" + "github.com/ontio/ontology/common" "github.com/stretchr/testify/assert" ) @@ -29,15 +29,16 @@ func TestInvokeCode_Serialize(t *testing.T) { Code: []byte{1, 2, 3}, } - buf := bytes.NewBuffer(nil) - code.Serialize(buf) - bs := buf.Bytes() + sink := common.NewZeroCopySink(nil) + code.Serialization(sink) + bs := sink.Bytes() var code2 InvokeCode - code2.Deserialize(buf) + source := common.NewZeroCopySource(bs) + code2.Deserialization(source) assert.Equal(t, code, code2) - buf = bytes.NewBuffer(bs[:len(bs)-2]) - err := code.Deserialize(buf) + source = common.NewZeroCopySource(bs[:len(bs)-2]) + err := code.Deserialization(source) assert.NotNil(t, err) } diff --git a/core/states/bookkeeper.go b/core/states/bookkeeper.go index a27f1d7f22..0103ad5ed6 100644 --- a/core/states/bookkeeper.go +++ b/core/states/bookkeeper.go @@ -19,11 +19,10 @@ package states import ( - "bytes" "io" "github.com/ontio/ontology-crypto/keypair" - "github.com/ontio/ontology/common/serialization" + "github.com/ontio/ontology/common" ) type BookkeeperState struct { @@ -32,62 +31,64 @@ type BookkeeperState struct { NextBookkeeper []keypair.PublicKey } -func (this *BookkeeperState) Serialize(w io.Writer) error { - this.StateBase.Serialize(w) - serialization.WriteUint32(w, uint32(len(this.CurrBookkeeper))) +func (this *BookkeeperState) Serialization(sink *common.ZeroCopySink) { + this.StateBase.Serialization(sink) + sink.WriteUint32(uint32(len(this.CurrBookkeeper))) for _, v := range this.CurrBookkeeper { buf := keypair.SerializePublicKey(v) - err := serialization.WriteVarBytes(w, buf) - if err != nil { - return err - } + sink.WriteVarBytes(buf) } - serialization.WriteUint32(w, uint32(len(this.NextBookkeeper))) + sink.WriteUint32(uint32(len(this.NextBookkeeper))) for _, v := range this.NextBookkeeper { buf := keypair.SerializePublicKey(v) - err := serialization.WriteVarBytes(w, buf) - if err != nil { - return err - } + sink.WriteVarBytes(buf) } - return nil } -func (this *BookkeeperState) Deserialize(r io.Reader) error { - err := this.StateBase.Deserialize(r) +func (this *BookkeeperState) Deserialization(source *common.ZeroCopySource) error { + err := this.StateBase.Deserialization(source) if err != nil { return err } - n, err := serialization.ReadUint32(r) - if err != nil { - return err + n, eof := source.NextUint32() + if eof { + return io.ErrUnexpectedEOF } for i := 0; i < int(n); i++ { - buf, err := serialization.ReadVarBytes(r) + buf, _, irregular, eof := source.NextVarBytes() + if irregular { + return common.ErrIrregularData + } + if eof { + return io.ErrUnexpectedEOF + } + key, err := keypair.DeserializePublicKey(buf) if err != nil { return err } - key, err := keypair.DeserializePublicKey(buf) this.CurrBookkeeper = append(this.CurrBookkeeper, key) } - - n, err = serialization.ReadUint32(r) - if err != nil { - return err + n, eof = source.NextUint32() + if eof { + return io.ErrUnexpectedEOF } for i := 0; i < int(n); i++ { - buf, err := serialization.ReadVarBytes(r) + buf, _, irregular, eof := source.NextVarBytes() + if irregular { + return common.ErrIrregularData + } + if eof { + return io.ErrUnexpectedEOF + } + key, err := keypair.DeserializePublicKey(buf) if err != nil { return err } - key, err := keypair.DeserializePublicKey(buf) this.NextBookkeeper = append(this.NextBookkeeper, key) } return nil } func (v *BookkeeperState) ToArray() []byte { - b := new(bytes.Buffer) - v.Serialize(b) - return b.Bytes() + return common.SerializeToBytes(v) } diff --git a/core/states/bookkeeper_test.go b/core/states/bookkeeper_test.go index fe969c6875..db02f3a6e9 100644 --- a/core/states/bookkeeper_test.go +++ b/core/states/bookkeeper_test.go @@ -20,9 +20,8 @@ package states import ( "testing" - "bytes" - "github.com/ontio/ontology-crypto/keypair" + "github.com/ontio/ontology/common" "github.com/stretchr/testify/assert" ) @@ -38,15 +37,16 @@ func TestBookkeeper_Deserialize_Serialize(t *testing.T) { NextBookkeeper: []keypair.PublicKey{pubKey3, pubKey4}, } - buf := bytes.NewBuffer(nil) - bk.Serialize(buf) - bs := buf.Bytes() + sink := common.NewZeroCopySink(nil) + bk.Serialization(sink) + bs := sink.Bytes() var bk2 BookkeeperState - bk2.Deserialize(buf) + source := common.NewZeroCopySource(bs) + bk2.Deserialization(source) assert.Equal(t, bk, bk2) - buf = bytes.NewBuffer(bs[:len(bs)-1]) - err := bk2.Deserialize(buf) + source = common.NewZeroCopySource(bs[:len(bs)-1]) + err := bk2.Deserialization(source) assert.NotNil(t, err) } diff --git a/core/states/state_base.go b/core/states/state_base.go index 94213f96b6..e5e40d23a6 100644 --- a/core/states/state_base.go +++ b/core/states/state_base.go @@ -21,23 +21,21 @@ package states import ( "io" - "github.com/ontio/ontology/common/serialization" - "github.com/ontio/ontology/errors" + "github.com/ontio/ontology/common" ) type StateBase struct { StateVersion byte } -func (this *StateBase) Serialize(w io.Writer) error { - serialization.WriteByte(w, this.StateVersion) - return nil +func (this *StateBase) Serialization(sink *common.ZeroCopySink) { + sink.WriteByte(this.StateVersion) } -func (this *StateBase) Deserialize(r io.Reader) error { - stateVersion, err := serialization.ReadByte(r) - if err != nil { - return errors.NewDetailErr(err, errors.ErrNoCode, "[StateBase], StateBase Deserialize failed.") +func (this *StateBase) Deserialization(source *common.ZeroCopySource) error { + stateVersion, eof := source.NextByte() + if eof { + return io.ErrUnexpectedEOF } this.StateVersion = stateVersion return nil diff --git a/core/states/state_base_test.go b/core/states/state_base_test.go index 821b2ea2fc..fbab49e818 100644 --- a/core/states/state_base_test.go +++ b/core/states/state_base_test.go @@ -18,7 +18,7 @@ package states import ( - "bytes" + "github.com/ontio/ontology/common" "testing" ) @@ -26,13 +26,12 @@ func TestStateBase_Serialize_Deserialize(t *testing.T) { st := &StateBase{byte(1)} - bf := new(bytes.Buffer) - if err := st.Serialize(bf); err != nil { - t.Fatalf("StateBase serialize error: %v", err) - } + bf := common.NewZeroCopySink(nil) + st.Serialization(bf) var st2 = new(StateBase) - if err := st2.Deserialize(bf); err != nil { + source := common.NewZeroCopySource(bf.Bytes()) + if err := st2.Deserialization(source); err != nil { t.Fatalf("StateBase deserialize error: %v", err) } } diff --git a/core/states/storage_item.go b/core/states/storage_item.go index 96f55aff8f..b808fb4529 100644 --- a/core/states/storage_item.go +++ b/core/states/storage_item.go @@ -19,10 +19,9 @@ package states import ( - "bytes" "io" - "github.com/ontio/ontology/common/serialization" + "github.com/ontio/ontology/common" "github.com/ontio/ontology/errors" ) @@ -31,34 +30,34 @@ type StorageItem struct { Value []byte } -func (this *StorageItem) Serialize(w io.Writer) error { - this.StateBase.Serialize(w) - serialization.WriteVarBytes(w, this.Value) - return nil +func (this *StorageItem) Serialization(sink *common.ZeroCopySink) { + this.StateBase.Serialization(sink) + sink.WriteVarBytes(this.Value) } -func (this *StorageItem) Deserialize(r io.Reader) error { - err := this.StateBase.Deserialize(r) +func (this *StorageItem) Deserialization(source *common.ZeroCopySource) error { + err := this.StateBase.Deserialization(source) if err != nil { return errors.NewDetailErr(err, errors.ErrNoCode, "[StorageItem], StateBase Deserialize failed.") } - value, err := serialization.ReadVarBytes(r) - if err != nil { - return errors.NewDetailErr(err, errors.ErrNoCode, "[StorageItem], Value Deserialize failed.") + value, _, irregular, eof := source.NextVarBytes() + if irregular { + return errors.NewDetailErr(common.ErrIrregularData, errors.ErrNoCode, "[StorageItem], Value Deserialize failed.") + } + if eof { + return errors.NewDetailErr(io.ErrUnexpectedEOF, errors.ErrNoCode, "[StorageItem], Value Deserialize failed.") } this.Value = value return nil } func (storageItem *StorageItem) ToArray() []byte { - b := new(bytes.Buffer) - storageItem.Serialize(b) - return b.Bytes() + return common.SerializeToBytes(storageItem) } func GetValueFromRawStorageItem(raw []byte) ([]byte, error) { item := StorageItem{} - err := item.Deserialize(bytes.NewBuffer(raw)) + err := item.Deserialization(common.NewZeroCopySource(raw)) if err != nil { return nil, err } diff --git a/core/states/storage_item_test.go b/core/states/storage_item_test.go index c3e5acc897..5b98c3a455 100644 --- a/core/states/storage_item_test.go +++ b/core/states/storage_item_test.go @@ -18,7 +18,7 @@ package states import ( - "bytes" + "github.com/ontio/ontology/common" "testing" ) @@ -29,13 +29,12 @@ func TestStorageItem_Serialize_Deserialize(t *testing.T) { Value: []byte{1}, } - bf := new(bytes.Buffer) - if err := item.Serialize(bf); err != nil { - t.Fatalf("StorageItem serialize error: %v", err) - } + bf := common.NewZeroCopySink(nil) + item.Serialization(bf) var storage = new(StorageItem) - if err := storage.Deserialize(bf); err != nil { + source := common.NewZeroCopySource(bf.Bytes()) + if err := storage.Deserialization(source); err != nil { t.Fatalf("StorageItem deserialize error: %v", err) } } diff --git a/core/states/storage_key.go b/core/states/storage_key.go index 60c58c8dda..17223c2504 100644 --- a/core/states/storage_key.go +++ b/core/states/storage_key.go @@ -19,11 +19,8 @@ package states import ( - "bytes" - "io" - "github.com/ontio/ontology/common" - "github.com/ontio/ontology/common/serialization" + "io" ) type StorageKey struct { @@ -31,30 +28,26 @@ type StorageKey struct { Key []byte } -func (this *StorageKey) Serialize(w io.Writer) (int, error) { - if err := this.ContractAddress.Serialize(w); err != nil { - return 0, err - } - if err := serialization.WriteVarBytes(w, this.Key); err != nil { - return 0, err - } - return 0, nil +func (this *StorageKey) Serialization(sink *common.ZeroCopySink) { + this.ContractAddress.Serialization(sink) + sink.WriteVarBytes(this.Key) } -func (this *StorageKey) Deserialize(r io.Reader) error { - if err := this.ContractAddress.Deserialize(r); err != nil { +func (this *StorageKey) Deserialization(source *common.ZeroCopySource) error { + if err := this.ContractAddress.Deserialization(source); err != nil { return err } - key, err := serialization.ReadVarBytes(r) - if err != nil { - return err + key, _, irregular, eof := source.NextVarBytes() + if irregular { + return common.ErrIrregularData + } + if eof { + return io.ErrUnexpectedEOF } this.Key = key return nil } func (this *StorageKey) ToArray() []byte { - b := new(bytes.Buffer) - this.Serialize(b) - return b.Bytes() + return common.SerializeToBytes(this) } diff --git a/core/states/storage_key_test.go b/core/states/storage_key_test.go index edb7046b5e..2330b3fd84 100644 --- a/core/states/storage_key_test.go +++ b/core/states/storage_key_test.go @@ -20,7 +20,6 @@ package states import ( "testing" - "bytes" "crypto/rand" "github.com/ontio/ontology/common" @@ -36,15 +35,16 @@ func TestStorageKey_Deserialize_Serialize(t *testing.T) { Key: []byte{1, 2, 3}, } - buf := bytes.NewBuffer(nil) - storage.Serialize(buf) - bs := buf.Bytes() + sink := common.NewZeroCopySink(nil) + storage.Serialization(sink) + bs := sink.Bytes() var storage2 StorageKey - storage2.Deserialize(buf) + source := common.NewZeroCopySource(sink.Bytes()) + storage2.Deserialization(source) assert.Equal(t, storage, storage2) - buf = bytes.NewBuffer(bs[:len(bs)-1]) - err := storage2.Deserialize(buf) + buf := common.NewZeroCopySource(bs[:len(bs)-1]) + err := storage2.Deserialization(buf) assert.NotNil(t, err) } diff --git a/core/states/validator_state.go b/core/states/validator_state.go index 27a4b5f5aa..4df4a2adad 100644 --- a/core/states/validator_state.go +++ b/core/states/validator_state.go @@ -19,11 +19,10 @@ package states import ( - "io" - "github.com/ontio/ontology-crypto/keypair" - "github.com/ontio/ontology/common/serialization" + "github.com/ontio/ontology/common" "github.com/ontio/ontology/errors" + "io" ) type ValidatorState struct { @@ -31,23 +30,23 @@ type ValidatorState struct { PublicKey keypair.PublicKey } -func (this *ValidatorState) Serialize(w io.Writer) error { - this.StateBase.Serialize(w) +func (this *ValidatorState) Serialization(sink *common.ZeroCopySink) { + this.StateBase.Serialization(sink) buf := keypair.SerializePublicKey(this.PublicKey) - if err := serialization.WriteVarBytes(w, buf); err != nil { - return err - } - return nil + sink.WriteVarBytes(buf) } -func (this *ValidatorState) Deserialize(r io.Reader) error { - err := this.StateBase.Deserialize(r) +func (this *ValidatorState) Deserialization(source *common.ZeroCopySource) error { + err := this.StateBase.Deserialization(source) if err != nil { return errors.NewDetailErr(err, errors.ErrNoCode, "[ValidatorState], StateBase Deserialize failed.") } - buf, err := serialization.ReadVarBytes(r) - if err != nil { - return errors.NewDetailErr(err, errors.ErrNoCode, "[ValidatorState], PublicKey Deserialize failed.") + buf, _, irregular, eof := source.NextVarBytes() + if irregular { + return errors.NewDetailErr(common.ErrIrregularData, errors.ErrNoCode, "[ValidatorState], PublicKey Deserialize failed.") + } + if eof { + return errors.NewDetailErr(io.ErrUnexpectedEOF, errors.ErrNoCode, "[ValidatorState], PublicKey Deserialize failed.") } pk, err := keypair.DeserializePublicKey(buf) if err != nil { diff --git a/core/states/validator_state_test.go b/core/states/validator_state_test.go index 76b3e8c142..1b9ed838d8 100644 --- a/core/states/validator_state_test.go +++ b/core/states/validator_state_test.go @@ -20,9 +20,8 @@ package states import ( "testing" - "bytes" - "github.com/ontio/ontology-crypto/keypair" + "github.com/ontio/ontology/common" "github.com/stretchr/testify/assert" ) @@ -34,15 +33,16 @@ func TestValidatorState_Deserialize_Serialize(t *testing.T) { PublicKey: pubKey, } - buf := bytes.NewBuffer(nil) - vs.Serialize(buf) - bs := buf.Bytes() + sink := common.NewZeroCopySink(nil) + vs.Serialization(sink) + bs := sink.Bytes() var vs2 ValidatorState - vs2.Deserialize(buf) + source := common.NewZeroCopySource(sink.Bytes()) + vs2.Deserialization(source) assert.Equal(t, vs, vs2) - buf = bytes.NewBuffer(bs[:len(bs)-1]) - err := vs2.Deserialize(buf) + source = common.NewZeroCopySource(bs[:len(bs)-1]) + err := vs2.Deserialization(source) assert.NotNil(t, err) } diff --git a/core/states/vote_state.go b/core/states/vote_state.go index cd9e44e865..b97a4da87a 100644 --- a/core/states/vote_state.go +++ b/core/states/vote_state.go @@ -19,11 +19,9 @@ package states import ( - "io" - "github.com/ontio/ontology-crypto/keypair" "github.com/ontio/ontology/common" - "github.com/ontio/ontology/common/serialization" + "io" ) type VoteState struct { @@ -32,39 +30,32 @@ type VoteState struct { Count common.Fixed64 } -func (this *VoteState) Serialize(w io.Writer) error { - err := this.StateBase.Serialize(w) - if err != nil { - return err - } - err = serialization.WriteUint32(w, uint32(len(this.PublicKeys))) - if err != nil { - return err - } +func (this *VoteState) Serialization(sink *common.ZeroCopySink) { + this.StateBase.Serialization(sink) + sink.WriteUint32(uint32(len(this.PublicKeys))) for _, v := range this.PublicKeys { buf := keypair.SerializePublicKey(v) - err := serialization.WriteVarBytes(w, buf) - if err != nil { - return err - } + sink.WriteVarBytes(buf) } - - return serialization.WriteUint64(w, uint64(this.Count)) + sink.WriteUint64(uint64(this.Count)) } -func (this *VoteState) Deserialize(r io.Reader) error { - err := this.StateBase.Deserialize(r) +func (this *VoteState) Deserialization(source *common.ZeroCopySource) error { + err := this.StateBase.Deserialization(source) if err != nil { return err } - n, err := serialization.ReadUint32(r) - if err != nil { - return err + n, eof := source.NextUint32() + if eof { + return io.ErrUnexpectedEOF } for i := 0; i < int(n); i++ { - buf, err := serialization.ReadVarBytes(r) - if err != nil { - return err + buf, _, irregular, eof := source.NextVarBytes() + if irregular { + return common.ErrIrregularData + } + if eof { + return io.ErrUnexpectedEOF } pk, err := keypair.DeserializePublicKey(buf) if err != nil { @@ -72,9 +63,9 @@ func (this *VoteState) Deserialize(r io.Reader) error { } this.PublicKeys = append(this.PublicKeys, pk) } - c, err := serialization.ReadUint64(r) - if err != nil { - return err + c, eof := source.NextUint64() + if eof { + return io.ErrUnexpectedEOF } this.Count = common.Fixed64(int64(c)) return nil diff --git a/core/states/vote_state_test.go b/core/states/vote_state_test.go index ee9b1947d7..f1a90a0815 100644 --- a/core/states/vote_state_test.go +++ b/core/states/vote_state_test.go @@ -20,9 +20,8 @@ package states import ( "testing" - "bytes" - "github.com/ontio/ontology-crypto/keypair" + "github.com/ontio/ontology/common" "github.com/stretchr/testify/assert" ) @@ -36,15 +35,16 @@ func TestVoteState_Deserialize_Serialize(t *testing.T) { Count: 10, } - buf := bytes.NewBuffer(nil) - vs.Serialize(buf) - bs := buf.Bytes() + sink := common.NewZeroCopySink(nil) + vs.Serialization(sink) + bs := sink.Bytes() var vs2 VoteState - vs2.Deserialize(buf) + source := common.NewZeroCopySource(bs) + vs2.Deserialization(source) assert.Equal(t, vs, vs2) - buf = bytes.NewBuffer(bs[:len(bs)-1]) - err := vs2.Deserialize(buf) + source = common.NewZeroCopySource(bs[:len(bs)-1]) + err := vs2.Deserialization(source) assert.NotNil(t, err) } diff --git a/core/store/common/store.go b/core/store/common/store.go index 36e6c57f6d..9c4758db76 100644 --- a/core/store/common/store.go +++ b/core/store/common/store.go @@ -87,7 +87,7 @@ type EventStore interface { //SaveEventNotifyByTx save event notify gen by smart contract execution SaveEventNotifyByTx(txHash common.Uint256, notify *event.ExecuteNotify) error //Save transaction hashes which have event notify gen - SaveEventNotifyByBlock(height uint32, txHashs []common.Uint256) error + SaveEventNotifyByBlock(height uint32, txHashs []common.Uint256) //GetEventNotifyByTx return event notify by transaction hash GetEventNotifyByTx(txHash common.Uint256) (*event.ExecuteNotify, error) //Commit event notify to store diff --git a/core/store/ledgerstore/block_store.go b/core/store/ledgerstore/block_store.go index b1b01c59b6..1e6b8bab3b 100644 --- a/core/store/ledgerstore/block_store.go +++ b/core/store/ledgerstore/block_store.go @@ -80,11 +80,7 @@ func (this *BlockStore) SaveBlock(block *types.Block) error { return fmt.Errorf("SaveHeader error %s", err) } for _, tx := range block.Transactions { - err = this.SaveTransaction(tx, blockHeight) - if err != nil { - txHash := tx.Hash() - return fmt.Errorf("SaveTransaction block height %d tx %s err %s", blockHeight, txHash.ToHexString(), err) - } + this.SaveTransaction(tx, blockHeight) } return nil } @@ -281,9 +277,9 @@ func (this *BlockStore) GetCurrentBlock() (common.Uint256, uint32, error) { //SaveCurrentBlock persist the current block height and current block hash to store func (this *BlockStore) SaveCurrentBlock(height uint32, blockHash common.Uint256) error { key := this.getCurrentBlockKey() - value := bytes.NewBuffer(nil) - blockHash.Serialize(value) - serialization.WriteUint32(value, height) + value := common.NewZeroCopySink(nil) + value.WriteHash(blockHash) + value.WriteUint32(height) this.store.BatchPut(key, value.Bytes()) return nil } @@ -320,17 +316,16 @@ func (this *BlockStore) GetHeaderIndexList() (map[uint32]common.Uint256, error) } //SaveHeaderIndexList persist header index list to store -func (this *BlockStore) SaveHeaderIndexList(startIndex uint32, indexList []common.Uint256) error { +func (this *BlockStore) SaveHeaderIndexList(startIndex uint32, indexList []common.Uint256) { indexKey := this.getHeaderIndexListKey(startIndex) indexSize := uint32(len(indexList)) - value := bytes.NewBuffer(nil) - serialization.WriteUint32(value, indexSize) + value := common.NewZeroCopySink(nil) + value.WriteUint32(indexSize) for _, hash := range indexList { - hash.Serialize(value) + value.WriteHash(hash) } this.store.BatchPut(indexKey, value.Bytes()) - return nil } //GetBlockHash return block hash by block height @@ -354,24 +349,20 @@ func (this *BlockStore) SaveBlockHash(height uint32, blockHash common.Uint256) { } //SaveTransaction persist transaction to store -func (this *BlockStore) SaveTransaction(tx *types.Transaction, height uint32) error { +func (this *BlockStore) SaveTransaction(tx *types.Transaction, height uint32) { if this.enableCache { this.cache.AddTransaction(tx, height) } - return this.putTransaction(tx, height) + this.putTransaction(tx, height) } -func (this *BlockStore) putTransaction(tx *types.Transaction, height uint32) error { +func (this *BlockStore) putTransaction(tx *types.Transaction, height uint32) { txHash := tx.Hash() key := this.getTransactionKey(txHash) - value := bytes.NewBuffer(nil) - serialization.WriteUint32(value, height) - err := tx.Serialize(value) - if err != nil { - return err - } + value := common.NewZeroCopySink(nil) + value.WriteUint32(height) + tx.Serialization(value) this.store.BatchPut(key, value.Bytes()) - return nil } //GetTransaction return transaction by transaction hash @@ -510,10 +501,10 @@ func (this *BlockStore) getVersionKey() []byte { } func (this *BlockStore) getHeaderIndexListKey(startHeight uint32) []byte { - key := bytes.NewBuffer(nil) - key.WriteByte(byte(scom.IX_HEADER_HASH_LIST)) - serialization.WriteUint32(key, startHeight) - return key.Bytes() + sink := common.NewZeroCopySink(nil) + sink.WriteByte(byte(scom.IX_HEADER_HASH_LIST)) + sink.WriteUint32(startHeight) + return sink.Bytes() } func (this *BlockStore) getStartHeightByHeaderIndexKey(key []byte) (uint32, error) { diff --git a/core/store/ledgerstore/block_store_test.go b/core/store/ledgerstore/block_store_test.go index 5e74a04258..dadf142224 100644 --- a/core/store/ledgerstore/block_store_test.go +++ b/core/store/ledgerstore/block_store_test.go @@ -19,7 +19,6 @@ package ledgerstore import ( - "bytes" "crypto/sha256" "fmt" "github.com/ontio/ontology-crypto/keypair" @@ -147,11 +146,7 @@ func TestSaveTransaction(t *testing.T) { } testBlockStore.NewBatch() - err = testBlockStore.SaveTransaction(tx, blockHeight) - if err != nil { - t.Errorf("SaveTransaction error %s", err) - return - } + testBlockStore.SaveTransaction(tx, blockHeight) err = testBlockStore.CommitTo() if err != nil { t.Errorf("CommitTo error %s", err) @@ -199,11 +194,7 @@ func TestHeaderIndexList(t *testing.T) { indexMap[i] = hash indexList = append(indexList, hash) } - err := testBlockStore.SaveHeaderIndexList(startHeight, indexList) - if err != nil { - t.Errorf("SaveHeaderIndexList error %s", err) - return - } + testBlockStore.SaveHeaderIndexList(startHeight, indexList) startHeight = uint32(100) size = uint32(100) indexMap = make(map[uint32]common.Uint256, size) @@ -212,7 +203,7 @@ func TestHeaderIndexList(t *testing.T) { indexMap[i] = hash indexList = append(indexList, hash) } - err = testBlockStore.CommitTo() + err := testBlockStore.CommitTo() if err != nil { t.Errorf("CommitTo error %s", err) return @@ -374,20 +365,12 @@ func TestBlock(t *testing.T) { } func transferTx(from, to common.Address, amount uint64) (*types.Transaction, error) { - buf := bytes.NewBuffer(nil) var sts []ont.State sts = append(sts, ont.State{ From: from, To: to, Value: amount, }) - transfers := &ont.Transfers{ - States: sts, - } - err := transfers.Serialize(buf) - if err != nil { - return nil, fmt.Errorf("transfers.Serialize error %s", err) - } var cversion byte return invokeSmartContractTx(0, 30000, cversion, nutils.OntContractAddress, "transfer", []interface{}{sts}) } diff --git a/core/store/ledgerstore/event_store.go b/core/store/ledgerstore/event_store.go index acee4c886b..af18ba45c1 100644 --- a/core/store/ledgerstore/event_store.go +++ b/core/store/ledgerstore/event_store.go @@ -60,37 +60,25 @@ func (this *EventStore) SaveEventNotifyByTx(txHash common.Uint256, notify *event if err != nil { return fmt.Errorf("json.Marshal error %s", err) } - key := this.getEventNotifyByTxKey(txHash) + key := genEventNotifyByTxKey(txHash) this.store.BatchPut(key, result) return nil } //SaveEventNotifyByBlock persist transaction hash which have event notify to store -func (this *EventStore) SaveEventNotifyByBlock(height uint32, txHashs []common.Uint256) error { - key, err := this.getEventNotifyByBlockKey(height) - if err != nil { - return err - } - - values := bytes.NewBuffer(nil) - err = serialization.WriteUint32(values, uint32(len(txHashs))) - if err != nil { - return err - } +func (this *EventStore) SaveEventNotifyByBlock(height uint32, txHashs []common.Uint256) { + key := genEventNotifyByBlockKey(height) + values := common.NewZeroCopySink(nil) + values.WriteUint32(uint32(len(txHashs))) for _, txHash := range txHashs { - err = txHash.Serialize(values) - if err != nil { - return err - } + values.WriteHash(txHash) } this.store.BatchPut(key, values.Bytes()) - - return nil } //GetEventNotifyByTx return event notify by trasanction hash func (this *EventStore) GetEventNotifyByTx(txHash common.Uint256) (*event.ExecuteNotify, error) { - key := this.getEventNotifyByTxKey(txHash) + key := genEventNotifyByTxKey(txHash) data, err := this.store.Get(key) if err != nil { return nil, err @@ -104,10 +92,7 @@ func (this *EventStore) GetEventNotifyByTx(txHash common.Uint256) (*event.Execut //GetEventNotifyByBlock return all event notify of transaction in block func (this *EventStore) GetEventNotifyByBlock(height uint32) ([]*event.ExecuteNotify, error) { - key, err := this.getEventNotifyByBlockKey(height) - if err != nil { - return nil, err - } + key := genEventNotifyByBlockKey(height) data, err := this.store.Get(key) if err != nil { return nil, err @@ -159,14 +144,12 @@ func (this *EventStore) ClearAll() error { } //SaveCurrentBlock persist current block height and block hash to event store -func (this *EventStore) SaveCurrentBlock(height uint32, blockHash common.Uint256) error { +func (this *EventStore) SaveCurrentBlock(height uint32, blockHash common.Uint256) { key := this.getCurrentBlockKey() - value := bytes.NewBuffer(nil) - blockHash.Serialize(value) - serialization.WriteUint32(value, height) + value := common.NewZeroCopySink(nil) + value.WriteHash(blockHash) + value.WriteUint32(height) this.store.BatchPut(key, value.Bytes()) - - return nil } //GetCurrentBlock return current block hash, and block height @@ -193,14 +176,14 @@ func (this *EventStore) getCurrentBlockKey() []byte { return []byte{byte(scom.SYS_CURRENT_BLOCK)} } -func (this *EventStore) getEventNotifyByBlockKey(height uint32) ([]byte, error) { +func genEventNotifyByBlockKey(height uint32) []byte { key := make([]byte, 5, 5) key[0] = byte(scom.EVENT_NOTIFY) binary.LittleEndian.PutUint32(key[1:], height) - return key, nil + return key } -func (this *EventStore) getEventNotifyByTxKey(txHash common.Uint256) []byte { +func genEventNotifyByTxKey(txHash common.Uint256) []byte { data := txHash.ToArray() key := make([]byte, 1+len(data)) key[0] = byte(scom.EVENT_NOTIFY) diff --git a/core/store/ledgerstore/ledger_store.go b/core/store/ledgerstore/ledger_store.go index 0c8b85085f..4b236340db 100644 --- a/core/store/ledgerstore/ledger_store.go +++ b/core/store/ledgerstore/ledger_store.go @@ -307,10 +307,7 @@ func (this *LedgerStoreImp) recoverStore() error { if err != nil { return fmt.Errorf("save to state store height:%d error:%s", i, err) } - err = this.saveBlockToEventStore(block) - if err != nil { - return fmt.Errorf("save to event store height:%d error:%s", i, err) - } + this.saveBlockToEventStore(block) err = this.eventStore.CommitTo() if err != nil { return fmt.Errorf("eventStore.CommitTo height:%d error %s", i, err) @@ -738,7 +735,7 @@ func (this *LedgerStoreImp) saveBlockToStateStore(block *types.Block, result sto return nil } -func (this *LedgerStoreImp) saveBlockToEventStore(block *types.Block) error { +func (this *LedgerStoreImp) saveBlockToEventStore(block *types.Block) { blockHash := block.Hash() blockHeight := block.Header.Height txs := make([]common.Uint256, 0) @@ -747,16 +744,9 @@ func (this *LedgerStoreImp) saveBlockToEventStore(block *types.Block) error { txs = append(txs, txHash) } if len(txs) > 0 { - err := this.eventStore.SaveEventNotifyByBlock(block.Header.Height, txs) - if err != nil { - return fmt.Errorf("SaveEventNotifyByBlock error %s", err) - } - } - err := this.eventStore.SaveCurrentBlock(blockHeight, blockHash) - if err != nil { - return fmt.Errorf("SaveCurrentBlock error %s", err) + this.eventStore.SaveEventNotifyByBlock(block.Header.Height, txs) } - return nil + this.eventStore.SaveCurrentBlock(blockHeight, blockHash) } func (this *LedgerStoreImp) tryGetSavingBlockLock() (hasLocked bool) { @@ -802,10 +792,7 @@ func (this *LedgerStoreImp) submitBlock(block *types.Block, result store.Execute if err != nil { return fmt.Errorf("save to state store height:%d error:%s", blockHeight, err) } - err = this.saveBlockToEventStore(block) - if err != nil { - return fmt.Errorf("save to event store height:%d error:%s", blockHeight, err) - } + this.saveBlockToEventStore(block) err = this.blockStore.CommitTo() if err != nil { return fmt.Errorf("blockStore.CommitTo height:%d error %s", blockHeight, err) @@ -851,7 +838,8 @@ func (this *LedgerStoreImp) saveBlock(block *types.Block, stateMerkleRoot common return err } - if result.MerkleRoot != stateMerkleRoot { + //empty block does not check stateMerkleRoot + if len(block.Transactions) != 0 && result.MerkleRoot != stateMerkleRoot { return fmt.Errorf("state merkle root mismatch. expected: %s, got: %s", result.MerkleRoot.ToHexString(), stateMerkleRoot.ToHexString()) } @@ -901,10 +889,7 @@ func (this *LedgerStoreImp) saveHeaderIndexList() error { } this.lock.RUnlock() - err := this.blockStore.SaveHeaderIndexList(storeCount, headerList) - if err != nil { - return fmt.Errorf("SaveHeaderIndexList start %d error %s", storeCount, err) - } + this.blockStore.SaveHeaderIndexList(storeCount, headerList) this.lock.Lock() this.storedIndexCount += HEADER_INDEX_BATCH_SIZE diff --git a/core/store/ledgerstore/state_store.go b/core/store/ledgerstore/state_store.go index c4eb4c5cf7..d03ae4bf29 100644 --- a/core/store/ledgerstore/state_store.go +++ b/core/store/ledgerstore/state_store.go @@ -271,9 +271,9 @@ func (self *StateStore) GetBookkeeperState() (*states.BookkeeperState, error) { if err != nil { return nil, err } - reader := bytes.NewReader(value) + reader := common.NewZeroCopySource(value) bookkeeperState := new(states.BookkeeperState) - err = bookkeeperState.Deserialize(reader) + err = bookkeeperState.Deserialization(reader) if err != nil { return nil, err } @@ -286,11 +286,8 @@ func (self *StateStore) SaveBookkeeperState(bookkeeperState *states.BookkeeperSt if err != nil { return err } - value := bytes.NewBuffer(nil) - err = bookkeeperState.Serialize(value) - if err != nil { - return err - } + value := common.NewZeroCopySink(nil) + bookkeeperState.Serialization(value) return self.store.Put(key, value.Bytes()) } @@ -306,9 +303,9 @@ func (self *StateStore) GetStorageState(key *states.StorageKey) (*states.Storage if err != nil { return nil, err } - reader := bytes.NewReader(data) + reader := common.NewZeroCopySource(data) storageState := new(states.StorageItem) - err = storageState.Deserialize(reader) + err = storageState.Deserialization(reader) if err != nil { return nil, err } @@ -424,8 +421,8 @@ func (self *StateStore) CheckStorage() error { val, err := db.Get(flag) if err == nil { item := &states.StorageItem{} - buf := bytes.NewBuffer(val) - err := item.Deserialize(buf) + source := common.NewZeroCopySource(val) + err := item.Deserialization(source) if err == nil && item.Value[0] == ontid.FLAG_VERSION { return nil } else if err == nil { @@ -452,8 +449,8 @@ func (self *StateStore) CheckStorage() error { tag := states.StorageItem{} tag.Value = []byte{ontid.FLAG_VERSION} - buf := bytes.NewBuffer(nil) - tag.Serialize(buf) + buf := common.NewZeroCopySink(nil) + tag.Serialization(buf) db.BatchPut(flag, buf.Bytes()) err = db.BatchCommit() diff --git a/core/store/ledgerstore/tx_handler.go b/core/store/ledgerstore/tx_handler.go index c2ca02bf67..6a39822981 100644 --- a/core/store/ledgerstore/tx_handler.go +++ b/core/store/ledgerstore/tx_handler.go @@ -27,7 +27,6 @@ import ( "github.com/ontio/ontology/common" sysconfig "github.com/ontio/ontology/common/config" "github.com/ontio/ontology/common/log" - "github.com/ontio/ontology/common/serialization" "github.com/ontio/ontology/core/payload" "github.com/ontio/ontology/core/store" scommon "github.com/ontio/ontology/core/store/common" @@ -266,10 +265,8 @@ func SaveNotify(eventStore scommon.EventStore, txHash common.Uint256, notify *ev } func genNativeTransferCode(from, to common.Address, value uint64) []byte { - transfer := ont.Transfers{States: []ont.State{{From: from, To: to, Value: value}}} - tr := new(bytes.Buffer) - transfer.Serialize(tr) - return tr.Bytes() + transfer := &ont.Transfers{States: []ont.State{{From: from, To: to, Value: value}}} + return common.SerializeToBytes(transfer) } // check whether payer ong balance sufficient @@ -305,14 +302,10 @@ func chargeCostGas(payer common.Address, gas uint64, config *smartcontract.Confi } func refreshGlobalParam(config *smartcontract.Config, cache *storage.CacheDB, store store.LedgerStore) error { - bf := new(bytes.Buffer) - if err := utils.WriteVarUint(bf, uint64(len(neovm.GAS_TABLE_KEYS))); err != nil { - return fmt.Errorf("write gas_table_keys length error:%s", err) - } + sink := common.NewZeroCopySink(nil) + utils.EncodeVarUint(sink, uint64(len(neovm.GAS_TABLE_KEYS))) for _, value := range neovm.GAS_TABLE_KEYS { - if err := serialization.WriteString(bf, value); err != nil { - return fmt.Errorf("serialize param name error:%s", value) - } + sink.WriteString(value) } sc := smartcontract.SmartContract{ @@ -323,12 +316,12 @@ func refreshGlobalParam(config *smartcontract.Config, cache *storage.CacheDB, st } service, _ := sc.NewNativeService() - result, err := service.NativeCall(utils.ParamContractAddress, "getGlobalParam", bf.Bytes()) + result, err := service.NativeCall(utils.ParamContractAddress, "getGlobalParam", sink.Bytes()) if err != nil { return err } params := new(global_params.Params) - if err := params.Deserialize(bytes.NewBuffer(result.([]byte))); err != nil { + if err := params.Deserialization(common.NewZeroCopySource(result.([]byte))); err != nil { return fmt.Errorf("deserialize global params error:%s", err) } neovm.GAS_TABLE.Range(func(key, value interface{}) bool { @@ -347,10 +340,8 @@ func refreshGlobalParam(config *smartcontract.Config, cache *storage.CacheDB, st } func getBalanceFromNative(config *smartcontract.Config, cache *storage.CacheDB, store store.LedgerStore, address common.Address) (uint64, error) { - bf := new(bytes.Buffer) - if err := utils.WriteAddress(bf, address); err != nil { - return 0, err - } + bf := common.NewZeroCopySink(nil) + utils.EncodeAddress(bf, address) sc := smartcontract.SmartContract{ Config: config, CacheDB: cache, diff --git a/core/types/transaction.go b/core/types/transaction.go index c62562e8c8..3d44a2daba 100644 --- a/core/types/transaction.go +++ b/core/types/transaction.go @@ -19,7 +19,6 @@ package types import ( - "bytes" "crypto/sha256" "errors" "fmt" @@ -28,7 +27,6 @@ import ( "github.com/ontio/ontology-crypto/keypair" "github.com/ontio/ontology/common" "github.com/ontio/ontology/common/constants" - "github.com/ontio/ontology/common/serialization" "github.com/ontio/ontology/core/payload" "github.com/ontio/ontology/core/program" ) @@ -204,34 +202,6 @@ func (self *RawSig) Serialization(sink *common.ZeroCopySink) error { return nil } -func (self *RawSig) Serialize(w io.Writer) error { - err := serialization.WriteVarBytes(w, self.Invoke) - if err != nil { - return err - } - err = serialization.WriteVarBytes(w, self.Verify) - if err != nil { - return err - } - - return nil -} - -func (self *RawSig) Deserialize(r io.Reader) error { - invoke, err := serialization.ReadVarBytes(r) - if err != nil { - return err - } - verify, err := serialization.ReadVarBytes(r) - if err != nil { - return err - } - self.Invoke = invoke - self.Verify = verify - - return nil -} - func (self *RawSig) Deserialization(source *common.ZeroCopySource) error { var eof, irregular bool self.Invoke, _, irregular, eof = source.NextVarBytes() @@ -308,57 +278,6 @@ func (self *Sig) Serialization(sink *common.ZeroCopySink) error { return nil } -func (self *Sig) Serialize(w io.Writer) error { - invocationScript := program.ProgramFromParams(self.SigData) - var verificationScript []byte - if len(self.PubKeys) == 0 { - return errors.New("no pubkeys in sig") - } else if len(self.PubKeys) == 1 { - verificationScript = program.ProgramFromPubKey(self.PubKeys[0]) - } else { - script, err := program.ProgramFromMultiPubKey(self.PubKeys, int(self.M)) - if err != nil { - return err - } - verificationScript = script - } - err := serialization.WriteVarBytes(w, invocationScript) - if err != nil { - return err - } - err = serialization.WriteVarBytes(w, verificationScript) - if err != nil { - return err - } - - return nil -} - -func (self *Sig) Deserialize(r io.Reader) error { - invocationScript, err := serialization.ReadVarBytes(r) - if err != nil { - return err - } - verificationScript, err := serialization.ReadVarBytes(r) - if err != nil { - return err - } - sigs, err := program.GetParamInfo(invocationScript) - if err != nil { - return err - } - info, err := program.GetProgramInfo(verificationScript) - if err != nil { - return err - } - - self.SigData = sigs - self.M = info.M - self.PubKeys = info.PubKeys - - return nil -} - func (self *Transaction) GetSignatureAddresses() ([]common.Address, error) { if len(self.SignedAddr) == 0 { addrs := make([]common.Address, 0, len(self.Sigs)) @@ -397,19 +316,8 @@ func (tx *Transaction) Serialization(sink *common.ZeroCopySink) { sink.WriteBytes(tx.Raw) } -// Serialize the Transaction -func (tx *Transaction) Serialize(w io.Writer) error { - if tx.nonDirectConstracted == false || len(tx.Raw) == 0 { - panic("wrong constructed transaction") - } - _, err := w.Write(tx.Raw) - return err -} - func (tx *Transaction) ToArray() []byte { - b := new(bytes.Buffer) - tx.Serialize(b) - return b.Bytes() + return common.SerializeToBytes(tx) } func (tx *Transaction) Hash() common.Uint256 { diff --git a/core/types/transaction_attribute.go b/core/types/transaction_attribute.go index d33e1d06d3..fdf2e2b62a 100644 --- a/core/types/transaction_attribute.go +++ b/core/types/transaction_attribute.go @@ -19,12 +19,11 @@ package types import ( - "bytes" "errors" "fmt" "io" - "github.com/ontio/ontology/common/serialization" + "github.com/ontio/ontology/common" ) type TransactionAttributeUsage byte @@ -60,38 +59,38 @@ func (u *TxAttribute) GetSize() uint32 { return 0 } -func (tx *TxAttribute) Serialize(w io.Writer) error { - if err := serialization.WriteUint8(w, byte(tx.Usage)); err != nil { - return fmt.Errorf("Transaction attribute Usage serialization error: %s", err) - } +func (tx *TxAttribute) Serialization(sink *common.ZeroCopySink) error { if !IsValidAttributeType(tx.Usage) { return errors.New("Unsupported attribute Description.") } - if err := serialization.WriteVarBytes(w, tx.Data); err != nil { - return fmt.Errorf("Transaction attribute Data serialization error: %s", err) - } + sink.WriteUint8(byte(tx.Usage)) + sink.WriteVarBytes(tx.Data) return nil } -func (tx *TxAttribute) Deserialize(r io.Reader) error { - val, err := serialization.ReadBytes(r, 1) - if err != nil { - return fmt.Errorf("Transaction attribute Usage deserialization error: %s", err) +func (tx *TxAttribute) Deserialization(source *common.ZeroCopySource) error { + val, eof := source.NextBytes(1) + if eof { + return fmt.Errorf("Transaction attribute Usage deserialization error: %s", io.ErrUnexpectedEOF) } tx.Usage = TransactionAttributeUsage(val[0]) if !IsValidAttributeType(tx.Usage) { return errors.New("[TxAttribute] Unsupported attribute Description.") } - tx.Data, err = serialization.ReadVarBytes(r) - if err != nil { - return fmt.Errorf("Transaction attribute Data deserialization error: %s", err) + var irregular bool + tx.Data, _, irregular, eof = source.NextVarBytes() + if irregular { + return fmt.Errorf("Transaction attribute Data deserialization error: %s", common.ErrIrregularData) + } + if eof { + return fmt.Errorf("Transaction attribute Data deserialization error: %s", io.ErrUnexpectedEOF) } return nil } func (tx *TxAttribute) ToArray() []byte { - bf := new(bytes.Buffer) - tx.Serialize(bf) - return bf.Bytes() + sink := common.NewZeroCopySink(nil) + tx.Serialization(sink) + return sink.Bytes() } diff --git a/core/utils/transaction_builder.go b/core/utils/transaction_builder.go index efe2da24d7..8a723fb2ff 100644 --- a/core/utils/transaction_builder.go +++ b/core/utils/transaction_builder.go @@ -193,9 +193,8 @@ func BuildWasmVMInvokeCode(contractAddress common.Address, params []interface{}) return nil, fmt.Errorf("build wasm contract param failed:%s", err) } contract.Args = argbytes - sink := common.NewZeroCopySink(nil) - contract.Serialization(sink) - return sink.Bytes(), nil + + return common.SerializeToBytes(contract), nil } //build param bytes for wasm contract diff --git a/http/base/common/common.go b/http/base/common/common.go index 856bfd0df5..6e573e100d 100644 --- a/http/base/common/common.go +++ b/http/base/common/common.go @@ -27,7 +27,6 @@ import ( "github.com/ontio/ontology/common" "github.com/ontio/ontology/common/constants" "github.com/ontio/ontology/common/log" - "github.com/ontio/ontology/common/serialization" "github.com/ontio/ontology/core/ledger" "github.com/ontio/ontology/core/payload" "github.com/ontio/ontology/core/types" @@ -39,6 +38,7 @@ import ( "github.com/ontio/ontology/smartcontract/service/native/utils" cstate "github.com/ontio/ontology/smartcontract/states" "github.com/ontio/ontology/vm/neovm" + "io" "strings" "time" ) @@ -299,9 +299,10 @@ func GetGrantOng(addr common.Address) (string, error) { if err != nil { value = []byte{0, 0, 0, 0} } - v, err := serialization.ReadUint32(bytes.NewBuffer(value)) - if err != nil { - return fmt.Sprintf("%v", 0), err + source := common.NewZeroCopySource(value) + v, eof := source.NextUint32() + if eof { + return fmt.Sprintf("%v", 0), io.ErrUnexpectedEOF } ont, err := GetContractBalance(0, utils.OntContractAddress, addr) if err != nil { diff --git a/http/base/rest/interfaces.go b/http/base/rest/interfaces.go index f5f21ce36a..fa13157c9e 100644 --- a/http/base/rest/interfaces.go +++ b/http/base/rest/interfaces.go @@ -19,7 +19,6 @@ package rest import ( - "bytes" "github.com/ontio/ontology/common" "github.com/ontio/ontology/common/config" "github.com/ontio/ontology/common/log" @@ -231,9 +230,7 @@ func GetTransactionByHash(cmd map[string]interface{}) map[string]interface{} { return ResponsePack(berr.UNKNOWN_TRANSACTION) } if raw, ok := cmd["Raw"].(string); ok && raw == "1" { - w := bytes.NewBuffer(nil) - tx.Serialize(w) - resp["Result"] = common.ToHexString(w.Bytes()) + resp["Result"] = common.ToHexString(common.SerializeToBytes(tx)) return resp } tran := bcomn.TransArryByteToHexString(tx) diff --git a/http/base/rpc/interfaces.go b/http/base/rpc/interfaces.go index e6c583d548..28337994ad 100644 --- a/http/base/rpc/interfaces.go +++ b/http/base/rpc/interfaces.go @@ -19,7 +19,6 @@ package rpc import ( - "bytes" "encoding/hex" "github.com/ontio/ontology/common" "github.com/ontio/ontology/common/config" @@ -209,9 +208,8 @@ func GetRawTransaction(params []interface{}) map[string]interface{} { return responsePack(berr.INVALID_PARAMS, "") } } - w := bytes.NewBuffer(nil) - tx.Serialize(w) - return responseSuccess(common.ToHexString(w.Bytes())) + + return responseSuccess(common.ToHexString(common.SerializeToBytes(tx))) } //get storage from contract diff --git a/p2pserver/block_sync.go b/p2pserver/block_sync.go index 2f9c4635e9..d6e683e3c9 100644 --- a/p2pserver/block_sync.go +++ b/p2pserver/block_sync.go @@ -224,7 +224,7 @@ type BlockInfo struct { type BlockSyncMgr struct { flightBlocks map[common.Uint256][]*SyncFlightInfo //Map BlockHash => []SyncFlightInfo, using for manager all of those block flights flightHeaders map[uint32]*SyncFlightInfo //Map HeaderHeight => SyncFlightInfo, using for manager all of those header flights - blocksCache map[uint32]*BlockInfo //Map BlockHash => BlockInfo, using for cache the blocks receive from net, and waiting for commit to ledger + blocksCache *BlockCache //Map BlockHash => BlockInfo, using for cache the blocks receive from net, and waiting for commit to ledger server *P2PServer //Pointer to the local node syncBlockLock bool //Help to avoid send block sync request duplicate syncHeaderLock bool //Help to avoid send header sync request duplicate @@ -240,7 +240,7 @@ func NewBlockSyncMgr(server *P2PServer) *BlockSyncMgr { return &BlockSyncMgr{ flightBlocks: make(map[common.Uint256][]*SyncFlightInfo, 0), flightHeaders: make(map[uint32]*SyncFlightInfo, 0), - blocksCache: make(map[uint32]*BlockInfo, 0), + blocksCache: NewBlockCache(), server: server, ledger: server.ledger, exitCh: make(chan interface{}, 1), @@ -248,6 +248,80 @@ func NewBlockSyncMgr(server *P2PServer) *BlockSyncMgr { } } +type BlockCache struct { + emptyBlockAmount int + blocksCache map[uint32]*BlockInfo //Map BlockHeight => BlockInfo, using for cache the blocks receive from net, and waiting for commit to ledger +} + +func NewBlockCache() *BlockCache { + return &BlockCache{ + emptyBlockAmount: 0, + blocksCache: make(map[uint32]*BlockInfo, 0), + } +} + +func (this *BlockCache) addBlock(nodeID uint64, block *types.Block, + merkleRoot common.Uint256) bool { + this.delBlockLocked(block.Header.Height) + blockInfo := &BlockInfo{ + nodeID: nodeID, + block: block, + merkleRoot: merkleRoot, + } + this.blocksCache[block.Header.Height] = blockInfo + if block.Header.TransactionsRoot == common.UINT256_EMPTY { + this.emptyBlockAmount += 1 + } + return true +} + +func (this *BlockSyncMgr) clearBlocks(curBlockHeight uint32) { + this.lock.Lock() + this.blocksCache.clearBlocks(curBlockHeight) + this.lock.Unlock() +} + +func (this *BlockCache) clearBlocks(curBlockHeight uint32) { + for height := range this.blocksCache { + if height < curBlockHeight { + this.delBlockLocked(height) + } + } +} + +func (this *BlockCache) getBlock(blockHeight uint32) (uint64, *types.Block, + common.Uint256) { + blockInfo, ok := this.blocksCache[blockHeight] + if !ok { + return 0, nil, common.UINT256_EMPTY + } + return blockInfo.nodeID, blockInfo.block, blockInfo.merkleRoot +} + +func (this *BlockCache) delBlockLocked(blockHeight uint32) { + blockInfo, ok := this.blocksCache[blockHeight] + if ok { + if blockInfo.block.Header.TransactionsRoot == common.UINT256_EMPTY { + this.emptyBlockAmount -= 1 + } + } + delete(this.blocksCache, blockHeight) +} + +func (this *BlockCache) isInBlockCache(blockHeight uint32) bool { + _, ok := this.blocksCache[blockHeight] + return ok +} + +func (this *BlockCache) getNonEmptyBlockCount() int { + return len(this.blocksCache) - this.emptyBlockAmount +} +func (this *BlockSyncMgr) getNonEmptyBlockCount() int { + this.lock.RLock() + defer this.lock.RUnlock() + return this.blocksCache.getNonEmptyBlockCount() +} + //Start to sync func (this *BlockSyncMgr) Start() { go this.sync() @@ -400,7 +474,7 @@ func (this *BlockSyncMgr) syncBlock() { if count > availCount { count = availCount } - cacheCap := SYNC_MAX_BLOCK_CACHE_SIZE - this.getBlockCacheSize() + cacheCap := SYNC_MAX_BLOCK_CACHE_SIZE - this.getNonEmptyBlockCount() if count > cacheCap { count = cacheCap } @@ -479,6 +553,32 @@ func (this *BlockSyncMgr) OnHeaderReceive(fromID uint64, headers []*types.Header log.Warnf("[p2p]OnHeaderReceive AddHeaders error:%s", err) return } + sort.Slice(headers, func(i, j int) bool { + return headers[i].Height < headers[j].Height + }) + curHeaderHeight = this.ledger.GetCurrentHeaderHeight() + curBlockHeight := this.ledger.GetCurrentBlockHeight() + for _, header := range headers { + //handle empty block + if header.TransactionsRoot == common.UINT256_EMPTY { + log.Trace("[p2p]OnHeaderReceive empty block Height:%d", header.Height) + height := header.Height + blockHash := header.Hash() + this.delFlightBlock(blockHash) + nextHeader := curHeaderHeight + 1 + if height > nextHeader { + break + } + if height <= curBlockHeight { + continue + } + block := &types.Block{ + Header: header, + } + this.addBlockCache(fromID, block, common.UINT256_EMPTY) + } + } + go this.saveBlock() this.syncHeader() } @@ -487,7 +587,7 @@ func (this *BlockSyncMgr) OnBlockReceive(fromID uint64, blockSize uint32, block merkleRoot common.Uint256) { height := block.Header.Height blockHash := block.Hash() - log.Trace("[p2p]OnBlockReceive Height:%d", height) + log.Tracef("[p2p]OnBlockReceive Height:%d", height) flightInfo := this.getFlightBlock(blockHash, fromID) if flightInfo != nil { t := (time.Now().UnixNano() - flightInfo.GetStartTime().UnixNano()) / int64(time.Millisecond) @@ -574,30 +674,20 @@ func (this *BlockSyncMgr) addBlockCache(nodeID uint64, block *types.Block, merkleRoot common.Uint256) bool { this.lock.Lock() defer this.lock.Unlock() - blockInfo := &BlockInfo{ - nodeID: nodeID, - block: block, - merkleRoot: merkleRoot, - } - this.blocksCache[block.Header.Height] = blockInfo - return true + return this.blocksCache.addBlock(nodeID, block, merkleRoot) } func (this *BlockSyncMgr) getBlockCache(blockHeight uint32) (uint64, *types.Block, common.Uint256) { this.lock.RLock() defer this.lock.RUnlock() - blockInfo, ok := this.blocksCache[blockHeight] - if !ok { - return 0, nil, common.UINT256_EMPTY - } - return blockInfo.nodeID, blockInfo.block, blockInfo.merkleRoot + return this.blocksCache.getBlock(blockHeight) } func (this *BlockSyncMgr) delBlockCache(blockHeight uint32) { this.lock.Lock() defer this.lock.Unlock() - delete(this.blocksCache, blockHeight) + this.blocksCache.delBlockLocked(blockHeight) } func (this *BlockSyncMgr) tryGetSaveBlockLock() bool { @@ -623,13 +713,7 @@ func (this *BlockSyncMgr) saveBlock() { defer this.releaseSaveBlockLock() curBlockHeight := this.ledger.GetCurrentBlockHeight() nextBlockHeight := curBlockHeight + 1 - this.lock.Lock() - for height := range this.blocksCache { - if height <= curBlockHeight { - delete(this.blocksCache, height) - } - } - this.lock.Unlock() + this.clearBlocks(curBlockHeight) for { fromID, nextBlock, merkleRoot := this.getBlockCache(nextBlockHeight) if nextBlock == nil { @@ -667,14 +751,7 @@ func (this *BlockSyncMgr) saveBlock() { func (this *BlockSyncMgr) isInBlockCache(blockHeight uint32) bool { this.lock.RLock() defer this.lock.RUnlock() - _, ok := this.blocksCache[blockHeight] - return ok -} - -func (this *BlockSyncMgr) getBlockCacheSize() int { - this.lock.RLock() - defer this.lock.RUnlock() - return len(this.blocksCache) + return this.blocksCache.isInBlockCache(blockHeight) } func (this *BlockSyncMgr) addFlightHeader(nodeId uint64, height uint32) { diff --git a/p2pserver/common/set/string_set.go b/p2pserver/common/set/string_set.go new file mode 100644 index 0000000000..4ea426ebfd --- /dev/null +++ b/p2pserver/common/set/string_set.go @@ -0,0 +1,74 @@ +/* + * Copyright (C) 2018 The ontology Authors + * This file is part of The ontology library. + * + * The ontology is free software: you can redistribute it and/or modify + * it under the terms of the GNU Lesser General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * The ontology is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public License + * along with The ontology. If not, see . + */ + +package set + +import ( + "reflect" +) + +type empty struct{} + +// StringSet is a set of strings, implemented via map[string]struct{} for minimal memory consumption. +type StringSet map[string]empty + +// NewStringSet creates a StringSet from a list of values. +func NewStringSet(items ...string) StringSet { + ss := StringSet{} + ss.Insert(items...) + return ss +} + +// StringKeySet creates a StringSet from a keys of a map[string](? extends interface{}). +// If the value passed in is not actually a map, this will panic. +func StringKeySet(theMap interface{}) StringSet { + v := reflect.ValueOf(theMap) + ret := StringSet{} + + for _, keyValue := range v.MapKeys() { + ret.Insert(keyValue.Interface().(string)) + } + return ret +} + +// Insert adds items to the set. +func (s StringSet) Insert(items ...string) StringSet { + for _, item := range items { + s[item] = empty{} + } + return s +} + +// Delete removes all items from the set. +func (s StringSet) Delete(items ...string) StringSet { + for _, item := range items { + delete(s, item) + } + return s +} + +// Has returns true if and only if item is contained in the set. +func (s StringSet) Has(item string) bool { + _, contained := s[item] + return contained +} + +// Len returns the size of the set. +func (s StringSet) Len() int { + return len(s) +} diff --git a/p2pserver/common/set/string_set_test.go b/p2pserver/common/set/string_set_test.go new file mode 100644 index 0000000000..96b2781892 --- /dev/null +++ b/p2pserver/common/set/string_set_test.go @@ -0,0 +1,78 @@ +/* + * Copyright (C) 2018 The ontology Authors + * This file is part of The ontology library. + * + * The ontology is free software: you can redistribute it and/or modify + * it under the terms of the GNU Lesser General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * The ontology is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public License + * along with The ontology. If not, see . + */ + +package set + +import ( + "testing" +) + +func TestStringSet(t *testing.T) { + s := StringSet{} + if s.Len() != 0 { + t.Errorf("Expected len=0: %d", len(s)) + } + s.Insert("a", "b") + if s.Len() != 2 { + t.Errorf("Expected len=2: %d", len(s)) + } + s.Insert("c") + if s.Has("d") { + t.Errorf("Unexpected contents: %#v", s) + } + if !s.Has("a") { + t.Errorf("Missing contents: %#v", s) + } + s.Delete("a") + if s.Has("a") { + t.Errorf("Unexpected contents: %#v", s) + } +} + +func TestStringSetDeleteMultiples(t *testing.T) { + s := StringSet{} + s.Insert("a", "b", "c") + if len(s) != 3 { + t.Errorf("Expected len=3: %d", len(s)) + } + + s.Delete("a", "c") + if len(s) != 1 { + t.Errorf("Expected len=1: %d", len(s)) + } + if s.Has("a") { + t.Errorf("Unexpected contents: %#v", s) + } + if s.Has("c") { + t.Errorf("Unexpected contents: %#v", s) + } + if !s.Has("b") { + t.Errorf("Missing contents: %#v", s) + } + +} + +func TestNewStringSet(t *testing.T) { + s := NewStringSet("a", "b", "c") + if len(s) != 3 { + t.Errorf("Expected len=3: %d", len(s)) + } + if !s.Has("a") || !s.Has("b") || !s.Has("c") { + t.Errorf("Unexpected contents: %#v", s) + } +} diff --git a/p2pserver/message/types/consensus_payload.go b/p2pserver/message/types/consensus_payload.go index 15f35c2a18..4e89b4aae9 100644 --- a/p2pserver/message/types/consensus_payload.go +++ b/p2pserver/message/types/consensus_payload.go @@ -19,14 +19,11 @@ package types import ( - "bytes" "fmt" "io" "github.com/ontio/ontology-crypto/keypair" "github.com/ontio/ontology/common" - "github.com/ontio/ontology/common/log" - "github.com/ontio/ontology/common/serialization" "github.com/ontio/ontology/core/signature" "github.com/ontio/ontology/errors" ) @@ -51,27 +48,19 @@ func (this *ConsensusPayload) Hash() common.Uint256 { //Check whether header is correct func (this *ConsensusPayload) Verify() error { - buf := new(bytes.Buffer) - err := this.SerializeUnsigned(buf) - if err != nil { - return err - } - err = signature.Verify(this.Owner, buf.Bytes(), this.Signature) + sink := common.NewZeroCopySink(nil) + this.SerializationUnsigned(sink) + + err := signature.Verify(this.Owner, sink.Bytes(), this.Signature) if err != nil { - return errors.NewDetailErr(err, errors.ErrNetVerifyFail, fmt.Sprintf("signature verify error. buf:%v", buf)) + return errors.NewDetailErr(err, errors.ErrNetVerifyFail, fmt.Sprintf("signature verify error. buf:%v", sink.Bytes())) } return nil } //serialize the consensus payload func (this *ConsensusPayload) ToArray() []byte { - b := new(bytes.Buffer) - err := this.Serialize(b) - if err != nil { - log.Errorf("consensus payload serialize error in ToArray(). payload:%v", this) - return nil - } - return b.Bytes() + return common.SerializeToBytes(this) } //return inventory type @@ -92,35 +81,15 @@ func (this *ConsensusPayload) Type() common.InventoryType { } func (this *ConsensusPayload) Serialization(sink *common.ZeroCopySink) { - this.serializationUnsigned(sink) + this.SerializationUnsigned(sink) buf := keypair.SerializePublicKey(this.Owner) sink.WriteVarBytes(buf) sink.WriteVarBytes(this.Signature) } -//Serialize message payload -func (this *ConsensusPayload) Serialize(w io.Writer) error { - err := this.SerializeUnsigned(w) - if err != nil { - return err - } - buf := keypair.SerializePublicKey(this.Owner) - err = serialization.WriteVarBytes(w, buf) - if err != nil { - return errors.NewDetailErr(err, errors.ErrNetPackFail, fmt.Sprintf("write publickey error. publickey buf:%v", buf)) - } - - err = serialization.WriteVarBytes(w, this.Signature) - if err != nil { - return errors.NewDetailErr(err, errors.ErrNetPackFail, fmt.Sprintf("write Signature error. Signature:%v", this.Signature)) - } - - return nil -} - //Deserialize message payload func (this *ConsensusPayload) Deserialization(source *common.ZeroCopySource) error { - err := this.deserializationUnsigned(source) + err := this.DeserializationUnsigned(source) if err != nil { return err } @@ -148,33 +117,7 @@ func (this *ConsensusPayload) Deserialization(source *common.ZeroCopySource) err return nil } -//Deserialize message payload -func (this *ConsensusPayload) Deserialize(r io.Reader) error { - err := this.DeserializeUnsigned(r) - if err != nil { - return err - } - buf, err := serialization.ReadVarBytes(r) - if err != nil { - - return errors.NewDetailErr(err, errors.ErrNetUnPackFail, "read buf error") - } - this.Owner, err = keypair.DeserializePublicKey(buf) - if err != nil { - - return errors.NewDetailErr(err, errors.ErrNetUnPackFail, "deserialize publickey error") - } - - this.Signature, err = serialization.ReadVarBytes(r) - if err != nil { - - return errors.NewDetailErr(err, errors.ErrNetUnPackFail, "read Signature error") - } - - return err -} - -func (this *ConsensusPayload) serializationUnsigned(sink *common.ZeroCopySink) { +func (this *ConsensusPayload) SerializationUnsigned(sink *common.ZeroCopySink) { sink.WriteUint32(this.Version) sink.WriteHash(this.PrevHash) sink.WriteUint32(this.Height) @@ -183,42 +126,7 @@ func (this *ConsensusPayload) serializationUnsigned(sink *common.ZeroCopySink) { sink.WriteVarBytes(this.Data) } -//Serialize message payload -func (this *ConsensusPayload) SerializeUnsigned(w io.Writer) error { - err := serialization.WriteUint32(w, this.Version) - if err != nil { - - return errors.NewDetailErr(err, errors.ErrNetPackFail, fmt.Sprintf("write error. version:%v", this.Version)) - } - err = this.PrevHash.Serialize(w) - if err != nil { - - return errors.NewDetailErr(err, errors.ErrNetPackFail, fmt.Sprintf("serialize error. PrevHash:%v", this.PrevHash)) - } - err = serialization.WriteUint32(w, this.Height) - if err != nil { - - return errors.NewDetailErr(err, errors.ErrNetPackFail, fmt.Sprintf("write error. Height:%v", this.Height)) - } - err = serialization.WriteUint16(w, this.BookkeeperIndex) - if err != nil { - - return errors.NewDetailErr(err, errors.ErrNetPackFail, fmt.Sprintf("write error. BookkeeperIndex:%v", this.BookkeeperIndex)) - } - err = serialization.WriteUint32(w, this.Timestamp) - if err != nil { - - return errors.NewDetailErr(err, errors.ErrNetPackFail, fmt.Sprintf("write error. Timestamp:%v", this.Timestamp)) - } - err = serialization.WriteVarBytes(w, this.Data) - if err != nil { - - return errors.NewDetailErr(err, errors.ErrNetPackFail, fmt.Sprintf("write error. Data:%v", this.Data)) - } - return nil -} - -func (this *ConsensusPayload) deserializationUnsigned(source *common.ZeroCopySource) error { +func (this *ConsensusPayload) DeserializationUnsigned(source *common.ZeroCopySource) error { var irregular, eof bool this.Version, eof = source.NextUint32() this.PrevHash, eof = source.NextHash() @@ -235,47 +143,3 @@ func (this *ConsensusPayload) deserializationUnsigned(source *common.ZeroCopySou return nil } - -//Deserialize message payload -func (this *ConsensusPayload) DeserializeUnsigned(r io.Reader) error { - var err error - this.Version, err = serialization.ReadUint32(r) - if err != nil { - - return errors.NewDetailErr(err, errors.ErrNetUnPackFail, "read version error") - } - - preBlock := new(common.Uint256) - err = preBlock.Deserialize(r) - if err != nil { - - return errors.NewDetailErr(err, errors.ErrNetUnPackFail, "read preBlock error") - } - this.PrevHash = *preBlock - - this.Height, err = serialization.ReadUint32(r) - if err != nil { - - return errors.NewDetailErr(err, errors.ErrNetUnPackFail, "read Height error") - } - - this.BookkeeperIndex, err = serialization.ReadUint16(r) - if err != nil { - - return errors.NewDetailErr(err, errors.ErrNetUnPackFail, "read BookkeeperIndex error") - } - - this.Timestamp, err = serialization.ReadUint32(r) - if err != nil { - - return errors.NewDetailErr(err, errors.ErrNetUnPackFail, "read Timestamp error") - } - - this.Data, err = serialization.ReadVarBytes(r) - if err != nil { - - return errors.NewDetailErr(err, errors.ErrNetUnPackFail, "read Data error") - } - - return nil -} diff --git a/p2pserver/net/netserver/netserver.go b/p2pserver/net/netserver/netserver.go index 3c37dfecda..49052cf7b4 100644 --- a/p2pserver/net/netserver/netserver.go +++ b/p2pserver/net/netserver/netserver.go @@ -30,6 +30,7 @@ import ( "github.com/ontio/ontology/common/log" "github.com/ontio/ontology/core/ledger" "github.com/ontio/ontology/p2pserver/common" + "github.com/ontio/ontology/p2pserver/common/set" "github.com/ontio/ontology/p2pserver/message/msg_pack" "github.com/ontio/ontology/p2pserver/message/types" "github.com/ontio/ontology/p2pserver/net/protocol" @@ -53,7 +54,7 @@ type NetServer struct { base peer.PeerCom listener net.Listener NetChan chan *types.MsgPayload - ConnectingNodes + connectingNodes PeerAddrMap Np *peer.NbrPeers connectLock sync.Mutex @@ -65,19 +66,19 @@ type NetServer struct { //InConnectionRecord include all addr connected type InConnectionRecord struct { sync.RWMutex - InConnectingAddrs []string + InConnectingAddrs set.StringSet } //OutConnectionRecord include all addr accepted type OutConnectionRecord struct { sync.RWMutex - OutConnectingAddrs []string + OutConnectingAddrs set.StringSet } -//ConnectingNodes include all addr in connecting state -type ConnectingNodes struct { +//connectingNodes include all addr in connecting state +type connectingNodes struct { sync.RWMutex - ConnectingAddrs []string + ConnectingAddrs set.StringSet } //PeerAddrMap include all addr-peer list @@ -114,6 +115,10 @@ func (this *NetServer) init() error { this.Np = &peer.NbrPeers{} this.Np.Init() + this.connectingNodes.ConnectingAddrs = set.NewStringSet() + this.inConnRecord.InConnectingAddrs = set.NewStringSet() + this.outConnRecord.OutConnectingAddrs = set.NewStringSet() + return nil } @@ -415,49 +420,37 @@ func (this *NetServer) startNetAccept(listener net.Listener) { //record the peer which is going to be dialed and sent version message but not in establish state func (this *NetServer) AddOutConnectingList(addr string) (added bool) { - this.ConnectingNodes.Lock() - defer this.ConnectingNodes.Unlock() - for _, a := range this.ConnectingAddrs { - if strings.Compare(a, addr) == 0 { - return false - } + this.connectingNodes.Lock() + defer this.connectingNodes.Unlock() + if this.connectingNodes.ConnectingAddrs.Has(addr) { + return false } + log.Trace("[p2p]add to out connecting list", addr) - this.ConnectingAddrs = append(this.ConnectingAddrs, addr) + this.connectingNodes.ConnectingAddrs.Insert(addr) return true } //Remove the peer from connecting list if the connection is established func (this *NetServer) RemoveFromConnectingList(addr string) { - this.ConnectingNodes.Lock() - defer this.ConnectingNodes.Unlock() - addrs := this.ConnectingAddrs[:0] - for _, a := range this.ConnectingAddrs { - if a != addr { - addrs = append(addrs, a) - } - } + this.connectingNodes.Lock() + defer this.connectingNodes.Unlock() + this.connectingNodes.ConnectingAddrs.Delete(addr) log.Trace("[p2p]remove from out connecting list", addr) - this.ConnectingAddrs = addrs } //record the peer which is going to be dialed and sent version message but not in establish state func (this *NetServer) GetOutConnectingListLen() (count uint) { - this.ConnectingNodes.RLock() - defer this.ConnectingNodes.RUnlock() - return uint(len(this.ConnectingAddrs)) + this.connectingNodes.RLock() + defer this.connectingNodes.RUnlock() + return uint(this.connectingNodes.ConnectingAddrs.Len()) } //check peer from connecting list func (this *NetServer) IsAddrFromConnecting(addr string) bool { - this.ConnectingNodes.Lock() - defer this.ConnectingNodes.Unlock() - for _, a := range this.ConnectingAddrs { - if strings.Compare(a, addr) == 0 { - return true - } - } - return false + this.connectingNodes.Lock() + defer this.connectingNodes.Unlock() + return this.connectingNodes.ConnectingAddrs.Has(addr) } //find exist peer from addr map @@ -519,12 +512,7 @@ func (this *NetServer) GetPeerAddressCount() (count uint) { func (this *NetServer) AddInConnRecord(addr string) { this.inConnRecord.Lock() defer this.inConnRecord.Unlock() - for _, a := range this.inConnRecord.InConnectingAddrs { - if strings.Compare(a, addr) == 0 { - return - } - } - this.inConnRecord.InConnectingAddrs = append(this.inConnRecord.InConnectingAddrs, addr) + this.inConnRecord.InConnectingAddrs.Insert(addr) log.Debugf("[p2p]add in record %s", addr) } @@ -532,12 +520,8 @@ func (this *NetServer) AddInConnRecord(addr string) { func (this *NetServer) IsAddrInInConnRecord(addr string) bool { this.inConnRecord.RLock() defer this.inConnRecord.RUnlock() - for _, a := range this.inConnRecord.InConnectingAddrs { - if strings.Compare(a, addr) == 0 { - return true - } - } - return false + + return this.inConnRecord.InConnectingAddrs.Has(addr) } //IsIPInInConnRecord return result whether the IP is in inConnRecordList @@ -545,7 +529,7 @@ func (this *NetServer) IsIPInInConnRecord(ip string) bool { this.inConnRecord.RLock() defer this.inConnRecord.RUnlock() var ipRecord string - for _, addr := range this.inConnRecord.InConnectingAddrs { + for addr := range this.inConnRecord.InConnectingAddrs { ipRecord, _ = common.ParseIPAddr(addr) if 0 == strings.Compare(ipRecord, ip) { return true @@ -558,21 +542,15 @@ func (this *NetServer) IsIPInInConnRecord(ip string) bool { func (this *NetServer) RemoveFromInConnRecord(addr string) { this.inConnRecord.Lock() defer this.inConnRecord.Unlock() - addrs := []string{} - for _, a := range this.inConnRecord.InConnectingAddrs { - if strings.Compare(a, addr) != 0 { - addrs = append(addrs, a) - } - } log.Debugf("[p2p]remove in record %s", addr) - this.inConnRecord.InConnectingAddrs = addrs + this.inConnRecord.InConnectingAddrs.Delete(addr) } //GetInConnRecordLen return length of inConnRecordList func (this *NetServer) GetInConnRecordLen() int { this.inConnRecord.RLock() defer this.inConnRecord.RUnlock() - return len(this.inConnRecord.InConnectingAddrs) + return this.inConnRecord.InConnectingAddrs.Len() } //GetIpCountInInConnRecord return count of in connections with single ip @@ -581,7 +559,7 @@ func (this *NetServer) GetIpCountInInConnRecord(ip string) uint { defer this.inConnRecord.RUnlock() var count uint var ipRecord string - for _, addr := range this.inConnRecord.InConnectingAddrs { + for addr := range this.inConnRecord.InConnectingAddrs { ipRecord, _ = common.ParseIPAddr(addr) if 0 == strings.Compare(ipRecord, ip) { count++ @@ -594,12 +572,7 @@ func (this *NetServer) GetIpCountInInConnRecord(ip string) uint { func (this *NetServer) AddOutConnRecord(addr string) { this.outConnRecord.Lock() defer this.outConnRecord.Unlock() - for _, a := range this.outConnRecord.OutConnectingAddrs { - if strings.Compare(a, addr) == 0 { - return - } - } - this.outConnRecord.OutConnectingAddrs = append(this.outConnRecord.OutConnectingAddrs, addr) + this.outConnRecord.OutConnectingAddrs.Insert(addr) log.Debugf("[p2p]add out record %s", addr) } @@ -607,33 +580,21 @@ func (this *NetServer) AddOutConnRecord(addr string) { func (this *NetServer) IsAddrInOutConnRecord(addr string) bool { this.outConnRecord.RLock() defer this.outConnRecord.RUnlock() - for _, a := range this.outConnRecord.OutConnectingAddrs { - if strings.Compare(a, addr) == 0 { - return true - } - } - return false + return this.outConnRecord.OutConnectingAddrs.Has(addr) } //RemoveOutConnRecord remove out connection from outConnRecord func (this *NetServer) RemoveFromOutConnRecord(addr string) { this.outConnRecord.Lock() defer this.outConnRecord.Unlock() - addrs := []string{} - for _, a := range this.outConnRecord.OutConnectingAddrs { - if strings.Compare(a, addr) != 0 { - addrs = append(addrs, a) - } - } - log.Debugf("[p2p]remove out record %s", addr) - this.outConnRecord.OutConnectingAddrs = addrs + this.outConnRecord.OutConnectingAddrs.Delete(addr) } //GetOutConnRecordLen return length of outConnRecord func (this *NetServer) GetOutConnRecordLen() int { this.outConnRecord.RLock() defer this.outConnRecord.RUnlock() - return len(this.outConnRecord.OutConnectingAddrs) + return this.outConnRecord.OutConnectingAddrs.Len() } //AddrValid whether the addr could be connect or accept @@ -664,5 +625,4 @@ func (this *NetServer) SetOwnAddress(addr string) { log.Infof("[p2p]set own address %s", addr) this.OwnAddress = addr } - } diff --git a/p2pserver/net/netserver/netserver_test.go b/p2pserver/net/netserver/netserver_test.go index 73cc794a4a..71faa565ba 100644 --- a/p2pserver/net/netserver/netserver_test.go +++ b/p2pserver/net/netserver/netserver_test.go @@ -26,6 +26,7 @@ import ( "github.com/ontio/ontology/common/log" "github.com/ontio/ontology/p2pserver/common" "github.com/ontio/ontology/p2pserver/peer" + "github.com/stretchr/testify/require" ) func init() { @@ -116,3 +117,71 @@ func TestNetServerNbrPeer(t *testing.T) { } } + +func TestConnectingNodeAPI(t *testing.T) { + a := require.New(t) + server := NewNetServer() + + a.Equal(server.GetOutConnectingListLen(), uint(0), "fail to test GetOutConnectingListLen") + + addOK := server.AddOutConnectingList("192.168.1.1:28339") + a.Equal(server.GetOutConnectingListLen(), uint(1), "fail to test AddOutConnectingList") + a.Equal(addOK, true, "fail to test AddOutConnectingList") + + // add same + addOK = server.AddOutConnectingList("192.168.1.1:28339") + a.Equal(server.GetOutConnectingListLen(), uint(1), "fail to test AddOutConnectingList") + a.Equal(addOK, false, "fail to test AddOutConnectingList") + + // add new + server.AddOutConnectingList("192.168.2.2:2") + a.Equal(server.GetOutConnectingListLen(), uint(2), "fail to test AddOutConnectingList") + + // test exist + a.Equal(server.IsAddrFromConnecting("192.168.2.2:2"), true, "fail to test IsAddrFromConnecting") + a.Equal(server.IsAddrFromConnecting("192.168.2.3:3"), false, "fail to test IsAddrFromConnecting") + + server.RemoveFromConnectingList("192.168.2.2:2") + a.Equal(server.GetOutConnectingListLen(), uint(1), "fail to test RemoveFromConnectingList") +} + +func TestInConnAPI(t *testing.T) { + a := require.New(t) + si := NewNetServer() + server, ok := si.(*NetServer) + a.True(ok, "fail to cast P2PServer") + + a.Equal(server.GetInConnRecordLen(), int(0), "fail to test GetInConnRecordLen") + server.AddInConnRecord("192.168.1.1:1024") + a.Equal(server.GetInConnRecordLen(), int(1), "fail to test AddInConnRecord") + server.AddInConnRecord("192.168.1.1:1024") + a.Equal(server.GetInConnRecordLen(), int(1), "fail to test GetInConnRecordLen") + server.AddInConnRecord("192.168.1.2:2048") + a.Equal(server.GetInConnRecordLen(), int(2), "fail to test AddInConnRecord") + server.RemoveFromInConnRecord("192.168.1.2:2048") + a.Equal(server.GetInConnRecordLen(), int(1), "fail to test RemoveFromInConnRecord") + // same IP, different port + server.AddInConnRecord("192.168.1.1:2048") + a.Equal(server.GetInConnRecordLen(), int(2), "fail to test RemoveFromInConnRecord") + + a.Equal(server.GetIpCountInInConnRecord("192.168.1.1"), uint(2), "fail to test GetIpCountInInConnRecord") +} + +func TestOutConnAPI(t *testing.T) { + a := require.New(t) + si := NewNetServer() + server, ok := si.(*NetServer) + a.True(ok, "fail to case P2PServer") + + a.Equal(server.GetOutConnRecordLen(), int(0), "fail to test GetOutConnRecordLen") + server.AddOutConnRecord("192.168.1.1:200") + a.Equal(server.GetOutConnRecordLen(), int(1), "fail to test AddOutConnRecord") + server.AddOutConnRecord("192.168.1.1:200") + a.Equal(server.GetOutConnRecordLen(), int(1), "fail to test AddOutConnRecord") + server.AddOutConnRecord("192.168.1.1:300") + a.Equal(server.GetOutConnRecordLen(), int(2), "fail to test AddOutConnRecord") + server.RemoveFromOutConnRecord("192.168.1.1:300") + a.Equal(server.GetOutConnRecordLen(), int(1), "fail to test RemoveFromOutConnRecord") + a.Equal(server.IsAddrInOutConnRecord("192.168.1.1:300"), false, "fail to test IsAddrInOutConnRecord") + a.Equal(server.IsAddrInOutConnRecord("192.168.1.1:200"), true, "fail to test IsAddrInOutConnRecord") +} diff --git a/p2pserver/p2pserver.go b/p2pserver/p2pserver.go index f30b63f3a3..35bf48fd57 100644 --- a/p2pserver/p2pserver.go +++ b/p2pserver/p2pserver.go @@ -494,7 +494,7 @@ func (this *P2PServer) heartBeatService() { this.timeout() case <-this.quitHeartBeat: t.Stop() - break + return } } } diff --git a/smartcontract/service/native/auth/auth.go b/smartcontract/service/native/auth/auth.go index 6a79de0690..ac88f904b4 100644 --- a/smartcontract/service/native/auth/auth.go +++ b/smartcontract/service/native/auth/auth.go @@ -26,7 +26,6 @@ import ( "github.com/ontio/ontology/account" "github.com/ontio/ontology/common" "github.com/ontio/ontology/common/log" - "github.com/ontio/ontology/common/serialization" "github.com/ontio/ontology/errors" "github.com/ontio/ontology/smartcontract/service/native" "github.com/ontio/ontology/smartcontract/service/native/utils" @@ -62,8 +61,8 @@ func initContractAdmin(native *native.NativeService, contractAddr common.Address func InitContractAdmin(native *native.NativeService) ([]byte, error) { param := new(InitContractAdminParam) - rd := bytes.NewReader(native.Input) - if err := param.Deserialize(rd); err != nil { + source := common.NewZeroCopySource(native.Input) + if err := param.Deserialization(source); err != nil { return nil, fmt.Errorf("[initContractAdmin] deserialize param failed: %v", err) } cxt := native.ContextRef.CallingContext() @@ -115,8 +114,7 @@ func transfer(native *native.NativeService, contractAddr common.Address, newAdmi func Transfer(native *native.NativeService) ([]byte, error) { //deserialize param param := new(TransferParam) - rd := bytes.NewReader(native.Input) - err := param.Deserialize(rd) + err := param.Deserialization(common.NewZeroCopySource(native.Input)) if err != nil { return nil, fmt.Errorf("[transfer] deserialize param failed: %v", err) } @@ -146,8 +144,8 @@ func Transfer(native *native.NativeService) ([]byte, error) { func AssignFuncsToRole(native *native.NativeService) ([]byte, error) { //deserialize input param param := new(FuncsToRoleParam) - rd := bytes.NewReader(native.Input) - if err := param.Deserialize(rd); err != nil { + source := common.NewZeroCopySource(native.Input) + if err := param.Deserialization(source); err != nil { return nil, fmt.Errorf("[assignFuncsToRole] deserialize param failed: %v", err) } @@ -270,8 +268,8 @@ func assignToRole(native *native.NativeService, param *OntIDsToRoleParam) (bool, func AssignOntIDsToRole(native *native.NativeService) ([]byte, error) { //deserialize param param := new(OntIDsToRoleParam) - rd := bytes.NewReader(native.Input) - if err := param.Deserialize(rd); err != nil { + source := common.NewZeroCopySource(native.Input) + if err := param.Deserialization(source); err != nil { return nil, fmt.Errorf("[assignOntIDsToRole] deserialize param failed: %v", err) } @@ -460,8 +458,8 @@ func delegate(native *native.NativeService, contractAddr common.Address, from [] func Delegate(native *native.NativeService) ([]byte, error) { //deserialize param param := &DelegateParam{} - rd := bytes.NewReader(native.Input) - err := param.Deserialize(rd) + source := common.NewZeroCopySource(native.Input) + err := param.Deserialization(source) if err != nil { return nil, fmt.Errorf("[delegate] deserialize param failed: %v", err) } @@ -537,8 +535,8 @@ func withdraw(native *native.NativeService, contractAddr common.Address, initiat func Withdraw(native *native.NativeService) ([]byte, error) { //deserialize param param := &WithdrawParam{} - rd := bytes.NewReader(native.Input) - err := param.Deserialize(rd) + source := common.NewZeroCopySource(native.Input) + err := param.Deserialization(source) if err != nil { return nil, fmt.Errorf("[withdraw] deserialize param failed: %v", err) } @@ -617,8 +615,8 @@ func verifyToken(native *native.NativeService, contractAddr common.Address, call func VerifyToken(native *native.NativeService) ([]byte, error) { //deserialize param param := &VerifyTokenParam{} - rd := bytes.NewReader(native.Input) - err := param.Deserialize(rd) + source := common.NewZeroCopySource(native.Input) + err := param.Deserialization(source) if err != nil { return nil, fmt.Errorf("[verifyToken] deserialize param failed: %v", err) } @@ -640,14 +638,10 @@ func VerifyToken(native *native.NativeService) ([]byte, error) { } func verifySig(native *native.NativeService, ontID []byte, keyNo uint64) (bool, error) { - bf := new(bytes.Buffer) - if err := serialization.WriteVarBytes(bf, ontID); err != nil { - return false, err - } - if err := utils.WriteVarUint(bf, keyNo); err != nil { - return false, err - } - args := bf.Bytes() + sink := common.NewZeroCopySink(nil) + sink.WriteVarBytes(ontID) + utils.EncodeVarUint(sink, keyNo) + args := sink.Bytes() ret, err := native.NativeCall(utils.OntIDContractAddress, "verifySignature", args) if err != nil { return false, err diff --git a/smartcontract/service/native/auth/param.go b/smartcontract/service/native/auth/param.go index 62ce241b37..9c54ba7aaf 100644 --- a/smartcontract/service/native/auth/param.go +++ b/smartcontract/service/native/auth/param.go @@ -24,7 +24,6 @@ import ( "math" "github.com/ontio/ontology/common" - "github.com/ontio/ontology/common/serialization" "github.com/ontio/ontology/smartcontract/service/native/utils" ) @@ -33,17 +32,18 @@ type InitContractAdminParam struct { AdminOntID []byte } -func (this *InitContractAdminParam) Serialize(w io.Writer) error { - if err := serialization.WriteVarBytes(w, this.AdminOntID); err != nil { - return err - } - return nil +func (this *InitContractAdminParam) Serialization(sink *common.ZeroCopySink) { + sink.WriteVarBytes(this.AdminOntID) } -func (this *InitContractAdminParam) Deserialize(rd io.Reader) error { - var err error - if this.AdminOntID, err = serialization.ReadVarBytes(rd); err != nil { - return err +func (this *InitContractAdminParam) Deserialization(source *common.ZeroCopySource) error { + var irregular, eof bool + this.AdminOntID, _, irregular, eof = source.NextVarBytes() + if irregular { + return common.ErrIrregularData + } + if eof { + return io.ErrUnexpectedEOF } return nil } @@ -55,28 +55,22 @@ type TransferParam struct { KeyNo uint64 } -func (this *TransferParam) Serialize(w io.Writer) error { - if err := serializeAddress(w, this.ContractAddr); err != nil { - return err - } - if err := serialization.WriteVarBytes(w, this.NewAdminOntID); err != nil { - return err - } - if err := utils.WriteVarUint(w, this.KeyNo); err != nil { - return nil - } - return nil +func (this *TransferParam) Serialization(sink *common.ZeroCopySink) { + serializeAddress(sink, this.ContractAddr) + sink.WriteVarBytes(this.NewAdminOntID) + utils.EncodeVarUint(sink, this.KeyNo) } -func (this *TransferParam) Deserialize(rd io.Reader) error { +func (this *TransferParam) Deserialization(source *common.ZeroCopySource) error { var err error - if this.ContractAddr, err = utils.ReadAddress(rd); err != nil { + if this.ContractAddr, err = utils.DecodeAddress(source); err != nil { return err } - if this.NewAdminOntID, err = serialization.ReadVarBytes(rd); err != nil { - return err + var irregular, eof bool + if this.NewAdminOntID, _, irregular, eof = source.NextVarBytes(); irregular || eof { + return fmt.Errorf("irregular:%v, eof:%v", irregular, eof) } - if this.KeyNo, err = utils.ReadVarUint(rd); err != nil { + if this.KeyNo, err = utils.DecodeVarUint(source); err != nil { return err } return nil @@ -91,56 +85,43 @@ type FuncsToRoleParam struct { KeyNo uint64 } -func (this *FuncsToRoleParam) Serialize(w io.Writer) error { - if err := serializeAddress(w, this.ContractAddr); err != nil { - return err - } - if err := serialization.WriteVarBytes(w, this.AdminOntID); err != nil { - return err - } - if err := serialization.WriteVarBytes(w, this.Role); err != nil { - return err - } - if err := utils.WriteVarUint(w, uint64(len(this.FuncNames))); err != nil { - return err - } +func (this *FuncsToRoleParam) Serialization(sink *common.ZeroCopySink) { + utils.EncodeAddress(sink, this.ContractAddr) + sink.WriteVarBytes(this.AdminOntID) + sink.WriteVarBytes(this.Role) + utils.EncodeVarUint(sink, uint64(len(this.FuncNames))) for _, fn := range this.FuncNames { - if err := serialization.WriteString(w, fn); err != nil { - return err - } - } - if err := utils.WriteVarUint(w, this.KeyNo); err != nil { - return nil + sink.WriteString(fn) } - return nil + utils.EncodeVarUint(sink, this.KeyNo) } -func (this *FuncsToRoleParam) Deserialize(rd io.Reader) error { +func (this *FuncsToRoleParam) Deserialization(source *common.ZeroCopySource) error { var err error var fnLen uint64 var i uint64 - if this.ContractAddr, err = utils.ReadAddress(rd); err != nil { + if this.ContractAddr, err = utils.DecodeAddress(source); err != nil { return err } - if this.AdminOntID, err = serialization.ReadVarBytes(rd); err != nil { - return err + if this.AdminOntID, err = utils.DecodeVarBytes(source); err != nil { + return fmt.Errorf("AdminOntID Deserialization error: %s", err) } - if this.Role, err = serialization.ReadVarBytes(rd); err != nil { - return err + if this.Role, err = utils.DecodeVarBytes(source); err != nil { + return fmt.Errorf("Role Deserialization error: %s", err) } - if fnLen, err = utils.ReadVarUint(rd); err != nil { + if fnLen, err = utils.DecodeVarUint(source); err != nil { return err } this.FuncNames = make([]string, 0) for i = 0; i < fnLen; i++ { - fn, err := serialization.ReadString(rd) + fn, err := utils.DecodeString(source) if err != nil { - return err + return fmt.Errorf("FuncNames Deserialization error: %s", err) } this.FuncNames = append(this.FuncNames, fn) } - if this.KeyNo, err = utils.ReadVarUint(rd); err != nil { + if this.KeyNo, err = utils.DecodeVarUint(source); err != nil { return err } return nil @@ -154,54 +135,42 @@ type OntIDsToRoleParam struct { KeyNo uint64 } -func (this *OntIDsToRoleParam) Serialize(w io.Writer) error { - if err := serializeAddress(w, this.ContractAddr); err != nil { - return err - } - if err := serialization.WriteVarBytes(w, this.AdminOntID); err != nil { - return err - } - if err := serialization.WriteVarBytes(w, this.Role); err != nil { - return err - } - if err := utils.WriteVarUint(w, uint64(len(this.Persons))); err != nil { - return err - } +func (this *OntIDsToRoleParam) Serialization(sink *common.ZeroCopySink) { + serializeAddress(sink, this.ContractAddr) + sink.WriteVarBytes(this.AdminOntID) + sink.WriteVarBytes(this.Role) + + utils.EncodeVarUint(sink, uint64(len(this.Persons))) for _, p := range this.Persons { - if err := serialization.WriteVarBytes(w, p); err != nil { - return err - } + sink.WriteVarBytes(p) } - if err := utils.WriteVarUint(w, this.KeyNo); err != nil { - return nil - } - return nil + utils.EncodeVarUint(sink, this.KeyNo) } -func (this *OntIDsToRoleParam) Deserialize(rd io.Reader) error { +func (this *OntIDsToRoleParam) Deserialization(source *common.ZeroCopySource) error { var err error var pLen uint64 - if this.ContractAddr, err = utils.ReadAddress(rd); err != nil { + if this.ContractAddr, err = utils.DecodeAddress(source); err != nil { return err } - if this.AdminOntID, err = serialization.ReadVarBytes(rd); err != nil { - return err + if this.AdminOntID, err = utils.DecodeVarBytes(source); err != nil { + return fmt.Errorf("AdminOntID Deserialization error: %s", err) } - if this.Role, err = serialization.ReadVarBytes(rd); err != nil { - return err + if this.Role, err = utils.DecodeVarBytes(source); err != nil { + return fmt.Errorf("Role Deserialization error: %s", err) } - if pLen, err = utils.ReadVarUint(rd); err != nil { + if pLen, err = utils.DecodeVarUint(source); err != nil { return err } this.Persons = make([][]byte, 0) for i := uint64(0); i < pLen; i++ { - p, err := serialization.ReadVarBytes(rd) + p, err := utils.DecodeVarBytes(source) if err != nil { - return err + return fmt.Errorf("Persons Deserialization error: %s", err) } this.Persons = append(this.Persons, p) } - if this.KeyNo, err = utils.ReadVarUint(rd); err != nil { + if this.KeyNo, err = utils.DecodeVarUint(source); err != nil { return err } return nil @@ -217,53 +186,38 @@ type DelegateParam struct { KeyNo uint64 } -func (this *DelegateParam) Serialize(w io.Writer) error { - if err := serializeAddress(w, this.ContractAddr); err != nil { - return err - } - if err := serialization.WriteVarBytes(w, this.From); err != nil { - return err - } - if err := serialization.WriteVarBytes(w, this.To); err != nil { - return err - } - if err := serialization.WriteVarBytes(w, this.Role); err != nil { - return err - } - if err := utils.WriteVarUint(w, this.Period); err != nil { - return err - } - if err := utils.WriteVarUint(w, uint64(this.Level)); err != nil { - return err - } - if err := utils.WriteVarUint(w, this.KeyNo); err != nil { - return err - } - return nil +func (this *DelegateParam) Serialization(sink *common.ZeroCopySink) { + serializeAddress(sink, this.ContractAddr) + sink.WriteVarBytes(this.From) + sink.WriteVarBytes(this.To) + sink.WriteVarBytes(this.Role) + utils.EncodeVarUint(sink, this.Period) + utils.EncodeVarUint(sink, uint64(this.Level)) + utils.EncodeVarUint(sink, this.KeyNo) } -func (this *DelegateParam) Deserialize(rd io.Reader) error { +func (this *DelegateParam) Deserialization(source *common.ZeroCopySource) error { var err error var level uint64 - if this.ContractAddr, err = utils.ReadAddress(rd); err != nil { + if this.ContractAddr, err = utils.DecodeAddress(source); err != nil { return err } - if this.From, err = serialization.ReadVarBytes(rd); err != nil { - return err + if this.From, err = utils.DecodeVarBytes(source); err != nil { + return fmt.Errorf("From Deserialization error: %s", err) } - if this.To, err = serialization.ReadVarBytes(rd); err != nil { - return err + if this.To, err = utils.DecodeVarBytes(source); err != nil { + return fmt.Errorf("To Deserialization error: %s", err) } - if this.Role, err = serialization.ReadVarBytes(rd); err != nil { - return err + if this.Role, err = utils.DecodeVarBytes(source); err != nil { + return fmt.Errorf("Role Deserialization error: %s", err) } - if this.Period, err = utils.ReadVarUint(rd); err != nil { + if this.Period, err = utils.DecodeVarUint(source); err != nil { return err } - if level, err = utils.ReadVarUint(rd); err != nil { + if level, err = utils.DecodeVarUint(source); err != nil { return err } - if this.KeyNo, err = utils.ReadVarUint(rd); err != nil { + if this.KeyNo, err = utils.DecodeVarUint(source); err != nil { return err } if level > math.MaxInt8 || this.Period > math.MaxUint32 { @@ -281,39 +235,28 @@ type WithdrawParam struct { KeyNo uint64 } -func (this *WithdrawParam) Serialize(w io.Writer) error { - if err := serializeAddress(w, this.ContractAddr); err != nil { - return err - } - if err := serialization.WriteVarBytes(w, this.Initiator); err != nil { - return err - } - if err := serialization.WriteVarBytes(w, this.Delegate); err != nil { - return err - } - if err := serialization.WriteVarBytes(w, this.Role); err != nil { - return err - } - if err := utils.WriteVarUint(w, this.KeyNo); err != nil { - return err - } - return nil +func (this *WithdrawParam) Serialization(sink *common.ZeroCopySink) { + serializeAddress(sink, this.ContractAddr) + sink.WriteVarBytes(this.Initiator) + sink.WriteVarBytes(this.Delegate) + sink.WriteVarBytes(this.Role) + utils.EncodeVarUint(sink, this.KeyNo) } -func (this *WithdrawParam) Deserialize(rd io.Reader) error { +func (this *WithdrawParam) Deserialization(source *common.ZeroCopySource) error { var err error - if this.ContractAddr, err = utils.ReadAddress(rd); err != nil { + if this.ContractAddr, err = utils.DecodeAddress(source); err != nil { return err } - if this.Initiator, err = serialization.ReadVarBytes(rd); err != nil { - return err + if this.Initiator, err = utils.DecodeVarBytes(source); err != nil { + return fmt.Errorf("Initiator Deserialization error: %s", err) } - if this.Delegate, err = serialization.ReadVarBytes(rd); err != nil { - return err + if this.Delegate, err = utils.DecodeVarBytes(source); err != nil { + return fmt.Errorf("Delegate Deserialization error: %s", err) } - if this.Role, err = serialization.ReadVarBytes(rd); err != nil { - return err + if this.Role, err = utils.DecodeVarBytes(source); err != nil { + return fmt.Errorf("Role Deserialization error: %s", err) } - if this.KeyNo, err = utils.ReadVarUint(rd); err != nil { + if this.KeyNo, err = utils.DecodeVarUint(source); err != nil { return err } return nil @@ -326,34 +269,25 @@ type VerifyTokenParam struct { KeyNo uint64 } -func (this *VerifyTokenParam) Serialize(w io.Writer) error { - if err := serializeAddress(w, this.ContractAddr); err != nil { - return err - } - if err := serialization.WriteVarBytes(w, this.Caller); err != nil { - return err - } - if err := serialization.WriteString(w, this.Fn); err != nil { - return err - } - if err := utils.WriteVarUint(w, this.KeyNo); err != nil { - return err - } - return nil +func (this *VerifyTokenParam) Serialization(sink *common.ZeroCopySink) { + serializeAddress(sink, this.ContractAddr) + sink.WriteVarBytes(this.Caller) + sink.WriteString(this.Fn) + utils.EncodeVarUint(sink, this.KeyNo) } -func (this *VerifyTokenParam) Deserialize(rd io.Reader) error { +func (this *VerifyTokenParam) Deserialization(source *common.ZeroCopySource) error { var err error - if this.ContractAddr, err = utils.ReadAddress(rd); err != nil { + if this.ContractAddr, err = utils.DecodeAddress(source); err != nil { return err } - if this.Caller, err = serialization.ReadVarBytes(rd); err != nil { - return err //deserialize caller error + if this.Caller, err = utils.DecodeVarBytes(source); err != nil { + return fmt.Errorf("Caller Deserialization error: %s", err) } - if this.Fn, err = serialization.ReadString(rd); err != nil { - return err + if this.Fn, err = utils.DecodeString(source); err != nil { + return fmt.Errorf("Fn Deserialization error: %s", err) } - if this.KeyNo, err = utils.ReadVarUint(rd); err != nil { + if this.KeyNo, err = utils.DecodeVarUint(source); err != nil { return err } return nil diff --git a/smartcontract/service/native/auth/param_test.go b/smartcontract/service/native/auth/param_test.go index 4ce39204f5..49e58d01e4 100644 --- a/smartcontract/service/native/auth/param_test.go +++ b/smartcontract/service/native/auth/param_test.go @@ -24,6 +24,7 @@ import ( "github.com/stretchr/testify/assert" + "github.com/ontio/ontology/common" "github.com/ontio/ontology/smartcontract/service/native/utils" ) @@ -53,14 +54,13 @@ func TestSerialization_Init(t *testing.T) { param := &InitContractAdminParam{ AdminOntID: admin, } - bf := new(bytes.Buffer) - if err := param.Serialize(bf); err != nil { - t.Fatal(err) - } - rd := bytes.NewReader(bf.Bytes()) + sink := common.NewZeroCopySink(nil) + param.Serialization(sink) + rd := common.NewZeroCopySource(sink.Bytes()) param2 := new(InitContractAdminParam) - if err := param2.Deserialize(rd); err != nil { + + if err := param2.Deserialization(rd); err != nil { t.Fatal(err) } @@ -74,14 +74,12 @@ func TestSerialization_Transfer(t *testing.T) { ContractAddr: OntContractAddr, NewAdminOntID: newAdmin, } - bf := new(bytes.Buffer) - if err := param.Serialize(bf); err != nil { - t.Fatal(err) - } - rd := bytes.NewReader(bf.Bytes()) + sink := common.NewZeroCopySink(nil) + param.Serialization(sink) + rd := common.NewZeroCopySource(sink.Bytes()) param2 := new(TransferParam) - if err := param2.Deserialize(rd); err != nil { + if err := param2.Deserialization(rd); err != nil { t.Fatal(err) } @@ -95,14 +93,12 @@ func TestSerialization_AssignFuncs(t *testing.T) { Role: []byte("role"), FuncNames: funcs, } - bf := new(bytes.Buffer) - if err := param.Serialize(bf); err != nil { - t.Fatal(err) - } - rd := bytes.NewReader(bf.Bytes()) + bf := common.NewZeroCopySink(nil) + param.Serialization(bf) + rd := common.NewZeroCopySource(bf.Bytes()) param2 := new(FuncsToRoleParam) - if err := param2.Deserialize(rd); err != nil { + if err := param2.Deserialization(rd); err != nil { t.Fatal(err) } @@ -116,13 +112,11 @@ func TestSerialization_AssignOntIDs(t *testing.T) { Role: []byte(role), Persons: [][]byte{[]byte{0x03, 0x04, 0x05, 0x06}, []byte{0x07, 0x08, 0x09, 0x0a}}, } - bf := new(bytes.Buffer) - if err := param.Serialize(bf); err != nil { - t.Fatal(err) - } - rd := bytes.NewReader(bf.Bytes()) + bf := common.NewZeroCopySink(nil) + param.Serialization(bf) + rd := common.NewZeroCopySource(bf.Bytes()) param2 := new(OntIDsToRoleParam) - if err := param2.Deserialize(rd); err != nil { + if err := param2.Deserialization(rd); err != nil { t.Fatal(err) } @@ -138,13 +132,11 @@ func TestSerialization_Delegate(t *testing.T) { Period: 60 * 60 * 24, Level: 3, } - bf := new(bytes.Buffer) - if err := param.Serialize(bf); err != nil { - t.Fatal(err) - } - rd := bytes.NewReader(bf.Bytes()) + bf := common.NewZeroCopySink(nil) + param.Serialization(bf) + rd := common.NewZeroCopySource(bf.Bytes()) param2 := new(DelegateParam) - if err := param2.Deserialize(rd); err != nil { + if err := param2.Deserialization(rd); err != nil { t.Fatal(err) } assert.Equal(t, param, param2) @@ -157,13 +149,11 @@ func TestSerialization_Withdraw(t *testing.T) { Delegate: p2, Role: []byte(role), } - bf := new(bytes.Buffer) - if err := param.Serialize(bf); err != nil { - t.Fatal(err) - } - rd := bytes.NewReader(bf.Bytes()) + bf := common.NewZeroCopySink(nil) + param.Serialization(bf) + rd := common.NewZeroCopySource(bf.Bytes()) param2 := new(WithdrawParam) - if err := param2.Deserialize(rd); err != nil { + if err := param2.Deserialization(rd); err != nil { t.Fatal(err) } assert.Equal(t, param, param2) @@ -175,13 +165,11 @@ func TestSerialization_VerifyToken(t *testing.T) { Caller: p1, Fn: "foo1", } - bf := new(bytes.Buffer) - if err := param.Serialize(bf); err != nil { - t.Fatal(err) - } - rd := bytes.NewReader(bf.Bytes()) + bf := common.NewZeroCopySink(nil) + param.Serialization(bf) + rd := common.NewZeroCopySource(bf.Bytes()) param2 := new(VerifyTokenParam) - if err := param2.Deserialize(rd); err != nil { + if err := param2.Deserialization(rd); err != nil { t.Fatal(err) } assert.Equal(t, param, param2) diff --git a/smartcontract/service/native/auth/state.go b/smartcontract/service/native/auth/state.go index 473ad7465c..d93abb7fba 100644 --- a/smartcontract/service/native/auth/state.go +++ b/smartcontract/service/native/auth/state.go @@ -22,7 +22,8 @@ import ( "io" "strings" - "github.com/ontio/ontology/common/serialization" + "github.com/ontio/ontology/common" + "github.com/ontio/ontology/smartcontract/service/native/utils" ) /* @@ -47,28 +48,23 @@ func (this *roleFuncs) ContainsFunc(fn string) bool { return false } -func (this *roleFuncs) Serialize(w io.Writer) error { - if err := serialization.WriteUint32(w, uint32(len(this.funcNames))); err != nil { - return err - } +func (this *roleFuncs) Serialization(sink *common.ZeroCopySink) { + sink.WriteUint32(uint32(len(this.funcNames))) + this.funcNames = StringsDedupAndSort(this.funcNames) for _, fn := range this.funcNames { - if err := serialization.WriteString(w, fn); err != nil { - return err - } + sink.WriteString(fn) } - return nil } -func (this *roleFuncs) Deserialize(rd io.Reader) error { - var err error - fnLen, err := serialization.ReadUint32(rd) - if err != nil { - return err +func (this *roleFuncs) Deserialization(source *common.ZeroCopySource) error { + fnLen, eof := source.NextUint32() + if eof { + return io.ErrUnexpectedEOF } funcNames := make([]string, 0) for i := uint32(0); i < fnLen; i++ { - fn, err := serialization.ReadString(rd) + fn, err := utils.DecodeString(source) if err != nil { return err } @@ -86,33 +82,27 @@ type AuthToken struct { level uint8 } -func (this *AuthToken) Serialize(w io.Writer) error { - if err := serialization.WriteVarBytes(w, this.role); err != nil { - return err - } - if err := serialization.WriteUint32(w, this.expireTime); err != nil { - return err - } - if err := serialization.WriteUint8(w, this.level); err != nil { - return err - } - return nil +func (this *AuthToken) Serialization(sink *common.ZeroCopySink) { + sink.WriteVarBytes(this.role) + sink.WriteUint32(this.expireTime) + sink.WriteUint8(this.level) } -func (this *AuthToken) Deserialize(rd io.Reader) error { +func (this *AuthToken) Deserialization(source *common.ZeroCopySource) error { //rd := bytes.NewReader(data) var err error - this.role, err = serialization.ReadVarBytes(rd) + this.role, err = utils.DecodeVarBytes(source) if err != nil { return err } - this.expireTime, err = serialization.ReadUint32(rd) - if err != nil { - return err + var eof bool + this.expireTime, eof = source.NextUint32() + if eof { + return io.ErrUnexpectedEOF } - this.level, err = serialization.ReadUint8(rd) - if err != nil { - return err + this.level, eof = source.NextUint8() + if eof { + return io.ErrUnexpectedEOF } return nil } @@ -122,23 +112,18 @@ type DelegateStatus struct { AuthToken } -func (this *DelegateStatus) Serialize(w io.Writer) error { - if err := serialization.WriteVarBytes(w, this.root); err != nil { - return err - } - if err := this.AuthToken.Serialize(w); err != nil { - return err - } - return nil +func (this *DelegateStatus) Serialization(sink *common.ZeroCopySink) { + sink.WriteVarBytes(this.root) + this.AuthToken.Serialization(sink) } -func (this *DelegateStatus) Deserialize(rd io.Reader) error { +func (this *DelegateStatus) Deserialization(source *common.ZeroCopySource) error { var err error - this.root, err = serialization.ReadVarBytes(rd) + this.root, err = utils.DecodeVarBytes(source) if err != nil { return err } - err = this.AuthToken.Deserialize(rd) + err = this.AuthToken.Deserialization(source) return err } @@ -146,27 +131,22 @@ type Status struct { status []*DelegateStatus } -func (this *Status) Serialize(w io.Writer) error { - if err := serialization.WriteUint32(w, uint32(len(this.status))); err != nil { - return err - } +func (this *Status) Serialization(sink *common.ZeroCopySink) { + sink.WriteUint32(uint32(len(this.status))) for _, s := range this.status { - if err := s.Serialize(w); err != nil { - return err - } + s.Serialization(sink) } - return nil } -func (this *Status) Deserialize(rd io.Reader) error { - sLen, err := serialization.ReadUint32(rd) - if err != nil { - return err +func (this *Status) Deserialization(source *common.ZeroCopySource) error { + sLen, eof := source.NextUint32() + if eof { + return io.ErrUnexpectedEOF } this.status = make([]*DelegateStatus, 0) for i := uint32(0); i < sLen; i++ { s := new(DelegateStatus) - err = s.Deserialize(rd) + err := s.Deserialization(source) if err != nil { return err } @@ -179,27 +159,22 @@ type roleTokens struct { tokens []*AuthToken } -func (this *roleTokens) Serialize(w io.Writer) error { - if err := serialization.WriteUint32(w, uint32(len(this.tokens))); err != nil { - return err - } +func (this *roleTokens) Serialization(sink *common.ZeroCopySink) { + sink.WriteUint32(uint32(len(this.tokens))) for _, token := range this.tokens { - if err := token.Serialize(w); err != nil { - return err - } + token.Serialization(sink) } - return nil } -func (this *roleTokens) Deserialize(rd io.Reader) error { - tLen, err := serialization.ReadUint32(rd) - if err != nil { - return err +func (this *roleTokens) Deserialization(source *common.ZeroCopySource) error { + tLen, eof := source.NextUint32() + if eof { + return io.ErrUnexpectedEOF } this.tokens = make([]*AuthToken, 0) for i := uint32(0); i < tLen; i++ { tok := new(AuthToken) - err = tok.Deserialize(rd) + err := tok.Deserialization(source) if err != nil { return err } diff --git a/smartcontract/service/native/auth/state_test.go b/smartcontract/service/native/auth/state_test.go index 4faec346ec..3828c8f1d4 100644 --- a/smartcontract/service/native/auth/state_test.go +++ b/smartcontract/service/native/auth/state_test.go @@ -20,6 +20,7 @@ package auth import ( "bytes" + "github.com/ontio/ontology/common" "testing" ) @@ -28,13 +29,11 @@ func TestSerRoleFuncs(t *testing.T) { []string{"foo1", "foo2"}, //[]string{}, } - bf := new(bytes.Buffer) - if err := param.Serialize(bf); err != nil { - t.Fatal(err) - } - rd := bytes.NewReader(bf.Bytes()) + bf := common.NewZeroCopySink(nil) + param.Serialization(bf) + rd := common.NewZeroCopySource(bf.Bytes()) param2 := new(roleFuncs) - if err := param2.Deserialize(rd); err != nil { + if err := param2.Deserialization(rd); err != nil { t.Fatal(err) } @@ -55,13 +54,11 @@ func TestSerAuthToken(t *testing.T) { level: 2, } - bf := new(bytes.Buffer) - if err := param.Serialize(bf); err != nil { - t.Fatal(err) - } - rd := bytes.NewReader(bf.Bytes()) + bf := common.NewZeroCopySink(nil) + param.Serialization(bf) + rd := common.NewZeroCopySource(bf.Bytes()) param2 := new(AuthToken) - if err := param2.Deserialize(rd); err != nil { + if err := param2.Deserialization(rd); err != nil { t.Fatal(err) } @@ -82,13 +79,11 @@ func TestSerDelegateStatus(t *testing.T) { root: []byte{0x01, 0x02, 0x03, 0x04, 0x05}, AuthToken: *token, } - bf := new(bytes.Buffer) - if err := s1.Serialize(bf); err != nil { - t.Fatal(err) - } - rd := bytes.NewReader(bf.Bytes()) + bf := common.NewZeroCopySink(nil) + s1.Serialization(bf) + rd := common.NewZeroCopySource(bf.Bytes()) s2 := new(DelegateStatus) - if err := s2.Deserialize(rd); err != nil { + if err := s2.Deserialization(rd); err != nil { t.Fatal(err) } diff --git a/smartcontract/service/native/auth/utils.go b/smartcontract/service/native/auth/utils.go index 5e44087a5b..307762e97e 100644 --- a/smartcontract/service/native/auth/utils.go +++ b/smartcontract/service/native/auth/utils.go @@ -19,13 +19,10 @@ package auth import ( - "bytes" "fmt" - "io" "sort" "github.com/ontio/ontology/common" - "github.com/ontio/ontology/common/serialization" "github.com/ontio/ontology/smartcontract/event" "github.com/ontio/ontology/smartcontract/service/native" "github.com/ontio/ontology/smartcontract/service/native/utils" @@ -84,9 +81,9 @@ func getRoleFunc(native *native.NativeService, contractAddr common.Address, role if item == nil { //is not set return nil, nil } - rd := bytes.NewReader(item.Value) + source := common.NewZeroCopySource(item.Value) rF := new(roleFuncs) - err = rF.Deserialize(rd) + err = rF.Deserialization(source) if err != nil { return nil, fmt.Errorf("deserialize roleFuncs object failed. data: %x", item.Value) } @@ -95,12 +92,7 @@ func getRoleFunc(native *native.NativeService, contractAddr common.Address, role func putRoleFunc(native *native.NativeService, contractAddr common.Address, role []byte, funcs *roleFuncs) error { key := concatRoleFuncKey(native, contractAddr, role) - bf := new(bytes.Buffer) - err := funcs.Serialize(bf) - if err != nil { - return fmt.Errorf("serialize roleFuncs failed, caused by %v", err) - } - utils.PutBytes(native, key, bf.Bytes()) + utils.PutBytes(native, key, common.SerializeToBytes(funcs)) return nil } @@ -123,9 +115,9 @@ func getOntIDToken(native *native.NativeService, contractAddr common.Address, on if item == nil { //is not set return nil, nil } - rd := bytes.NewReader(item.Value) + source := common.NewZeroCopySource(item.Value) rT := new(roleTokens) - err = rT.Deserialize(rd) + err = rT.Deserialization(source) if err != nil { return nil, fmt.Errorf("deserialize roleTokens object failed. data: %x", item.Value) } @@ -134,12 +126,7 @@ func getOntIDToken(native *native.NativeService, contractAddr common.Address, on func putOntIDToken(native *native.NativeService, contractAddr common.Address, ontID []byte, tokens *roleTokens) error { key := concatOntIDTokenKey(native, contractAddr, ontID) - bf := new(bytes.Buffer) - err := tokens.Serialize(bf) - if err != nil { - return fmt.Errorf("serialize roleFuncs failed, caused by %v", err) - } - utils.PutBytes(native, key, bf.Bytes()) + utils.PutBytes(native, key, common.SerializeToBytes(tokens)) return nil } @@ -163,8 +150,8 @@ func getDelegateStatus(native *native.NativeService, contractAddr common.Address return nil, nil } status := new(Status) - rd := bytes.NewReader(item.Value) - err = status.Deserialize(rd) + source := common.NewZeroCopySource(item.Value) + err = status.Deserialization(source) if err != nil { return nil, fmt.Errorf("deserialize Status object failed. data: %x", item.Value) } @@ -173,12 +160,7 @@ func getDelegateStatus(native *native.NativeService, contractAddr common.Address func putDelegateStatus(native *native.NativeService, contractAddr common.Address, ontID []byte, status *Status) error { key := concatDelegateStatusKey(native, contractAddr, ontID) - bf := new(bytes.Buffer) - err := status.Serialize(bf) - if err != nil { - return fmt.Errorf("serialize Status failed, caused by %v", err) - } - utils.PutBytes(native, key, bf.Bytes()) + utils.PutBytes(native, key, common.SerializeToBytes(status)) return nil } @@ -208,10 +190,6 @@ func pushEvent(native *native.NativeService, s interface{}) { native.Notifications = append(native.Notifications, event) } -func serializeAddress(w io.Writer, addr common.Address) error { - err := serialization.WriteVarBytes(w, addr[:]) - if err != nil { - return err - } - return nil +func serializeAddress(sink *common.ZeroCopySink, addr common.Address) { + sink.WriteVarBytes(addr[:]) } diff --git a/smartcontract/service/native/global_params/global_params.go b/smartcontract/service/native/global_params/global_params.go index 1e3805921d..9610bd8d51 100644 --- a/smartcontract/service/native/global_params/global_params.go +++ b/smartcontract/service/native/global_params/global_params.go @@ -19,11 +19,9 @@ package global_params import ( - "bytes" "fmt" "github.com/ontio/ontology/common" - "github.com/ontio/ontology/common/serialization" "github.com/ontio/ontology/errors" "github.com/ontio/ontology/smartcontract/service/native" "github.com/ontio/ontology/smartcontract/service/native/utils" @@ -67,19 +65,19 @@ func ParamInit(native *native.NativeService) ([]byte, error) { } initParams := Params{} - args, err := serialization.ReadVarBytes(bytes.NewBuffer(native.Input)) + args, err := utils.DecodeVarBytes(common.NewZeroCopySource(native.Input)) if err != nil { return utils.BYTE_FALSE, errors.NewDetailErr(err, errors.ErrNoCode, "init param, read native input failed!") } - argsBuffer := bytes.NewBuffer(args) - if err := initParams.Deserialize(argsBuffer); err != nil { + source := common.NewZeroCopySource(args) + if err := initParams.Deserialization(source); err != nil { return utils.BYTE_FALSE, errors.NewDetailErr(err, errors.ErrNoCode, "init param, deserialize params failed!") } native.CacheDB.Put(generateParamKey(contract, CURRENT_VALUE), getParamStorageItem(initParams).ToArray()) native.CacheDB.Put(generateParamKey(contract, PREPARE_VALUE), getParamStorageItem(initParams).ToArray()) var admin common.Address - if admin, err = utils.ReadAddress(argsBuffer); err != nil { + if admin, err = utils.DecodeAddress(source); err != nil { return utils.BYTE_FALSE, errors.NewDetailErr(err, errors.ErrNoCode, "init param, deserialize admin failed!") } native.CacheDB.Put(generateAdminKey(contract, false), getRoleStorageItem(admin).ToArray()) @@ -90,7 +88,7 @@ func ParamInit(native *native.NativeService) ([]byte, error) { func AcceptAdmin(native *native.NativeService) ([]byte, error) { var destinationAdmin common.Address - destinationAdmin, err := utils.ReadAddress(bytes.NewBuffer(native.Input)) + destinationAdmin, err := utils.DecodeAddress(common.NewZeroCopySource(native.Input)) if err != nil { return utils.BYTE_FALSE, errors.NewErr("accept admin, deserialize admin failed!") } @@ -120,7 +118,7 @@ func TransferAdmin(native *native.NativeService) ([]byte, error) { if !native.ContextRef.CheckWitness(admin) { return utils.BYTE_FALSE, errors.NewErr("transfer admin, authentication failed!") } - destinationAdmin, err := utils.ReadAddress(bytes.NewBuffer(native.Input)) + destinationAdmin, err := utils.DecodeAddress(common.NewZeroCopySource(native.Input)) if err != nil { return utils.BYTE_FALSE, errors.NewErr("transfer admin, deserialize admin failed!") } @@ -140,7 +138,7 @@ func SetOperator(native *native.NativeService) ([]byte, error) { if !native.ContextRef.CheckWitness(admin) { return utils.BYTE_FALSE, errors.NewErr("set operator, authentication failed!") } - destinationOperator, err := utils.ReadAddress(bytes.NewBuffer(native.Input)) + destinationOperator, err := utils.DecodeAddress(common.NewZeroCopySource(native.Input)) if err != nil { return utils.BYTE_FALSE, errors.NewErr("set operator, deserialize operator failed!") } @@ -160,7 +158,7 @@ func SetGlobalParam(native *native.NativeService) ([]byte, error) { return utils.BYTE_FALSE, errors.NewErr("set param, authentication failed!") } params := Params{} - if err := params.Deserialize(bytes.NewBuffer(native.Input)); err != nil { + if err := params.Deserialization(common.NewZeroCopySource(native.Input)); err != nil { return utils.BYTE_FALSE, errors.NewErr("set param, deserialize failed!") } if len(params) == 0 { @@ -185,7 +183,7 @@ func SetGlobalParam(native *native.NativeService) ([]byte, error) { func GetGlobalParam(native *native.NativeService) ([]byte, error) { var paramNameList ParamNameList - if err := paramNameList.Deserialize(bytes.NewBuffer(native.Input)); err != nil { + if err := paramNameList.Deserialization(common.NewZeroCopySource(native.Input)); err != nil { return utils.BYTE_FALSE, errors.NewErr("get param, deserialize failed!") } if len(paramNameList) == 0 { @@ -209,12 +207,7 @@ func GetGlobalParam(native *native.NativeService) ([]byte, error) { params.SetParam(Param{Key: paramName, Value: ""}) } } - result := new(bytes.Buffer) - err = params.Serialize(result) - if err != nil { - return utils.BYTE_FALSE, errors.NewDetailErr(err, errors.ErrNoCode, "get param, serialize result error!") - } - return result.Bytes(), nil + return common.SerializeToBytes(params), nil } func CreateSnapshot(native *native.NativeService) ([]byte, error) { diff --git a/smartcontract/service/native/global_params/param_test.go b/smartcontract/service/native/global_params/param_test.go index 54098d7e27..5f97d43c15 100644 --- a/smartcontract/service/native/global_params/param_test.go +++ b/smartcontract/service/native/global_params/param_test.go @@ -19,10 +19,10 @@ package global_params import ( - "bytes" "strconv" "testing" + "github.com/ontio/ontology/common" "github.com/stretchr/testify/assert" ) @@ -33,12 +33,11 @@ func TestParams_Serialize_Deserialize(t *testing.T) { v := "value" + strconv.Itoa(i) params.SetParam(Param{k, v}) } - bf := new(bytes.Buffer) - if err := params.Serialize(bf); err != nil { - t.Fatalf("params serialize error: %v", err) - } + sink := common.NewZeroCopySink(nil) + params.Serialization(sink) deserializeParams := Params{} - if err := deserializeParams.Deserialize(bf); err != nil { + source := common.NewZeroCopySource(sink.Bytes()) + if err := deserializeParams.Deserialization(source); err != nil { t.Fatalf("params deserialize error: %v", err) } for i := 0; i < 10; i++ { @@ -55,11 +54,11 @@ func TestParamNameList_Serialize_Deserialize(t *testing.T) { for i := 0; i < 3; i++ { nameList = append(nameList, strconv.Itoa(i)) } - bf := new(bytes.Buffer) - err := nameList.Serialize(bf) - assert.Nil(t, err) + sink := common.NewZeroCopySink(nil) + nameList.Serialization(sink) deserializeNameList := ParamNameList{} - err = deserializeNameList.Deserialize(bf) + source := common.NewZeroCopySource(sink.Bytes()) + err := deserializeNameList.Deserialization(source) assert.Nil(t, err) assert.Equal(t, nameList, deserializeNameList) } diff --git a/smartcontract/service/native/global_params/states.go b/smartcontract/service/native/global_params/states.go index f62d4d9e2c..912f646f56 100644 --- a/smartcontract/service/native/global_params/states.go +++ b/smartcontract/service/native/global_params/states.go @@ -19,10 +19,8 @@ package global_params import ( - "io" - "fmt" - "github.com/ontio/ontology/common/serialization" + "github.com/ontio/ontology/common" "github.com/ontio/ontology/errors" "github.com/ontio/ontology/smartcontract/service/native/utils" ) @@ -55,63 +53,51 @@ func (params *Params) GetParam(key string) (int, Param) { return -1, Param{} } -func (params *Params) Serialize(w io.Writer) error { +func (params *Params) Serialization(sink *common.ZeroCopySink) { paramNum := len(*params) - if err := utils.WriteVarUint(w, uint64(paramNum)); err != nil { - return errors.NewDetailErr(err, errors.ErrNoCode, "param config, serialize params length error!") - } + utils.EncodeVarUint(sink, uint64(paramNum)) for _, param := range *params { - if err := serialization.WriteString(w, param.Key); err != nil { - return errors.NewDetailErr(err, errors.ErrNoCode, fmt.Sprintf("param config, serialize param key %v error!", param.Key)) - } - if err := serialization.WriteString(w, param.Value); err != nil { - return errors.NewDetailErr(err, errors.ErrNoCode, fmt.Sprintf("param config, serialize param value %v error!", param.Value)) - } + sink.WriteString(param.Key) + sink.WriteString(param.Value) } - return nil } - -func (params *Params) Deserialize(r io.Reader) error { - paramNum, err := utils.ReadVarUint(r) +func (params *Params) Deserialization(source *common.ZeroCopySource) error { + paramNum, err := utils.DecodeVarUint(source) if err != nil { return errors.NewDetailErr(err, errors.ErrNoCode, "param config, deserialize params length error!") } + for i := 0; uint64(i) < paramNum; i++ { param := Param{} - param.Key, err = serialization.ReadString(r) - if err != nil { + var irregular, eof bool + param.Key, _, irregular, eof = source.NextString() + if irregular || eof { return errors.NewDetailErr(err, errors.ErrNoCode, fmt.Sprintf("param config, deserialize param key %v error!", param.Key)) } - param.Value, err = serialization.ReadString(r) - if err != nil { + param.Value, _, irregular, eof = source.NextString() + if irregular || eof { return errors.NewDetailErr(err, errors.ErrNoCode, fmt.Sprintf("param config, deserialize param value %v error!", param.Value)) } *params = append(*params, param) } return nil } - -func (nameList *ParamNameList) Serialize(w io.Writer) error { +func (nameList *ParamNameList) Serialization(sink *common.ZeroCopySink) { nameNum := len(*nameList) - if err := utils.WriteVarUint(w, uint64(nameNum)); err != nil { - return errors.NewDetailErr(err, errors.ErrNoCode, "param config, serialize param name list length error!") - } + utils.EncodeVarUint(sink, uint64(nameNum)) for _, value := range *nameList { - if err := serialization.WriteString(w, value); err != nil { - return errors.NewDetailErr(err, errors.ErrNoCode, fmt.Sprintf("param config, serialize param name %v error!", value)) - } + sink.WriteString(value) } - return nil } -func (nameList *ParamNameList) Deserialize(r io.Reader) error { - nameNum, err := utils.ReadVarUint(r) +func (nameList *ParamNameList) Deserialization(source *common.ZeroCopySource) error { + nameNum, err := utils.DecodeVarUint(source) if err != nil { return errors.NewDetailErr(err, errors.ErrNoCode, "param config, deserialize param name list length error!") } for i := 0; uint64(i) < nameNum; i++ { - name, err := serialization.ReadString(r) - if err != nil { + name, _, irregular, eof := source.NextString() + if irregular || eof { return errors.NewDetailErr(err, errors.ErrNoCode, fmt.Sprintf("param config, deserialize param name %v error!", name)) } *nameList = append(*nameList, name) diff --git a/smartcontract/service/native/global_params/utils.go b/smartcontract/service/native/global_params/utils.go index 03fcfef426..e67f40ea9c 100644 --- a/smartcontract/service/native/global_params/utils.go +++ b/smartcontract/service/native/global_params/utils.go @@ -19,8 +19,6 @@ package global_params import ( - "bytes" - "github.com/ontio/ontology/common" "github.com/ontio/ontology/common/config" cstates "github.com/ontio/ontology/core/states" @@ -37,15 +35,13 @@ const ( ) func getRoleStorageItem(role common.Address) *cstates.StorageItem { - bf := new(bytes.Buffer) - utils.WriteAddress(bf, role) + bf := common.NewZeroCopySink(nil) + utils.EncodeAddress(bf, role) return &cstates.StorageItem{Value: bf.Bytes()} } func getParamStorageItem(params Params) *cstates.StorageItem { - bf := new(bytes.Buffer) - params.Serialize(bf) - return &cstates.StorageItem{Value: bf.Bytes()} + return &cstates.StorageItem{Value: common.SerializeToBytes(¶ms)} } func generateParamKey(contract common.Address, valueType paramType) []byte { @@ -72,8 +68,7 @@ func getStorageParam(native *native.NativeService, key []byte) (Params, error) { if err != nil || item == nil { return params, err } - bf := bytes.NewBuffer(item.Value) - err = params.Deserialize(bf) + err = params.Deserialization(common.NewZeroCopySource(item.Value)) return params, err } @@ -83,8 +78,8 @@ func GetStorageRole(native *native.NativeService, key []byte) (common.Address, e if err != nil || item == nil { return role, err } - bf := bytes.NewBuffer(item.Value) - role, err = utils.ReadAddress(bf) + bf := common.NewZeroCopySource(item.Value) + role, err = utils.DecodeAddress(bf) return role, err } diff --git a/smartcontract/service/native/governance/governance.go b/smartcontract/service/native/governance/governance.go index a5a00db7f3..e151efa038 100644 --- a/smartcontract/service/native/governance/governance.go +++ b/smartcontract/service/native/governance/governance.go @@ -22,7 +22,6 @@ package governance import ( - "bytes" "encoding/hex" "fmt" "math" @@ -30,7 +29,6 @@ import ( "github.com/ontio/ontology/common" "github.com/ontio/ontology/common/config" "github.com/ontio/ontology/common/constants" - "github.com/ontio/ontology/common/serialization" cstates "github.com/ontio/ontology/core/states" "github.com/ontio/ontology/smartcontract/service/native" "github.com/ontio/ontology/smartcontract/service/native/global_params" @@ -157,11 +155,11 @@ func RegisterGovernanceContract(native *native.NativeService) { //Init governance contract, include vbft config, global param and ontid admin. func InitConfig(native *native.NativeService) ([]byte, error) { configuration := new(config.VBFTConfig) - buf, err := serialization.ReadVarBytes(bytes.NewBuffer(native.Input)) + buf, err := utils.DecodeVarBytes(common.NewZeroCopySource(native.Input)) if err != nil { return utils.BYTE_FALSE, fmt.Errorf("serialization.ReadVarBytes, contract params deserialize error: %v", err) } - if err := configuration.Deserialize(bytes.NewBuffer(buf)); err != nil { + if err := configuration.Deserialization(common.NewZeroCopySource(buf)); err != nil { return utils.BYTE_FALSE, fmt.Errorf("deserialize, contract params deserialize error: %v", err) } contract := native.ContextRef.CurrentContext().ContractAddress @@ -334,7 +332,7 @@ func RegisterCandidateTransferFrom(native *native.NativeService) ([]byte, error) //Unregister a registered candidate node, will remove node from pool, and unfreeze deposit ont. func UnRegisterCandidate(native *native.NativeService) ([]byte, error) { params := new(UnRegisterCandidateParam) - if err := params.Deserialize(bytes.NewBuffer(native.Input)); err != nil { + if err := params.Deserialization(common.NewZeroCopySource(native.Input)); err != nil { return utils.BYTE_FALSE, fmt.Errorf("deserialize, contract params deserialize error: %v", err) } address := params.Address @@ -397,7 +395,7 @@ func UnRegisterCandidate(native *native.NativeService) ([]byte, error) { //Only approved candidate node can participate in consensus selection and get ong bonus. func ApproveCandidate(native *native.NativeService) ([]byte, error) { params := new(ApproveCandidateParam) - if err := params.Deserialize(bytes.NewBuffer(native.Input)); err != nil { + if err := params.Deserialization(common.NewZeroCopySource(native.Input)); err != nil { return utils.BYTE_FALSE, fmt.Errorf("deserialize, contract params deserialize error: %v", err) } @@ -527,7 +525,7 @@ func ApproveCandidate(native *native.NativeService) ([]byte, error) { //Only approved candidate node can participate in consensus selection and get ong bonus. func RejectCandidate(native *native.NativeService) ([]byte, error) { params := new(RejectCandidateParam) - if err := params.Deserialize(bytes.NewBuffer(native.Input)); err != nil { + if err := params.Deserialization(common.NewZeroCopySource(native.Input)); err != nil { return utils.BYTE_FALSE, fmt.Errorf("deserialize, contract params deserialize error: %v", err) } @@ -591,7 +589,7 @@ func RejectCandidate(native *native.NativeService) ([]byte, error) { //Node in black list can't be registered. func BlackNode(native *native.NativeService) ([]byte, error) { params := new(BlackNodeParam) - if err := params.Deserialize(bytes.NewBuffer(native.Input)); err != nil { + if err := params.Deserialization(common.NewZeroCopySource(native.Input)); err != nil { return utils.BYTE_FALSE, fmt.Errorf("deserialize, contract params deserialize error: %v", err) } @@ -635,12 +633,8 @@ func BlackNode(native *native.NativeService) ([]byte, error) { Address: peerPoolItem.Address, InitPos: peerPoolItem.InitPos, } - bf := new(bytes.Buffer) - if err := blackListItem.Serialize(bf); err != nil { - return utils.BYTE_FALSE, fmt.Errorf("serialize, serialize blackListItem error: %v", err) - } //put peer into black list - native.CacheDB.Put(utils.ConcatKey(contract, []byte(BLACK_LIST), peerPubkeyPrefix), cstates.GenRawStorageItem(bf.Bytes())) + native.CacheDB.Put(utils.ConcatKey(contract, []byte(BLACK_LIST), peerPubkeyPrefix), cstates.GenRawStorageItem(common.SerializeToBytes(blackListItem))) //change peerPool status if peerPoolItem.Status == ConsensusStatus { commit = true @@ -666,7 +660,7 @@ func BlackNode(native *native.NativeService) ([]byte, error) { //Remove a node from black list, allow it to be registered func WhiteNode(native *native.NativeService) ([]byte, error) { params := new(WhiteNodeParam) - if err := params.Deserialize(bytes.NewBuffer(native.Input)); err != nil { + if err := params.Deserialization(common.NewZeroCopySource(native.Input)); err != nil { return utils.BYTE_FALSE, fmt.Errorf("deserialize, contract params deserialize error: %v", err) } @@ -708,7 +702,7 @@ func WhiteNode(native *native.NativeService) ([]byte, error) { //Remove node from pool and unfreeze deposit next epoch(candidate node) / next next epoch(consensus node) func QuitNode(native *native.NativeService) ([]byte, error) { params := new(QuitNodeParam) - if err := params.Deserialize(bytes.NewBuffer(native.Input)); err != nil { + if err := params.Deserialization(common.NewZeroCopySource(native.Input)); err != nil { return utils.BYTE_FALSE, fmt.Errorf("deserialize, contract params deserialize error: %v", err) } address := params.Address @@ -801,7 +795,7 @@ func UnAuthorizeForPeer(native *native.NativeService) ([]byte, error) { PeerPubkeyList: make([]string, 0), PosList: make([]uint32, 0), } - if err := params.Deserialize(bytes.NewBuffer(native.Input)); err != nil { + if err := params.Deserialization(common.NewZeroCopySource(native.Input)); err != nil { return utils.BYTE_FALSE, fmt.Errorf("deserialize, contract params deserialize error: %v", err) } address := params.Address @@ -911,7 +905,7 @@ func Withdraw(native *native.NativeService) ([]byte, error) { PeerPubkeyList: make([]string, 0), WithdrawList: make([]uint32, 0), } - if err := params.Deserialize(bytes.NewBuffer(native.Input)); err != nil { + if err := params.Deserialization(common.NewZeroCopySource(native.Input)); err != nil { return utils.BYTE_FALSE, fmt.Errorf("deserialize, contract params deserialize error: %v", err) } address := params.Address @@ -1036,7 +1030,7 @@ func UpdateConfig(native *native.NativeService) ([]byte, error) { } configuration := new(Configuration) - if err := configuration.Deserialize(bytes.NewBuffer(native.Input)); err != nil { + if err := configuration.Deserialization(common.NewZeroCopySource(native.Input)); err != nil { return utils.BYTE_FALSE, fmt.Errorf("deserialize, deserialize configuration error: %v", err) } @@ -1124,7 +1118,7 @@ func UpdateGlobalParam(native *native.NativeService) ([]byte, error) { } globalParam := new(GlobalParam) - if err := globalParam.Deserialize(bytes.NewBuffer(native.Input)); err != nil { + if err := globalParam.Deserialization(common.NewZeroCopySource(native.Input)); err != nil { return utils.BYTE_FALSE, fmt.Errorf("deserialize, deserialize globalParam error: %v", err) } @@ -1178,7 +1172,7 @@ func UpdateGlobalParam2(native *native.NativeService) ([]byte, error) { contract := native.ContextRef.CurrentContext().ContractAddress globalParam2 := new(GlobalParam2) - if err := globalParam2.Deserialize(bytes.NewBuffer(native.Input)); err != nil { + if err := globalParam2.Deserialization(common.NewZeroCopySource(native.Input)); err != nil { return utils.BYTE_FALSE, fmt.Errorf("deserialize, deserialize globalParam2 error: %v", err) } @@ -1215,7 +1209,7 @@ func UpdateSplitCurve(native *native.NativeService) ([]byte, error) { } splitCurve := new(SplitCurve) - if err := splitCurve.Deserialize(bytes.NewBuffer(native.Input)); err != nil { + if err := splitCurve.Deserialization(common.NewZeroCopySource(native.Input)); err != nil { return utils.BYTE_FALSE, fmt.Errorf("deserialize, deserialize splitCurve error: %v", err) } contract := native.ContextRef.CurrentContext().ContractAddress @@ -1244,7 +1238,7 @@ func TransferPenalty(native *native.NativeService) ([]byte, error) { } param := new(TransferPenaltyParam) - if err := param.Deserialize(bytes.NewBuffer(native.Input)); err != nil { + if err := param.Deserialization(common.NewZeroCopySource(native.Input)); err != nil { return utils.BYTE_FALSE, fmt.Errorf("deserialize, deserialize transferPenaltyParam error: %v", err) } contract := native.ContextRef.CurrentContext().ContractAddress @@ -1260,7 +1254,7 @@ func TransferPenalty(native *native.NativeService) ([]byte, error) { //Withdraw unbounded ONG according to deposit ONT in this governance contract func WithdrawOng(native *native.NativeService) ([]byte, error) { params := new(WithdrawOngParam) - if err := params.Deserialize(bytes.NewBuffer(native.Input)); err != nil { + if err := params.Deserialization(common.NewZeroCopySource(native.Input)); err != nil { return utils.BYTE_FALSE, fmt.Errorf("deserialize, deserialize transferPenaltyParam error: %v", err) } contract := native.ContextRef.CurrentContext().ContractAddress @@ -1306,7 +1300,7 @@ func ChangeMaxAuthorization(native *native.NativeService) ([]byte, error) { return utils.BYTE_FALSE, fmt.Errorf("block num is not reached for this func") } params := new(ChangeMaxAuthorizationParam) - if err := params.Deserialize(bytes.NewBuffer(native.Input)); err != nil { + if err := params.Deserialization(common.NewZeroCopySource(native.Input)); err != nil { return utils.BYTE_FALSE, fmt.Errorf("deserialize, deserialize changeMaxAuthorizationParam error: %v", err) } @@ -1367,7 +1361,7 @@ func SetPeerCost(native *native.NativeService) ([]byte, error) { return utils.BYTE_FALSE, fmt.Errorf("block num is not reached for this func") } params := new(SetPeerCostParam) - if err := params.Deserialize(bytes.NewBuffer(native.Input)); err != nil { + if err := params.Deserialization(common.NewZeroCopySource(native.Input)); err != nil { return utils.BYTE_FALSE, fmt.Errorf("deserialize, deserialize setPeerCostParam error: %v", err) } if params.PeerCost > 100 { @@ -1421,7 +1415,7 @@ func WithdrawFee(native *native.NativeService) ([]byte, error) { return utils.BYTE_FALSE, fmt.Errorf("block num is not reached for this func") } params := new(WithdrawFeeParam) - if err := params.Deserialize(bytes.NewBuffer(native.Input)); err != nil { + if err := params.Deserialization(common.NewZeroCopySource(native.Input)); err != nil { return utils.BYTE_FALSE, fmt.Errorf("deserialize, deserialize withdrawFeeParam error: %v", err) } @@ -1470,7 +1464,7 @@ func AddInitPos(native *native.NativeService) ([]byte, error) { return utils.BYTE_FALSE, fmt.Errorf("block num is not reached for this func") } params := new(ChangeInitPosParam) - if err := params.Deserialize(bytes.NewBuffer(native.Input)); err != nil { + if err := params.Deserialization(common.NewZeroCopySource(native.Input)); err != nil { return utils.BYTE_FALSE, fmt.Errorf("deserialize, deserialize changeInitPosParam error: %v", err) } @@ -1536,7 +1530,7 @@ func ReduceInitPos(native *native.NativeService) ([]byte, error) { return utils.BYTE_FALSE, fmt.Errorf("block num is not reached for this func") } params := new(ChangeInitPosParam) - if err := params.Deserialize(bytes.NewBuffer(native.Input)); err != nil { + if err := params.Deserialization(common.NewZeroCopySource(native.Input)); err != nil { return utils.BYTE_FALSE, fmt.Errorf("deserialize, deserialize changeInitPosParam error: %v", err) } @@ -1641,7 +1635,7 @@ func SetPromisePos(native *native.NativeService) ([]byte, error) { contract := native.ContextRef.CurrentContext().ContractAddress promisePos := new(PromisePos) - if err := promisePos.Deserialize(bytes.NewBuffer(native.Input)); err != nil { + if err := promisePos.Deserialization(common.NewZeroCopySource(native.Input)); err != nil { return utils.BYTE_FALSE, fmt.Errorf("deserialize, contract params deserialize error: %v", err) } //update promise pos diff --git a/smartcontract/service/native/governance/method.go b/smartcontract/service/native/governance/method.go index 98f71138da..2ef48b6185 100644 --- a/smartcontract/service/native/governance/method.go +++ b/smartcontract/service/native/governance/method.go @@ -19,7 +19,6 @@ package governance import ( - "bytes" "encoding/hex" "fmt" "math/big" @@ -34,7 +33,7 @@ import ( func registerCandidate(native *native.NativeService, flag string) error { params := new(RegisterCandidateParam) - if err := params.Deserialize(bytes.NewBuffer(native.Input)); err != nil { + if err := params.Deserialization(common.NewZeroCopySource(native.Input)); err != nil { return fmt.Errorf("deserialize, contract params deserialize error: %v", err) } contract := native.ContextRef.CurrentContext().ContractAddress @@ -150,7 +149,7 @@ func authorizeForPeer(native *native.NativeService, flag string) error { PeerPubkeyList: make([]string, 0), PosList: make([]uint32, 0), } - if err := params.Deserialize(bytes.NewBuffer(native.Input)); err != nil { + if err := params.Deserialization(common.NewZeroCopySource(native.Input)); err != nil { return fmt.Errorf("deserialize, contract params deserialize error: %v", err) } contract := native.ContextRef.CurrentContext().ContractAddress @@ -283,7 +282,7 @@ func normalQuit(native *native.NativeService, contract common.Address, peerPoolI return fmt.Errorf("authorizeInfoStore is not available!:%v", err) } var authorizeInfo AuthorizeInfo - if err := authorizeInfo.Deserialize(bytes.NewBuffer(authorizeInfoStore)); err != nil { + if err := authorizeInfo.Deserialization(common.NewZeroCopySource(authorizeInfoStore)); err != nil { return fmt.Errorf("deserialize, deserialize authorizeInfo error: %v", err) } authorizeInfo.WithdrawUnfreezePos = authorizeInfo.ConsensusPos + authorizeInfo.CandidatePos + authorizeInfo.NewPos + @@ -354,7 +353,7 @@ func blackQuit(native *native.NativeService, contract common.Address, peerPoolIt return fmt.Errorf("authorizeInfoStore is not available!:%v", err) } var authorizeInfo AuthorizeInfo - if err := authorizeInfo.Deserialize(bytes.NewBuffer(authorizeInfoStore)); err != nil { + if err := authorizeInfo.Deserialization(common.NewZeroCopySource(authorizeInfoStore)); err != nil { return fmt.Errorf("deserialize, deserialize authorizeInfo error: %v", err) } total := authorizeInfo.ConsensusPos + authorizeInfo.CandidatePos + authorizeInfo.NewPos + authorizeInfo.WithdrawConsensusPos + @@ -406,7 +405,7 @@ func consensusToConsensus(native *native.NativeService, contract common.Address, return fmt.Errorf("authorizeInfoStore is not available!:%v", err) } var authorizeInfo AuthorizeInfo - if err := authorizeInfo.Deserialize(bytes.NewBuffer(authorizeInfoStore)); err != nil { + if err := authorizeInfo.Deserialization(common.NewZeroCopySource(authorizeInfoStore)); err != nil { return fmt.Errorf("deserialize, deserialize authorizeInfo error: %v", err) } if authorizeInfo.CandidatePos != 0 { @@ -449,7 +448,7 @@ func unConsensusToConsensus(native *native.NativeService, contract common.Addres return fmt.Errorf("authorizeInfoStore is not available!:%v", err) } var authorizeInfo AuthorizeInfo - if err := authorizeInfo.Deserialize(bytes.NewBuffer(authorizeInfoStore)); err != nil { + if err := authorizeInfo.Deserialization(common.NewZeroCopySource(authorizeInfoStore)); err != nil { return fmt.Errorf("deserialize, deserialize authorizeInfo error: %v", err) } if authorizeInfo.ConsensusPos != 0 { @@ -492,7 +491,7 @@ func consensusToUnConsensus(native *native.NativeService, contract common.Addres return fmt.Errorf("authorizeInfoStore is not available!:%v", err) } var authorizeInfo AuthorizeInfo - if err := authorizeInfo.Deserialize(bytes.NewBuffer(authorizeInfoStore)); err != nil { + if err := authorizeInfo.Deserialization(common.NewZeroCopySource(authorizeInfoStore)); err != nil { return fmt.Errorf("deserialize, deserialize authorizeInfo error: %v", err) } if authorizeInfo.CandidatePos != 0 { @@ -535,7 +534,7 @@ func unConsensusToUnConsensus(native *native.NativeService, contract common.Addr return fmt.Errorf("authorizeInfoStore is not available!:%v", err) } var authorizeInfo AuthorizeInfo - if err := authorizeInfo.Deserialize(bytes.NewBuffer(authorizeInfoStore)); err != nil { + if err := authorizeInfo.Deserialization(common.NewZeroCopySource(authorizeInfoStore)); err != nil { return fmt.Errorf("deserialize, deserialize authorizeInfo error: %v", err) } if authorizeInfo.ConsensusPos != 0 { @@ -1353,7 +1352,7 @@ func splitNodeFee(native *native.NativeService, contract common.Address, peerPub return fmt.Errorf("authorizeInfoStore is not available!:%v", err) } var authorizeInfo AuthorizeInfo - if err := authorizeInfo.Deserialize(bytes.NewBuffer(authorizeInfoStore)); err != nil { + if err := authorizeInfo.Deserialization(common.NewZeroCopySource(authorizeInfoStore)); err != nil { return fmt.Errorf("deserialize, deserialize authorizeInfo error: %v", err) } diff --git a/smartcontract/service/native/governance/param.go b/smartcontract/service/native/governance/param.go index 12131de313..ec05dee8dd 100644 --- a/smartcontract/service/native/governance/param.go +++ b/smartcontract/service/native/governance/param.go @@ -20,11 +20,9 @@ package governance import ( "fmt" - "io" "math" "github.com/ontio/ontology/common" - "github.com/ontio/ontology/common/serialization" "github.com/ontio/ontology/smartcontract/service/native/utils" ) @@ -36,46 +34,35 @@ type RegisterCandidateParam struct { KeyNo uint32 } -func (this *RegisterCandidateParam) Serialize(w io.Writer) error { - if err := serialization.WriteString(w, this.PeerPubkey); err != nil { - return fmt.Errorf("serialization.WriteString, request peerPubkey error: %v", err) - } - if err := serialization.WriteVarBytes(w, this.Address[:]); err != nil { - return fmt.Errorf("serialization.WriteVarBytes, address address error: %v", err) - } - if err := utils.WriteVarUint(w, uint64(this.InitPos)); err != nil { - return fmt.Errorf("utils.WriteVarUint, serialize initPos error: %v", err) - } - if err := serialization.WriteVarBytes(w, this.Caller); err != nil { - return fmt.Errorf("serialization.WriteVarBytes, serialize caller error: %v", err) - } - if err := utils.WriteVarUint(w, uint64(this.KeyNo)); err != nil { - return fmt.Errorf("utils.WriteVarUint, serialize keyNo error: %v", err) - } - return nil +func (this *RegisterCandidateParam) Serialization(sink *common.ZeroCopySink) { + sink.WriteString(this.PeerPubkey) + sink.WriteVarBytes(this.Address[:]) + utils.EncodeVarUint(sink, uint64(this.InitPos)) + sink.WriteVarBytes(this.Caller) + utils.EncodeVarUint(sink, uint64(this.KeyNo)) } -func (this *RegisterCandidateParam) Deserialize(r io.Reader) error { - peerPubkey, err := serialization.ReadString(r) - if err != nil { - return fmt.Errorf("serialization.ReadString, deserialize peerPubkey error: %v", err) +func (this *RegisterCandidateParam) Deserialization(source *common.ZeroCopySource) error { + peerPubkey, _, irregular, eof := source.NextString() + if irregular || eof { + return fmt.Errorf("serialization.ReadString, deserialize peerPubkey irregular: %v, eof: %v", irregular, eof) } - address, err := utils.ReadAddress(r) + address, err := utils.DecodeAddress(source) if err != nil { return fmt.Errorf("utils.ReadAddress, deserialize address error: %v", err) } - initPos, err := utils.ReadVarUint(r) + initPos, err := utils.DecodeVarUint(source) if err != nil { return fmt.Errorf("utils.ReadVarUint, deserialize initPos error: %v", err) } if initPos > math.MaxUint32 { return fmt.Errorf("initPos larger than max of uint32") } - caller, err := serialization.ReadVarBytes(r) - if err != nil { - return fmt.Errorf("serialization.ReadVarBytes, deserialize caller error: %v", err) + caller, _, irregular, eof := source.NextVarBytes() + if irregular || eof { + return fmt.Errorf("serialization.ReadVarBytes, deserialize caller irregular: %v, eof: %v", irregular, eof) } - keyNo, err := utils.ReadVarUint(r) + keyNo, err := utils.DecodeVarUint(source) if err != nil { return fmt.Errorf("utils.ReadVarUint, deserialize keyNo error: %v", err) } @@ -95,22 +82,17 @@ type UnRegisterCandidateParam struct { Address common.Address } -func (this *UnRegisterCandidateParam) Serialize(w io.Writer) error { - if err := serialization.WriteString(w, this.PeerPubkey); err != nil { - return fmt.Errorf("serialization.WriteString, request peerPubkey error: %v", err) - } - if err := serialization.WriteVarBytes(w, this.Address[:]); err != nil { - return fmt.Errorf("serialization.WriteVarBytes, address address error: %v", err) - } - return nil +func (this *UnRegisterCandidateParam) Serialization(sink *common.ZeroCopySink) { + sink.WriteString(this.PeerPubkey) + sink.WriteVarBytes(this.Address[:]) } -func (this *UnRegisterCandidateParam) Deserialize(r io.Reader) error { - peerPubkey, err := serialization.ReadString(r) +func (this *UnRegisterCandidateParam) Deserialization(source *common.ZeroCopySource) error { + peerPubkey, err := utils.DecodeString(source) if err != nil { return fmt.Errorf("serialization.ReadString, deserialize peerPubkey error: %v", err) } - address, err := utils.ReadAddress(r) + address, err := utils.DecodeAddress(source) if err != nil { return fmt.Errorf("utils.ReadAddress, deserialize address error: %v", err) } @@ -124,22 +106,17 @@ type QuitNodeParam struct { Address common.Address } -func (this *QuitNodeParam) Serialize(w io.Writer) error { - if err := serialization.WriteString(w, this.PeerPubkey); err != nil { - return fmt.Errorf("serialization.WriteString, deserialize peerPubkey error: %v", err) - } - if err := serialization.WriteVarBytes(w, this.Address[:]); err != nil { - return fmt.Errorf("serialization.WriteVarBytes, address address error: %v", err) - } - return nil +func (this *QuitNodeParam) Serialization(sink *common.ZeroCopySink) { + sink.WriteString(this.PeerPubkey) + sink.WriteVarBytes(this.Address[:]) } -func (this *QuitNodeParam) Deserialize(r io.Reader) error { - peerPubkey, err := serialization.ReadString(r) +func (this *QuitNodeParam) Deserialization(source *common.ZeroCopySource) error { + peerPubkey, err := utils.DecodeString(source) if err != nil { return fmt.Errorf("serialization.ReadString, deserialize peerPubkey error: %v", err) } - address, err := utils.ReadAddress(r) + address, err := utils.DecodeAddress(source) if err != nil { return fmt.Errorf("utils.ReadAddress, deserialize address error: %v", err) } @@ -152,15 +129,12 @@ type ApproveCandidateParam struct { PeerPubkey string } -func (this *ApproveCandidateParam) Serialize(w io.Writer) error { - if err := serialization.WriteString(w, this.PeerPubkey); err != nil { - return fmt.Errorf("serialization.WriteString, serialize peerPubkey error: %v", err) - } - return nil +func (this *ApproveCandidateParam) Serialization(sink *common.ZeroCopySink) { + sink.WriteString(this.PeerPubkey) } -func (this *ApproveCandidateParam) Deserialize(r io.Reader) error { - peerPubkey, err := serialization.ReadString(r) +func (this *ApproveCandidateParam) Deserialization(source *common.ZeroCopySource) error { + peerPubkey, err := utils.DecodeString(source) if err != nil { return fmt.Errorf("serialization.ReadString, deserialize peerPubkey error: %v", err) } @@ -172,15 +146,12 @@ type RejectCandidateParam struct { PeerPubkey string } -func (this *RejectCandidateParam) Serialize(w io.Writer) error { - if err := serialization.WriteString(w, this.PeerPubkey); err != nil { - return fmt.Errorf("serialization.WriteString, serialize peerPubkey error: %v", err) - } - return nil +func (this *RejectCandidateParam) Serialization(sink *common.ZeroCopySink) { + sink.WriteString(this.PeerPubkey) } -func (this *RejectCandidateParam) Deserialize(r io.Reader) error { - peerPubkey, err := serialization.ReadString(r) +func (this *RejectCandidateParam) Deserialization(source *common.ZeroCopySource) error { + peerPubkey, err := utils.DecodeString(source) if err != nil { return fmt.Errorf("serialization.ReadString, deserialize peerPubkey error: %v", err) } @@ -192,28 +163,23 @@ type BlackNodeParam struct { PeerPubkeyList []string } -func (this *BlackNodeParam) Serialize(w io.Writer) error { - if err := utils.WriteVarUint(w, uint64(len(this.PeerPubkeyList))); err != nil { - return fmt.Errorf("serialization.WriteVarUint, serialize peerPubkeyList length error: %v", err) - } +func (this *BlackNodeParam) Serialization(sink *common.ZeroCopySink) { + utils.EncodeVarUint(sink, uint64(len(this.PeerPubkeyList))) for _, v := range this.PeerPubkeyList { - if err := serialization.WriteString(w, v); err != nil { - return fmt.Errorf("serialization.WriteString, serialize peerPubkey error: %v", err) - } + sink.WriteString(v) } - return nil } -func (this *BlackNodeParam) Deserialize(r io.Reader) error { - n, err := utils.ReadVarUint(r) +func (this *BlackNodeParam) Deserialization(source *common.ZeroCopySource) error { + n, err := utils.DecodeVarUint(source) if err != nil { return fmt.Errorf("serialization.ReadVarUint, deserialize peerPubkeyList length error: %v", err) } peerPubkeyList := make([]string, 0) for i := 0; uint64(i) < n; i++ { - k, err := serialization.ReadString(r) - if err != nil { - return fmt.Errorf("serialization.ReadString, deserialize peerPubkey error: %v", err) + k, _, irregular, eof := source.NextString() + if irregular || eof { + return fmt.Errorf("serialization.ReadString, deserialize peerPubkey irregular:%v, eof: %v", irregular, eof) } peerPubkeyList = append(peerPubkeyList, k) } @@ -225,15 +191,12 @@ type WhiteNodeParam struct { PeerPubkey string } -func (this *WhiteNodeParam) Serialize(w io.Writer) error { - if err := serialization.WriteString(w, this.PeerPubkey); err != nil { - return fmt.Errorf("serialization.WriteString, serialize peerPubkey error: %v", err) - } - return nil +func (this *WhiteNodeParam) Serialization(sink *common.ZeroCopySink) { + sink.WriteString(this.PeerPubkey) } -func (this *WhiteNodeParam) Deserialize(r io.Reader) error { - peerPubkey, err := serialization.ReadString(r) +func (this *WhiteNodeParam) Deserialization(source *common.ZeroCopySource) error { + peerPubkey, err := utils.DecodeString(source) if err != nil { return fmt.Errorf("serialization.ReadString, deserialize peerPubkey error: %v", err) } @@ -247,41 +210,31 @@ type AuthorizeForPeerParam struct { PosList []uint32 } -func (this *AuthorizeForPeerParam) Serialize(w io.Writer) error { +func (this *AuthorizeForPeerParam) Serialization(sink *common.ZeroCopySink) error { if len(this.PeerPubkeyList) > 1024 { return fmt.Errorf("length of input list > 1024") } if len(this.PeerPubkeyList) != len(this.PosList) { return fmt.Errorf("length of PeerPubkeyList != length of PosList") } - if err := serialization.WriteVarBytes(w, this.Address[:]); err != nil { - return fmt.Errorf("serialization.WriteVarBytes, address address error: %v", err) - } - if err := utils.WriteVarUint(w, uint64(len(this.PeerPubkeyList))); err != nil { - return fmt.Errorf("serialization.WriteVarUint, serialize peerPubkeyList length error: %v", err) - } + sink.WriteVarBytes(this.Address[:]) + utils.EncodeVarUint(sink, uint64(len(this.PeerPubkeyList))) for _, v := range this.PeerPubkeyList { - if err := serialization.WriteString(w, v); err != nil { - return fmt.Errorf("serialization.WriteString, serialize peerPubkey error: %v", err) - } - } - if err := utils.WriteVarUint(w, uint64(len(this.PosList))); err != nil { - return fmt.Errorf("serialization.WriteVarUint, serialize posList length error: %v", err) + sink.WriteString(v) } + utils.EncodeVarUint(sink, uint64(len(this.PosList))) for _, v := range this.PosList { - if err := utils.WriteVarUint(w, uint64(v)); err != nil { - return fmt.Errorf("utils.WriteVarUint, serialize pos error: %v", err) - } + utils.EncodeVarUint(sink, uint64(v)) } return nil } -func (this *AuthorizeForPeerParam) Deserialize(r io.Reader) error { - address, err := utils.ReadAddress(r) +func (this *AuthorizeForPeerParam) Deserialization(source *common.ZeroCopySource) error { + address, err := utils.DecodeAddress(source) if err != nil { return fmt.Errorf("utils.ReadAddress, deserialize address error: %v", err) } - n, err := utils.ReadVarUint(r) + n, err := utils.DecodeVarUint(source) if err != nil { return fmt.Errorf("serialization.ReadVarUint, deserialize peerPubkeyList length error: %v", err) } @@ -290,19 +243,19 @@ func (this *AuthorizeForPeerParam) Deserialize(r io.Reader) error { } peerPubkeyList := make([]string, 0) for i := 0; uint64(i) < n; i++ { - k, err := serialization.ReadString(r) - if err != nil { - return fmt.Errorf("serialization.ReadString, deserialize peerPubkey error: %v", err) + k, _, irregular, eof := source.NextString() + if irregular || eof { + return fmt.Errorf("serialization.ReadString, deserialize peerPubkey irregular: %v,eof: %v", irregular, eof) } peerPubkeyList = append(peerPubkeyList, k) } - m, err := utils.ReadVarUint(r) + m, err := utils.DecodeVarUint(source) if err != nil { return fmt.Errorf("serialization.ReadVarUint, deserialize posList length error: %v", err) } posList := make([]uint32, 0) for i := 0; uint64(i) < m; i++ { - k, err := utils.ReadVarUint(r) + k, err := utils.DecodeVarUint(source) if err != nil { return fmt.Errorf("utils.ReadVarUint, deserialize pos error: %v", err) } @@ -326,41 +279,32 @@ type WithdrawParam struct { WithdrawList []uint32 } -func (this *WithdrawParam) Serialize(w io.Writer) error { +func (this *WithdrawParam) Serialization(sink *common.ZeroCopySink) error { if len(this.PeerPubkeyList) > 1024 { return fmt.Errorf("length of input list > 1024") } if len(this.PeerPubkeyList) != len(this.WithdrawList) { return fmt.Errorf("length of PeerPubkeyList != length of WithdrawList, contract params error") } - if err := serialization.WriteVarBytes(w, this.Address[:]); err != nil { - return fmt.Errorf("serialization.WriteVarBytes, address address error: %v", err) - } - if err := utils.WriteVarUint(w, uint64(len(this.PeerPubkeyList))); err != nil { - return fmt.Errorf("serialization.WriteVarUint, serialize peerPubkeyList length error: %v", err) - } + sink.WriteVarBytes(this.Address[:]) + + utils.EncodeVarUint(sink, uint64(len(this.PeerPubkeyList))) for _, v := range this.PeerPubkeyList { - if err := serialization.WriteString(w, v); err != nil { - return fmt.Errorf("serialization.WriteString, serialize peerPubkey error: %v", err) - } - } - if err := utils.WriteVarUint(w, uint64(len(this.WithdrawList))); err != nil { - return fmt.Errorf("serialization.WriteVarUint, serialize withdrawList length error: %v", err) + sink.WriteString(v) } + utils.EncodeVarUint(sink, uint64(len(this.WithdrawList))) for _, v := range this.WithdrawList { - if err := utils.WriteVarUint(w, uint64(v)); err != nil { - return fmt.Errorf("utils.WriteVarUint, serialize withdraw error: %v", err) - } + utils.EncodeVarUint(sink, uint64(v)) } return nil } -func (this *WithdrawParam) Deserialize(r io.Reader) error { - address, err := utils.ReadAddress(r) +func (this *WithdrawParam) Deserialization(source *common.ZeroCopySource) error { + address, err := utils.DecodeAddress(source) if err != nil { return fmt.Errorf("utils.ReadAddress, deserialize address error: %v", err) } - n, err := utils.ReadVarUint(r) + n, err := utils.DecodeVarUint(source) if err != nil { return fmt.Errorf("serialization.ReadVarUint, deserialize peerPubkeyList length error: %v", err) } @@ -369,19 +313,19 @@ func (this *WithdrawParam) Deserialize(r io.Reader) error { } peerPubkeyList := make([]string, 0) for i := 0; uint64(i) < n; i++ { - k, err := serialization.ReadString(r) + k, err := utils.DecodeString(source) if err != nil { return fmt.Errorf("serialization.ReadString, deserialize peerPubkey error: %v", err) } peerPubkeyList = append(peerPubkeyList, k) } - m, err := utils.ReadVarUint(r) + m, err := utils.DecodeVarUint(source) if err != nil { return fmt.Errorf("serialization.ReadVarUint, deserialize withdrawList length error: %v", err) } withdrawList := make([]uint32, 0) for i := 0; uint64(i) < m; i++ { - k, err := utils.ReadVarUint(r) + k, err := utils.DecodeVarUint(source) if err != nil { return fmt.Errorf("utils.ReadVarUint, deserialize withdraw error: %v", err) } @@ -410,64 +354,47 @@ type Configuration struct { MaxBlockChangeView uint32 } -func (this *Configuration) Serialize(w io.Writer) error { - if err := utils.WriteVarUint(w, uint64(this.N)); err != nil { - return fmt.Errorf("utils.WriteVarUint, serialize n error: %v", err) - } - if err := utils.WriteVarUint(w, uint64(this.C)); err != nil { - return fmt.Errorf("utils.WriteVarUint, serialize c error: %v", err) - } - if err := utils.WriteVarUint(w, uint64(this.K)); err != nil { - return fmt.Errorf("utils.WriteVarUint, serialize k error: %v", err) - } - if err := utils.WriteVarUint(w, uint64(this.L)); err != nil { - return fmt.Errorf("utils.WriteVarUint, serialize l error: %v", err) - } - if err := utils.WriteVarUint(w, uint64(this.BlockMsgDelay)); err != nil { - return fmt.Errorf("utils.WriteVarUint, serialize block_msg_delay error: %v", err) - } - if err := utils.WriteVarUint(w, uint64(this.HashMsgDelay)); err != nil { - return fmt.Errorf("utils.WriteVarUint, serialize hash_msg_delay error: %v", err) - } - if err := utils.WriteVarUint(w, uint64(this.PeerHandshakeTimeout)); err != nil { - return fmt.Errorf("utils.WriteVarUint, serialize peer_handshake_timeout error: %v", err) - } - if err := utils.WriteVarUint(w, uint64(this.MaxBlockChangeView)); err != nil { - return fmt.Errorf("utils.WriteVarUint, serialize max_block_change_view error: %v", err) - } - return nil +func (this *Configuration) Serialization(sink *common.ZeroCopySink) { + utils.EncodeVarUint(sink, uint64(this.N)) + utils.EncodeVarUint(sink, uint64(this.C)) + utils.EncodeVarUint(sink, uint64(this.K)) + utils.EncodeVarUint(sink, uint64(this.L)) + utils.EncodeVarUint(sink, uint64(this.BlockMsgDelay)) + utils.EncodeVarUint(sink, uint64(this.HashMsgDelay)) + utils.EncodeVarUint(sink, uint64(this.PeerHandshakeTimeout)) + utils.EncodeVarUint(sink, uint64(this.MaxBlockChangeView)) } -func (this *Configuration) Deserialize(r io.Reader) error { - n, err := utils.ReadVarUint(r) +func (this *Configuration) Deserialization(source *common.ZeroCopySource) error { + n, err := utils.DecodeVarUint(source) if err != nil { return fmt.Errorf("utils.ReadVarUint, deserialize n error: %v", err) } - c, err := utils.ReadVarUint(r) + c, err := utils.DecodeVarUint(source) if err != nil { return fmt.Errorf("utils.ReadVarUint, deserialize c error: %v", err) } - k, err := utils.ReadVarUint(r) + k, err := utils.DecodeVarUint(source) if err != nil { return fmt.Errorf("utils.ReadVarUint, deserialize k error: %v", err) } - l, err := utils.ReadVarUint(r) + l, err := utils.DecodeVarUint(source) if err != nil { return fmt.Errorf("utils.ReadVarUint, deserialize l error: %v", err) } - blockMsgDelay, err := utils.ReadVarUint(r) + blockMsgDelay, err := utils.DecodeVarUint(source) if err != nil { return fmt.Errorf("utils.ReadVarUint, deserialize blockMsgDelay error: %v", err) } - hashMsgDelay, err := utils.ReadVarUint(r) + hashMsgDelay, err := utils.DecodeVarUint(source) if err != nil { return fmt.Errorf("utils.ReadVarUint, deserialize hashMsgDelay error: %v", err) } - peerHandshakeTimeout, err := utils.ReadVarUint(r) + peerHandshakeTimeout, err := utils.DecodeVarUint(source) if err != nil { return fmt.Errorf("utils.ReadVarUint, deserialize peerHandshakeTimeout error: %v", err) } - maxBlockChangeView, err := utils.ReadVarUint(r) + maxBlockChangeView, err := utils.DecodeVarUint(source) if err != nil { return fmt.Errorf("utils.ReadVarUint, deserialize maxBlockChangeView error: %v", err) } @@ -511,23 +438,18 @@ type PreConfig struct { SetView uint32 } -func (this *PreConfig) Serialize(w io.Writer) error { - if err := this.Configuration.Serialize(w); err != nil { - return fmt.Errorf("utils.WriteVarUint, serialize configuration error: %v", err) - } - if err := utils.WriteVarUint(w, uint64(this.SetView)); err != nil { - return fmt.Errorf("utils.WriteVarUint, serialize setView error: %v", err) - } - return nil +func (this *PreConfig) Serialization(sink *common.ZeroCopySink) { + this.Configuration.Serialization(sink) + utils.EncodeVarUint(sink, uint64(this.SetView)) } -func (this *PreConfig) Deserialize(r io.Reader) error { +func (this *PreConfig) Deserialization(source *common.ZeroCopySource) error { config := new(Configuration) - err := config.Deserialize(r) + err := config.Deserialization(source) if err != nil { return fmt.Errorf("utils.ReadVarUint, deserialize configuration error: %v", err) } - setView, err := utils.ReadVarUint(r) + setView, err := utils.DecodeVarUint(source) if err != nil { return fmt.Errorf("utils.ReadVarUint, deserialize setView error: %v", err) } @@ -550,64 +472,47 @@ type GlobalParam struct { Penalty uint32 //authorize pos penalty percentage } -func (this *GlobalParam) Serialize(w io.Writer) error { - if err := utils.WriteVarUint(w, this.CandidateFee); err != nil { - return fmt.Errorf("utils.WriteVarUint, serialize candidateFee error: %v", err) - } - if err := utils.WriteVarUint(w, uint64(this.MinInitStake)); err != nil { - return fmt.Errorf("utils.WriteVarUint, serialize minInitStake error: %v", err) - } - if err := utils.WriteVarUint(w, uint64(this.CandidateNum)); err != nil { - return fmt.Errorf("utils.WriteVarUint, serialize candidateNum error: %v", err) - } - if err := utils.WriteVarUint(w, uint64(this.PosLimit)); err != nil { - return fmt.Errorf("utils.WriteVarUint, serialize posLimit error: %v", err) - } - if err := utils.WriteVarUint(w, uint64(this.A)); err != nil { - return fmt.Errorf("utils.WriteVarUint, serialize a error: %v", err) - } - if err := utils.WriteVarUint(w, uint64(this.B)); err != nil { - return fmt.Errorf("utils.WriteVarUint, serialize b error: %v", err) - } - if err := utils.WriteVarUint(w, uint64(this.Yita)); err != nil { - return fmt.Errorf("utils.WriteVarUint, serialize yita error: %v", err) - } - if err := utils.WriteVarUint(w, uint64(this.Penalty)); err != nil { - return fmt.Errorf("utils.WriteVarUint, serialize penalty error: %v", err) - } - return nil +func (this *GlobalParam) Serialization(sink *common.ZeroCopySink) { + utils.EncodeVarUint(sink, this.CandidateFee) + utils.EncodeVarUint(sink, uint64(this.MinInitStake)) + utils.EncodeVarUint(sink, uint64(this.CandidateNum)) + utils.EncodeVarUint(sink, uint64(this.PosLimit)) + utils.EncodeVarUint(sink, uint64(this.A)) + utils.EncodeVarUint(sink, uint64(this.B)) + utils.EncodeVarUint(sink, uint64(this.Yita)) + utils.EncodeVarUint(sink, uint64(this.Penalty)) } -func (this *GlobalParam) Deserialize(r io.Reader) error { - candidateFee, err := utils.ReadVarUint(r) +func (this *GlobalParam) Deserialization(source *common.ZeroCopySource) error { + candidateFee, err := utils.DecodeVarUint(source) if err != nil { return fmt.Errorf("utils.ReadVarUint, deserialize candidateFee error: %v", err) } - minInitStake, err := utils.ReadVarUint(r) + minInitStake, err := utils.DecodeVarUint(source) if err != nil { return fmt.Errorf("utils.ReadVarUint, deserialize minInitStake error: %v", err) } - candidateNum, err := utils.ReadVarUint(r) + candidateNum, err := utils.DecodeVarUint(source) if err != nil { return fmt.Errorf("utils.ReadVarUint, deserialize candidateNum error: %v", err) } - posLimit, err := utils.ReadVarUint(r) + posLimit, err := utils.DecodeVarUint(source) if err != nil { return fmt.Errorf("utils.ReadVarUint, deserialize posLimit error: %v", err) } - a, err := utils.ReadVarUint(r) + a, err := utils.DecodeVarUint(source) if err != nil { return fmt.Errorf("utils.ReadVarUint, deserialize a error: %v", err) } - b, err := utils.ReadVarUint(r) + b, err := utils.DecodeVarUint(source) if err != nil { return fmt.Errorf("utils.ReadVarUint, deserialize b error: %v", err) } - yita, err := utils.ReadVarUint(r) + yita, err := utils.DecodeVarUint(source) if err != nil { return fmt.Errorf("utils.ReadVarUint, deserialize yita error: %v", err) } - penalty, err := utils.ReadVarUint(r) + penalty, err := utils.DecodeVarUint(source) if err != nil { return fmt.Errorf("utils.ReadVarUint, deserialize penalty error: %v", err) } @@ -654,70 +559,55 @@ type GlobalParam2 struct { Field6 []byte //reserved field } -func (this *GlobalParam2) Serialize(w io.Writer) error { +func (this *GlobalParam2) Serialization(sink *common.ZeroCopySink) error { if this.MinAuthorizePos == 0 { return fmt.Errorf("globalParam2.MinAuthorizePos can not be 0") } if this.DappFee > 100 { return fmt.Errorf("globalParam2.DappFee must <= 100") } - if err := utils.WriteVarUint(w, uint64(this.MinAuthorizePos)); err != nil { - return fmt.Errorf("utils.WriteVarUint, serialize minAuthorizePos error: %v", err) - } - if err := utils.WriteVarUint(w, uint64(this.CandidateFeeSplitNum)); err != nil { - return fmt.Errorf("utils.WriteVarUint, serialize candidateFeeSplitNum error: %v", err) - } - if err := utils.WriteVarUint(w, uint64(this.DappFee)); err != nil { - return fmt.Errorf("utils.WriteVarUint, serialize dappFee error: %v", err) - } - if err := serialization.WriteVarBytes(w, this.Field2); err != nil { - return fmt.Errorf("serialization.WriteVarBytes, serialize field2 error: %v", err) - } - if err := serialization.WriteVarBytes(w, this.Field3); err != nil { - return fmt.Errorf("serialization.WriteVarBytes, serialize field3 error: %v", err) - } - if err := serialization.WriteVarBytes(w, this.Field4); err != nil { - return fmt.Errorf("serialization.WriteVarBytes, serialize field4 error: %v", err) - } - if err := serialization.WriteVarBytes(w, this.Field5); err != nil { - return fmt.Errorf("serialization.WriteVarBytes, serialize field5 error: %v", err) - } - if err := serialization.WriteVarBytes(w, this.Field6); err != nil { - return fmt.Errorf("serialization.WriteVarBytes, serialize field6 error: %v", err) - } + utils.EncodeVarUint(sink, uint64(this.MinAuthorizePos)) + utils.EncodeVarUint(sink, uint64(this.CandidateFeeSplitNum)) + + utils.EncodeVarUint(sink, uint64(this.DappFee)) + sink.WriteVarBytes(this.Field2) + sink.WriteVarBytes(this.Field3) + sink.WriteVarBytes(this.Field4) + sink.WriteVarBytes(this.Field5) + sink.WriteVarBytes(this.Field6) return nil } -func (this *GlobalParam2) Deserialize(r io.Reader) error { - minAuthorizePos, err := utils.ReadVarUint(r) +func (this *GlobalParam2) Deserialization(source *common.ZeroCopySource) error { + minAuthorizePos, err := utils.DecodeVarUint(source) if err != nil { return fmt.Errorf("utils.ReadVarUint, deserialize minAuthorizePos error: %v", err) } - candidateFeeSplitNum, err := utils.ReadVarUint(r) + candidateFeeSplitNum, err := utils.DecodeVarUint(source) if err != nil { return fmt.Errorf("utils.ReadVarUint, deserialize candidateFeeSplitNum error: %v", err) } - dappFee, err := utils.ReadVarUint(r) + dappFee, err := utils.DecodeVarUint(source) if err != nil { return fmt.Errorf("utils.ReadVarUint, deserialize dappFee error: %v", err) } - field2, err := serialization.ReadVarBytes(r) + field2, err := utils.DecodeVarBytes(source) if err != nil { return fmt.Errorf("serialization.ReadVarBytes, deserialize field2 error: %v", err) } - field3, err := serialization.ReadVarBytes(r) + field3, err := utils.DecodeVarBytes(source) if err != nil { return fmt.Errorf("serialization.ReadVarBytes, deserialize field3 error: %v", err) } - field4, err := serialization.ReadVarBytes(r) + field4, err := utils.DecodeVarBytes(source) if err != nil { return fmt.Errorf("serialization.ReadVarBytes, deserialize field4 error: %v", err) } - field5, err := serialization.ReadVarBytes(r) + field5, err := utils.DecodeVarBytes(source) if err != nil { return fmt.Errorf("serialization.ReadVarBytes, deserialize field5 error: %v", err) } - field6, err := serialization.ReadVarBytes(r) + field6, err := utils.DecodeVarBytes(source) if err != nil { return fmt.Errorf("utils.ReadVarUint, deserialize field6 error: %v", err) } @@ -748,29 +638,25 @@ type SplitCurve struct { Yi []uint32 } -func (this *SplitCurve) Serialize(w io.Writer) error { +func (this *SplitCurve) Serialization(sink *common.ZeroCopySink) error { if len(this.Yi) != 101 { return fmt.Errorf("length of split curve != 101") } - if err := utils.WriteVarUint(w, uint64(len(this.Yi))); err != nil { - return fmt.Errorf("serialization.WriteVarUint, serialize Yi length error: %v", err) - } + utils.EncodeVarUint(sink, uint64(len(this.Yi))) for _, v := range this.Yi { - if err := utils.WriteVarUint(w, uint64(v)); err != nil { - return fmt.Errorf("utils.WriteVarUint, serialize splitCurve error: %v", err) - } + utils.EncodeVarUint(sink, uint64(v)) } return nil } -func (this *SplitCurve) Deserialize(r io.Reader) error { - n, err := utils.ReadVarUint(r) +func (this *SplitCurve) Deserialization(source *common.ZeroCopySource) error { + n, err := utils.DecodeVarUint(source) if err != nil { return fmt.Errorf("serialization.ReadVarUint, deserialize Yi length error: %v", err) } yi := make([]uint32, 0) for i := 0; uint64(i) < n; i++ { - k, err := utils.ReadVarUint(r) + k, err := utils.DecodeVarUint(source) if err != nil { return fmt.Errorf("utils.ReadVarUint, deserialize splitCurve error: %v", err) } @@ -788,22 +674,17 @@ type TransferPenaltyParam struct { Address common.Address } -func (this *TransferPenaltyParam) Serialize(w io.Writer) error { - if err := serialization.WriteString(w, this.PeerPubkey); err != nil { - return fmt.Errorf("serialization.WriteString, serialize peerPubkey error: %v", err) - } - if err := serialization.WriteVarBytes(w, this.Address[:]); err != nil { - return fmt.Errorf("serialization.WriteVarBytes, serialize address error: %v", err) - } - return nil +func (this *TransferPenaltyParam) Serialization(sink *common.ZeroCopySink) { + sink.WriteString(this.PeerPubkey) + sink.WriteVarBytes(this.Address[:]) } -func (this *TransferPenaltyParam) Deserialize(r io.Reader) error { - peerPubkey, err := serialization.ReadString(r) +func (this *TransferPenaltyParam) Deserialization(source *common.ZeroCopySource) error { + peerPubkey, err := utils.DecodeString(source) if err != nil { return fmt.Errorf("serialization.ReadString, deserialize peerPubkey error: %v", err) } - address, err := utils.ReadAddress(r) + address, err := utils.DecodeAddress(source) if err != nil { return fmt.Errorf("utils.ReadAddress, deserialize address error: %v", err) } @@ -816,15 +697,12 @@ type WithdrawOngParam struct { Address common.Address } -func (this *WithdrawOngParam) Serialize(w io.Writer) error { - if err := serialization.WriteVarBytes(w, this.Address[:]); err != nil { - return fmt.Errorf("serialization.WriteVarBytes, serialize address error: %v", err) - } - return nil +func (this *WithdrawOngParam) Serialization(sink *common.ZeroCopySink) { + sink.WriteVarBytes(this.Address[:]) } -func (this *WithdrawOngParam) Deserialize(r io.Reader) error { - address, err := utils.ReadAddress(r) +func (this *WithdrawOngParam) Deserialization(source *common.ZeroCopySource) error { + address, err := utils.DecodeAddress(source) if err != nil { return fmt.Errorf("utils.ReadAddress, deserialize address error: %v", err) } @@ -838,29 +716,23 @@ type ChangeMaxAuthorizationParam struct { MaxAuthorize uint32 } -func (this *ChangeMaxAuthorizationParam) Serialize(w io.Writer) error { - if err := serialization.WriteString(w, this.PeerPubkey); err != nil { - return fmt.Errorf("serialization.WriteString, serialize peerPubkey error: %v", err) - } - if err := serialization.WriteVarBytes(w, this.Address[:]); err != nil { - return fmt.Errorf("serialization.WriteVarBytes, serialize address error: %v", err) - } - if err := utils.WriteVarUint(w, uint64(this.MaxAuthorize)); err != nil { - return fmt.Errorf("utils.WriteVarUint, serialize maxAuthorize error: %v", err) - } - return nil +func (this *ChangeMaxAuthorizationParam) Serialization(sink *common.ZeroCopySink) { + sink.WriteString(this.PeerPubkey) + sink.WriteVarBytes(this.Address[:]) + + utils.EncodeVarUint(sink, uint64(this.MaxAuthorize)) } -func (this *ChangeMaxAuthorizationParam) Deserialize(r io.Reader) error { - peerPubkey, err := serialization.ReadString(r) +func (this *ChangeMaxAuthorizationParam) Deserialization(source *common.ZeroCopySource) error { + peerPubkey, err := utils.DecodeString(source) if err != nil { return fmt.Errorf("serialization.ReadString, deserialize peerPubkey error: %v", err) } - address, err := utils.ReadAddress(r) + address, err := utils.DecodeAddress(source) if err != nil { return fmt.Errorf("utils.ReadAddress, deserialize address error: %v", err) } - maxAuthorize, err := utils.ReadVarUint(r) + maxAuthorize, err := utils.DecodeVarUint(source) if err != nil { return fmt.Errorf("utils.ReadVarUint, deserialize maxAuthorize error: %v", err) } @@ -879,29 +751,23 @@ type SetPeerCostParam struct { PeerCost uint32 } -func (this *SetPeerCostParam) Serialize(w io.Writer) error { - if err := serialization.WriteString(w, this.PeerPubkey); err != nil { - return fmt.Errorf("serialization.WriteString, serialize peerPubkey error: %v", err) - } - if err := serialization.WriteVarBytes(w, this.Address[:]); err != nil { - return fmt.Errorf("serialization.WriteVarBytes, serialize address error: %v", err) - } - if err := utils.WriteVarUint(w, uint64(this.PeerCost)); err != nil { - return fmt.Errorf("serialization.WriteBool, serialize peerCost error: %v", err) - } +func (this *SetPeerCostParam) Serialization(sink *common.ZeroCopySink) error { + sink.WriteString(this.PeerPubkey) + sink.WriteVarBytes(this.Address[:]) + utils.EncodeVarUint(sink, uint64(this.PeerCost)) return nil } -func (this *SetPeerCostParam) Deserialize(r io.Reader) error { - peerPubkey, err := serialization.ReadString(r) +func (this *SetPeerCostParam) Deserialization(source *common.ZeroCopySource) error { + peerPubkey, err := utils.DecodeString(source) if err != nil { return fmt.Errorf("serialization.ReadString, deserialize peerPubkey error: %v", err) } - address, err := utils.ReadAddress(r) + address, err := utils.DecodeAddress(source) if err != nil { return fmt.Errorf("utils.ReadAddress, deserialize address error: %v", err) } - peerCost, err := utils.ReadVarUint(r) + peerCost, err := utils.DecodeVarUint(source) if err != nil { return fmt.Errorf("serialization.ReadBool, deserialize peerCost error: %v", err) } @@ -918,15 +784,12 @@ type WithdrawFeeParam struct { Address common.Address } -func (this *WithdrawFeeParam) Serialize(w io.Writer) error { - if err := serialization.WriteVarBytes(w, this.Address[:]); err != nil { - return fmt.Errorf("serialization.WriteVarBytes, serialize address error: %v", err) - } - return nil +func (this *WithdrawFeeParam) Serialization(sink *common.ZeroCopySink) { + sink.WriteVarBytes(this.Address[:]) } -func (this *WithdrawFeeParam) Deserialize(r io.Reader) error { - address, err := utils.ReadAddress(r) +func (this *WithdrawFeeParam) Deserialization(source *common.ZeroCopySource) error { + address, err := utils.DecodeAddress(source) if err != nil { return fmt.Errorf("utils.ReadAddress, deserialize address error: %v", err) } @@ -939,22 +802,17 @@ type PromisePos struct { PromisePos uint64 } -func (this *PromisePos) Serialize(w io.Writer) error { - if err := serialization.WriteString(w, this.PeerPubkey); err != nil { - return fmt.Errorf("serialization.WriteString, serialize peerPubkey error: %v", err) - } - if err := utils.WriteVarUint(w, this.PromisePos); err != nil { - return fmt.Errorf("serialization.WriteBool, serialize promisePos error: %v", err) - } - return nil +func (this *PromisePos) Serialization(sink *common.ZeroCopySink) { + sink.WriteString(this.PeerPubkey) + utils.EncodeVarUint(sink, this.PromisePos) } -func (this *PromisePos) Deserialize(r io.Reader) error { - peerPubkey, err := serialization.ReadString(r) +func (this *PromisePos) Deserialization(source *common.ZeroCopySource) error { + peerPubkey, err := utils.DecodeString(source) if err != nil { return fmt.Errorf("serialization.ReadString, deserialize peerPubkey error: %v", err) } - promisePos, err := utils.ReadVarUint(r) + promisePos, err := utils.DecodeVarUint(source) if err != nil { return fmt.Errorf("serialization.ReadBool, deserialize promisePos error: %v", err) } @@ -969,29 +827,22 @@ type ChangeInitPosParam struct { Pos uint32 } -func (this *ChangeInitPosParam) Serialize(w io.Writer) error { - if err := serialization.WriteString(w, this.PeerPubkey); err != nil { - return fmt.Errorf("serialization.WriteString, serialize peerPubkey error: %v", err) - } - if err := serialization.WriteVarBytes(w, this.Address[:]); err != nil { - return fmt.Errorf("serialization.WriteVarBytes, serialize address error: %v", err) - } - if err := utils.WriteVarUint(w, uint64(this.Pos)); err != nil { - return fmt.Errorf("serialization.WriteBool, serialize pos error: %v", err) - } - return nil +func (this *ChangeInitPosParam) Serialization(sink *common.ZeroCopySink) { + sink.WriteString(this.PeerPubkey) + sink.WriteVarBytes(this.Address[:]) + utils.EncodeVarUint(sink, uint64(this.Pos)) } -func (this *ChangeInitPosParam) Deserialize(r io.Reader) error { - peerPubkey, err := serialization.ReadString(r) +func (this *ChangeInitPosParam) Deserialization(source *common.ZeroCopySource) error { + peerPubkey, err := utils.DecodeString(source) if err != nil { return fmt.Errorf("serialization.ReadString, deserialize peerPubkey error: %v", err) } - address, err := utils.ReadAddress(r) + address, err := utils.DecodeAddress(source) if err != nil { return fmt.Errorf("utils.ReadAddress, deserialize address error: %v", err) } - pos, err := utils.ReadVarUint(r) + pos, err := utils.DecodeVarUint(source) if err != nil { return fmt.Errorf("serialization.ReadBool, deserialize pos error: %v", err) } diff --git a/smartcontract/service/native/governance/states.go b/smartcontract/service/native/governance/states.go index c171a2e8ab..bd5a40d336 100644 --- a/smartcontract/service/native/governance/states.go +++ b/smartcontract/service/native/governance/states.go @@ -25,21 +25,19 @@ import ( "github.com/ontio/ontology/common" "github.com/ontio/ontology/common/serialization" + "github.com/ontio/ontology/smartcontract/service/native/utils" ) type Status uint8 -func (this *Status) Serialize(w io.Writer) error { - if err := serialization.WriteUint8(w, uint8(*this)); err != nil { - return fmt.Errorf("serialization.WriteUint8, serialize status error: %v", err) - } - return nil +func (this *Status) Serialization(sink *common.ZeroCopySink) { + sink.WriteUint8(uint8(*this)) } -func (this *Status) Deserialize(r io.Reader) error { - status, err := serialization.ReadUint8(r) - if err != nil { - return fmt.Errorf("serialization.ReadUint8, deserialize status error: %v", err) +func (this *Status) Deserialization(source *common.ZeroCopySource) error { + status, eof := source.NextUint8() + if eof { + return fmt.Errorf("serialization.ReadUint8, deserialize status error: %v", io.ErrUnexpectedEOF) } *this = Status(status) return nil @@ -51,32 +49,25 @@ type BlackListItem struct { InitPos uint64 //initPos of this peer } -func (this *BlackListItem) Serialize(w io.Writer) error { - if err := serialization.WriteString(w, this.PeerPubkey); err != nil { - return fmt.Errorf("serialization.WriteString, serialize peerPubkey error: %v", err) - } - if err := this.Address.Serialize(w); err != nil { - return fmt.Errorf("address.Serialize, serialize address error: %v", err) - } - if err := serialization.WriteUint64(w, this.InitPos); err != nil { - return fmt.Errorf("serialization.WriteUint64, serialize initPos error: %v", err) - } - return nil +func (this *BlackListItem) Serialization(sink *common.ZeroCopySink) { + sink.WriteString(this.PeerPubkey) + this.Address.Serialization(sink) + sink.WriteUint64(this.InitPos) } -func (this *BlackListItem) Deserialize(r io.Reader) error { - peerPubkey, err := serialization.ReadString(r) +func (this *BlackListItem) Deserialization(source *common.ZeroCopySource) error { + peerPubkey, err := utils.DecodeString(source) if err != nil { return fmt.Errorf("serialization.ReadString, deserialize peerPubkey error: %v", err) } address := new(common.Address) - err = address.Deserialize(r) + err = address.Deserialization(source) if err != nil { return fmt.Errorf("address.Deserialize, deserialize address error: %v", err) } - initPos, err := serialization.ReadUint64(r) - if err != nil { - return fmt.Errorf("serialization.ReadUint64, deserialize initPos error: %v", err) + initPos, eof := source.NextUint64() + if eof { + return fmt.Errorf("serialization.ReadUint64, deserialize initPos error: %v", io.ErrUnexpectedEOF) } this.PeerPubkey = peerPubkey this.Address = *address @@ -92,10 +83,9 @@ type PeerPoolMap struct { PeerPoolMap map[string]*PeerPoolItem } -func (this *PeerPoolMap) Serialize(w io.Writer) error { - if err := serialization.WriteUint32(w, uint32(len(this.PeerPoolMap))); err != nil { - return fmt.Errorf("serialization.WriteUint32, serialize PeerPoolMap length error: %v", err) - } +func (this *PeerPoolMap) Serialization(sink *common.ZeroCopySink) error { + sink.WriteUint32(uint32(len(this.PeerPoolMap))) + var peerPoolItemList []*PeerPoolItem for _, v := range this.PeerPoolMap { peerPoolItemList = append(peerPoolItemList, v) @@ -104,22 +94,21 @@ func (this *PeerPoolMap) Serialize(w io.Writer) error { return peerPoolItemList[i].PeerPubkey > peerPoolItemList[j].PeerPubkey }) for _, v := range peerPoolItemList { - if err := v.Serialize(w); err != nil { - return fmt.Errorf("serialize peerPool error: %v", err) - } + v.Serialization(sink) } return nil } -func (this *PeerPoolMap) Deserialize(r io.Reader) error { - n, err := serialization.ReadUint32(r) +func (this *PeerPoolMap) Deserialization(source *common.ZeroCopySource) error { + + n, err := utils.DecodeUint32(source) if err != nil { return fmt.Errorf("serialization.ReadUint32, deserialize PeerPoolMap length error: %v", err) } peerPoolMap := make(map[string]*PeerPoolItem) for i := 0; uint32(i) < n; i++ { peerPoolItem := new(PeerPoolItem) - if err := peerPoolItem.Deserialize(r); err != nil { + if err := peerPoolItem.Deserialization(source); err != nil { return fmt.Errorf("deserialize peerPool error: %v", err) } peerPoolMap[peerPoolItem.PeerPubkey] = peerPoolItem @@ -137,52 +126,39 @@ type PeerPoolItem struct { TotalPos uint64 //total authorize pos this peer received } -func (this *PeerPoolItem) Serialize(w io.Writer) error { - if err := serialization.WriteUint32(w, this.Index); err != nil { - return fmt.Errorf("serialization.WriteUint32, serialize address error: %v", err) - } - if err := serialization.WriteString(w, this.PeerPubkey); err != nil { - return fmt.Errorf("serialization.WriteString, serialize peerPubkey error: %v", err) - } - if err := this.Address.Serialize(w); err != nil { - return fmt.Errorf("address.Serialize, serialize address error: %v", err) - } - if err := this.Status.Serialize(w); err != nil { - return fmt.Errorf("this.Status.Serialize, serialize Status error: %v", err) - } - if err := serialization.WriteUint64(w, this.InitPos); err != nil { - return fmt.Errorf("serialization.WriteUint64, serialize initPos error: %v", err) - } - if err := serialization.WriteUint64(w, this.TotalPos); err != nil { - return fmt.Errorf("serialization.WriteUint64, serialize totalPos error: %v", err) - } - return nil +func (this *PeerPoolItem) Serialization(sink *common.ZeroCopySink) { + sink.WriteUint32(this.Index) + sink.WriteString(this.PeerPubkey) + this.Address.Serialization(sink) + this.Status.Serialization(sink) + sink.WriteUint64(this.InitPos) + sink.WriteUint64(this.TotalPos) } -func (this *PeerPoolItem) Deserialize(r io.Reader) error { - index, err := serialization.ReadUint32(r) - if err != nil { - return fmt.Errorf("serialization.ReadUint32, deserialize index error: %v", err) +func (this *PeerPoolItem) Deserialization(source *common.ZeroCopySource) error { + index, eof := source.NextUint32() + if eof { + return fmt.Errorf("serialization.ReadUint32, deserialize index error: %v", io.ErrUnexpectedEOF) } - peerPubkey, err := serialization.ReadString(r) + peerPubkey, err := utils.DecodeString(source) if err != nil { return fmt.Errorf("serialization.ReadString, deserialize peerPubkey error: %v", err) } address := new(common.Address) - err = address.Deserialize(r) + err = address.Deserialization(source) if err != nil { return fmt.Errorf("address.Deserialize, deserialize address error: %v", err) } status := new(Status) - err = status.Deserialize(r) + err = status.Deserialization(source) if err != nil { return fmt.Errorf("status.Deserialize. deserialize status error: %v", err) } - initPos, err := serialization.ReadUint64(r) - if err != nil { - return fmt.Errorf("serialization.ReadUint64, deserialize initPos error: %v", err) + initPos, eof := source.NextUint64() + if eof { + return fmt.Errorf("serialization.ReadUint64, deserialize initPos error: %v", io.ErrUnexpectedEOF) } - totalPos, err := serialization.ReadUint64(r) + totalPos, err := utils.DecodeUint64(source) if err != nil { return fmt.Errorf("serialization.ReadUint64, deserialize totalPos error: %v", err) } @@ -206,65 +182,48 @@ type AuthorizeInfo struct { WithdrawUnfreezePos uint64 //unfrozen pos, can withdraw at any time } -func (this *AuthorizeInfo) Serialize(w io.Writer) error { - if err := serialization.WriteString(w, this.PeerPubkey); err != nil { - return fmt.Errorf("serialization.WriteString, request peerPubkey error: %v", err) - } - if err := this.Address.Serialize(w); err != nil { - return fmt.Errorf("address.Serialize, serialize address error: %v", err) - } - if err := serialization.WriteUint64(w, this.ConsensusPos); err != nil { - return fmt.Errorf("serialization.WriteUint64, serialize consensusPos error: %v", err) - } - if err := serialization.WriteUint64(w, this.CandidatePos); err != nil { - return fmt.Errorf("serialization.WriteUint64, serialize candidatePos error: %v", err) - } - if err := serialization.WriteUint64(w, this.NewPos); err != nil { - return fmt.Errorf("serialization.WriteUint64, serialize newPos error: %v", err) - } - if err := serialization.WriteUint64(w, this.WithdrawConsensusPos); err != nil { - return fmt.Errorf("serialization.WriteUint64, serialize withdrawConsensusPos error: %v", err) - } - if err := serialization.WriteUint64(w, this.WithdrawCandidatePos); err != nil { - return fmt.Errorf("serialization.WriteUint64, serialize withdrawCandidatePos error: %v", err) - } - if err := serialization.WriteUint64(w, this.WithdrawUnfreezePos); err != nil { - return fmt.Errorf("serialization.WriteUint64, serialize withDrawUnfreezePos error: %v", err) - } - return nil +func (this *AuthorizeInfo) Serialization(sink *common.ZeroCopySink) { + sink.WriteString(this.PeerPubkey) + this.Address.Serialization(sink) + sink.WriteUint64(this.ConsensusPos) + sink.WriteUint64(this.CandidatePos) + sink.WriteUint64(this.NewPos) + sink.WriteUint64(this.WithdrawConsensusPos) + sink.WriteUint64(this.WithdrawCandidatePos) + sink.WriteUint64(this.WithdrawUnfreezePos) } -func (this *AuthorizeInfo) Deserialize(r io.Reader) error { - peerPubkey, err := serialization.ReadString(r) +func (this *AuthorizeInfo) Deserialization(source *common.ZeroCopySource) error { + peerPubkey, err := utils.DecodeString(source) if err != nil { return fmt.Errorf("serialization.ReadString, deserialize peerPubkey error: %v", err) } address := new(common.Address) - err = address.Deserialize(r) + err = address.Deserialization(source) if err != nil { return fmt.Errorf("address.Deserialize, deserialize address error: %v", err) } - consensusPos, err := serialization.ReadUint64(r) + consensusPos, err := utils.DecodeUint64(source) if err != nil { return fmt.Errorf("serialization.ReadUint64. deserialize consensusPos error: %v", err) } - candidatePos, err := serialization.ReadUint64(r) + candidatePos, err := utils.DecodeUint64(source) if err != nil { return fmt.Errorf("serialization.ReadUint64. deserialize candidatePos error: %v", err) } - newPos, err := serialization.ReadUint64(r) + newPos, err := utils.DecodeUint64(source) if err != nil { return fmt.Errorf("serialization.ReadUint64. deserialize newPos error: %v", err) } - withDrawConsensusPos, err := serialization.ReadUint64(r) + withDrawConsensusPos, err := utils.DecodeUint64(source) if err != nil { return fmt.Errorf("serialization.ReadUint64. deserialize withDrawConsensusPos error: %v", err) } - withDrawCandidatePos, err := serialization.ReadUint64(r) + withDrawCandidatePos, err := utils.DecodeUint64(source) if err != nil { return fmt.Errorf("serialization.ReadUint64. deserialize withDrawCandidatePos error: %v", err) } - withDrawUnfreezePos, err := serialization.ReadUint64(r) + withDrawUnfreezePos, err := utils.DecodeUint64(source) if err != nil { return fmt.Errorf("serialization.ReadUint64. deserialize withDrawUnfreezePos error: %v", err) } @@ -329,30 +288,23 @@ type TotalStake struct { //table record each address's total stake in this contr TimeOffset uint32 } -func (this *TotalStake) Serialize(w io.Writer) error { - if err := this.Address.Serialize(w); err != nil { - return fmt.Errorf("address.Serialize, serialize address error: %v", err) - } - if err := serialization.WriteUint64(w, this.Stake); err != nil { - return fmt.Errorf("serialization.WriteUint64, serialize stake error: %v", err) - } - if err := serialization.WriteUint32(w, this.TimeOffset); err != nil { - return fmt.Errorf("serialization.WriteUint32, serialize timeOffset error: %v", err) - } - return nil +func (this *TotalStake) Serialization(sink *common.ZeroCopySink) { + this.Address.Serialization(sink) + sink.WriteUint64(this.Stake) + sink.WriteUint32(this.TimeOffset) } -func (this *TotalStake) Deserialize(r io.Reader) error { +func (this *TotalStake) Deserialization(source *common.ZeroCopySource) error { address := new(common.Address) - err := address.Deserialize(r) + err := address.Deserialization(source) if err != nil { return fmt.Errorf("address.Deserialize, deserialize address error: %v", err) } - stake, err := serialization.ReadUint64(r) + stake, err := utils.DecodeUint64(source) if err != nil { return fmt.Errorf("serialization.ReadUint64, deserialize stake error: %v", err) } - timeOffset, err := serialization.ReadUint32(r) + timeOffset, err := utils.DecodeUint32(source) if err != nil { return fmt.Errorf("serialization.ReadUint64, deserialize timeOffset error: %v", err) } @@ -370,43 +322,32 @@ type PenaltyStake struct { //table record penalty stake of peer Amount uint64 //unbound ong that this penalty unbounded } -func (this *PenaltyStake) Serialize(w io.Writer) error { - if err := serialization.WriteString(w, this.PeerPubkey); err != nil { - return fmt.Errorf("serialization.WriteString, serialize peerPubkey error: %v", err) - } - if err := serialization.WriteUint64(w, this.InitPos); err != nil { - return fmt.Errorf("serialization.WriteUint64, serialize initPos error: %v", err) - } - if err := serialization.WriteUint64(w, this.AuthorizePos); err != nil { - return fmt.Errorf("serialization.WriteUint64, serialize authorizePos error: %v", err) - } - if err := serialization.WriteUint32(w, this.TimeOffset); err != nil { - return fmt.Errorf("serialization.WriteUint32, serialize timeOffset error: %v", err) - } - if err := serialization.WriteUint64(w, this.Amount); err != nil { - return fmt.Errorf("serialization.WriteUint64, serialize amount error: %v", err) - } - return nil +func (this *PenaltyStake) Serialization(sink *common.ZeroCopySink) { + sink.WriteString(this.PeerPubkey) + sink.WriteUint64(this.InitPos) + sink.WriteUint64(this.AuthorizePos) + sink.WriteUint32(this.TimeOffset) + sink.WriteUint64(this.Amount) } -func (this *PenaltyStake) Deserialize(r io.Reader) error { - peerPubkey, err := serialization.ReadString(r) +func (this *PenaltyStake) Deserialization(source *common.ZeroCopySource) error { + peerPubkey, err := utils.DecodeString(source) if err != nil { return fmt.Errorf("serialization.ReadString, deserialize peerPubkey error: %v", err) } - initPos, err := serialization.ReadUint64(r) + initPos, err := utils.DecodeUint64(source) if err != nil { return fmt.Errorf("serialization.ReadUint64. deserialize initPos error: %v", err) } - authorizePos, err := serialization.ReadUint64(r) + authorizePos, err := utils.DecodeUint64(source) if err != nil { return fmt.Errorf("serialization.ReadUint64. deserialize authorizePos error: %v", err) } - timeOffset, err := serialization.ReadUint32(r) + timeOffset, err := utils.DecodeUint32(source) if err != nil { return fmt.Errorf("serialization.ReadUint64, deserialize timeOffset error: %v", err) } - amount, err := serialization.ReadUint64(r) + amount, err := utils.DecodeUint64(source) if err != nil { return fmt.Errorf("serialization.ReadUint64. deserialize amount error: %v", err) } @@ -438,71 +379,52 @@ type PeerAttributes struct { Field4 []byte //reserved field } -func (this *PeerAttributes) Serialize(w io.Writer) error { - if err := serialization.WriteString(w, this.PeerPubkey); err != nil { - return fmt.Errorf("serialization.WriteBool, serialize peerPubkey error: %v", err) - } - if err := serialization.WriteUint64(w, this.MaxAuthorize); err != nil { - return fmt.Errorf("serialization.WriteUint64, serialize maxAuthorize error: %v", err) - } - if err := serialization.WriteUint64(w, this.T2PeerCost); err != nil { - return fmt.Errorf("serialization.WriteUint64, serialize oldPeerCost error: %v", err) - } - if err := serialization.WriteUint64(w, this.T1PeerCost); err != nil { - return fmt.Errorf("serialization.WriteUint64, serialize newPeerCost error: %v", err) - } - if err := serialization.WriteUint64(w, this.TPeerCost); err != nil { - return fmt.Errorf("serialization.WriteUint64, serialize newPeerCost error: %v", err) - } - if err := serialization.WriteVarBytes(w, this.Field1); err != nil { - return fmt.Errorf("serialization.WriteVarBytes, serialize field1 error: %v", err) - } - if err := serialization.WriteVarBytes(w, this.Field2); err != nil { - return fmt.Errorf("serialization.WriteVarBytes, serialize field2 error: %v", err) - } - if err := serialization.WriteVarBytes(w, this.Field3); err != nil { - return fmt.Errorf("serialization.WriteVarBytes, serialize field3 error: %v", err) - } - if err := serialization.WriteVarBytes(w, this.Field4); err != nil { - return fmt.Errorf("serialization.WriteVarBytes, serialize field4 error: %v", err) - } - return nil +func (this *PeerAttributes) Serialization(sink *common.ZeroCopySink) { + sink.WriteString(this.PeerPubkey) + sink.WriteUint64(this.MaxAuthorize) + sink.WriteUint64(this.T2PeerCost) + sink.WriteUint64(this.T1PeerCost) + sink.WriteUint64(this.TPeerCost) + sink.WriteVarBytes(this.Field1) + sink.WriteVarBytes(this.Field2) + sink.WriteVarBytes(this.Field3) + sink.WriteVarBytes(this.Field4) } -func (this *PeerAttributes) Deserialize(r io.Reader) error { - peerPubkey, err := serialization.ReadString(r) +func (this *PeerAttributes) Deserialization(source *common.ZeroCopySource) error { + peerPubkey, err := utils.DecodeString(source) if err != nil { return fmt.Errorf("serialization.ReadString, deserialize peerPubkey error: %v", err) } - maxAuthorize, err := serialization.ReadUint64(r) - if err != nil { - return fmt.Errorf("serialization.ReadBool, deserialize maxAuthorize error: %v", err) + maxAuthorize, eof := source.NextUint64() + if eof { + return fmt.Errorf("serialization.ReadBool, deserialize maxAuthorize error: %v", io.ErrUnexpectedEOF) } - t2PeerCost, err := serialization.ReadUint64(r) - if err != nil { - return fmt.Errorf("serialization.ReadUint64, deserialize t2PeerCost error: %v", err) + t2PeerCost, eof := source.NextUint64() + if eof { + return fmt.Errorf("serialization.ReadUint64, deserialize t2PeerCost error: %v", io.ErrUnexpectedEOF) } - t1PeerCost, err := serialization.ReadUint64(r) - if err != nil { - return fmt.Errorf("serialization.ReadUint64, deserialize t1PeerCost error: %v", err) + t1PeerCost, eof := source.NextUint64() + if eof { + return fmt.Errorf("serialization.ReadUint64, deserialize t1PeerCost error: %v", io.ErrUnexpectedEOF) } - tPeerCost, err := serialization.ReadUint64(r) - if err != nil { - return fmt.Errorf("serialization.ReadUint64, deserialize tPeerCost error: %v", err) + tPeerCost, eof := source.NextUint64() + if eof { + return fmt.Errorf("serialization.ReadUint64, deserialize tPeerCost error: %v", io.ErrUnexpectedEOF) } - field1, err := serialization.ReadVarBytes(r) + field1, err := utils.DecodeVarBytes(source) if err != nil { return fmt.Errorf("serialization.ReadVarBytes. deserialize field1 error: %v", err) } - field2, err := serialization.ReadVarBytes(r) + field2, err := utils.DecodeVarBytes(source) if err != nil { return fmt.Errorf("serialization.ReadVarBytes. deserialize field2 error: %v", err) } - field3, err := serialization.ReadVarBytes(r) + field3, err := utils.DecodeVarBytes(source) if err != nil { return fmt.Errorf("serialization.ReadVarBytes, deserialize field3 error: %v", err) } - field4, err := serialization.ReadVarBytes(r) + field4, err := utils.DecodeVarBytes(source) if err != nil { return fmt.Errorf("serialization.ReadVarBytes. deserialize field4 error: %v", err) } @@ -523,23 +445,18 @@ type SplitFeeAddress struct { //table record each address's ong motivation Amount uint64 } -func (this *SplitFeeAddress) Serialize(w io.Writer) error { - if err := this.Address.Serialize(w); err != nil { - return fmt.Errorf("address.Serialize, serialize address error: %v", err) - } - if err := serialization.WriteUint64(w, this.Amount); err != nil { - return fmt.Errorf("serialization.WriteUint64, serialize amount error: %v", err) - } - return nil +func (this *SplitFeeAddress) Serialization(sink *common.ZeroCopySink) { + this.Address.Serialization(sink) + sink.WriteUint64(this.Amount) } -func (this *SplitFeeAddress) Deserialize(r io.Reader) error { +func (this *SplitFeeAddress) Deserialization(source *common.ZeroCopySource) error { address := new(common.Address) - err := address.Deserialize(r) + err := address.Deserialization(source) if err != nil { return fmt.Errorf("address.Deserialize, deserialize address error: %v", err) } - amount, err := serialization.ReadUint64(r) + amount, err := utils.DecodeUint64(source) if err != nil { return fmt.Errorf("serialization.ReadUint64, deserialize amount error: %v", err) } diff --git a/smartcontract/service/native/governance/utils.go b/smartcontract/service/native/governance/utils.go index 88d12a6691..eb29d8c6ef 100644 --- a/smartcontract/service/native/governance/utils.go +++ b/smartcontract/service/native/governance/utils.go @@ -51,27 +51,27 @@ func GetPeerPoolMap(native *native.NativeService, contract common.Address, view return nil, fmt.Errorf("getPeerPoolMap, peerPoolMap is nil") } item := cstates.StorageItem{} - err = item.Deserialize(bytes.NewBuffer(peerPoolMapBytes)) + err = item.Deserialization(common.NewZeroCopySource(peerPoolMapBytes)) if err != nil { return nil, fmt.Errorf("deserialize PeerPoolMap error:%v", err) } peerPoolMapStore := item.Value - if err := peerPoolMap.Deserialize(bytes.NewBuffer(peerPoolMapStore)); err != nil { + if err := peerPoolMap.Deserialization(common.NewZeroCopySource(peerPoolMapStore)); err != nil { return nil, fmt.Errorf("deserialize, deserialize peerPoolMap error: %v", err) } return peerPoolMap, nil } func putPeerPoolMap(native *native.NativeService, contract common.Address, view uint32, peerPoolMap *PeerPoolMap) error { - bf := new(bytes.Buffer) - if err := peerPoolMap.Serialize(bf); err != nil { + sink := common.NewZeroCopySink(nil) + if err := peerPoolMap.Serialization(sink); err != nil { return fmt.Errorf("serialize, serialize peerPoolMap error: %v", err) } viewBytes, err := GetUint32Bytes(view) if err != nil { return fmt.Errorf("getUint32Bytes, get viewBytes error: %v", err) } - native.CacheDB.Put(utils.ConcatKey(contract, []byte(PEER_POOL), viewBytes), cstates.GenRawStorageItem(bf.Bytes())) + native.CacheDB.Put(utils.ConcatKey(contract, []byte(PEER_POOL), viewBytes), cstates.GenRawStorageItem(sink.Bytes())) return nil } @@ -138,10 +138,8 @@ func appCallTransfer(native *native.NativeService, contract common.Address, from transfers := ont.Transfers{ States: sts, } - sink := common.NewZeroCopySink(nil) - transfers.Serialization(sink) - if _, err := native.NativeCall(contract, "transfer", sink.Bytes()); err != nil { + if _, err := native.NativeCall(contract, "transfer", common.SerializeToBytes(&transfers)); err != nil { return fmt.Errorf("appCallTransfer, appCall error: %v", err) } return nil @@ -170,21 +168,14 @@ func appCallTransferFrom(native *native.NativeService, contract common.Address, To: to, Value: amount, } - sink := common.NewZeroCopySink(nil) - params.Serialization(sink) - if _, err := native.NativeCall(contract, "transferFrom", sink.Bytes()); err != nil { + if _, err := native.NativeCall(contract, "transferFrom", common.SerializeToBytes(params)); err != nil { return fmt.Errorf("appCallTransferFrom, appCall error: %v", err) } return nil } func getOngBalance(native *native.NativeService, address common.Address) (uint64, error) { - bf := new(bytes.Buffer) - err := utils.WriteAddress(bf, address) - if err != nil { - return 0, fmt.Errorf("getOngBalance, utils.WriteAddress error: %v", err) - } sink := common.ZeroCopySink{} utils.EncodeAddress(&sink, address) @@ -231,12 +222,10 @@ func GetBytesUint32(b []byte) (uint32, error) { return num, nil } -func GetUint64Bytes(num uint64) ([]byte, error) { - bf := new(bytes.Buffer) - if err := serialization.WriteUint64(bf, num); err != nil { - return nil, fmt.Errorf("serialization.WriteUint64, serialize uint64 error: %v", err) - } - return bf.Bytes(), nil +func GetUint64Bytes(num uint64) []byte { + sink := common.NewZeroCopySink(nil) + sink.WriteUint64(num) + return sink.Bytes() } func GetBytesUint64(b []byte) (uint64, error) { @@ -260,7 +249,7 @@ func getGlobalParam(native *native.NativeService, contract common.Address) (*Glo if err != nil { return nil, fmt.Errorf("getGlobalParam, deserialize from raw storage item err:%v", err) } - if err := globalParam.Deserialize(bytes.NewBuffer(value)); err != nil { + if err := globalParam.Deserialization(common.NewZeroCopySource(value)); err != nil { return nil, fmt.Errorf("deserialize, deserialize globalParam error: %v", err) } } @@ -268,11 +257,7 @@ func getGlobalParam(native *native.NativeService, contract common.Address) (*Glo } func putGlobalParam(native *native.NativeService, contract common.Address, globalParam *GlobalParam) error { - bf := new(bytes.Buffer) - if err := globalParam.Serialize(bf); err != nil { - return fmt.Errorf("serialize, serialize globalParam error: %v", err) - } - native.CacheDB.Put(utils.ConcatKey(contract, []byte(GLOBAL_PARAM)), cstates.GenRawStorageItem(bf.Bytes())) + native.CacheDB.Put(utils.ConcatKey(contract, []byte(GLOBAL_PARAM)), cstates.GenRawStorageItem(common.SerializeToBytes(globalParam))) return nil } @@ -297,7 +282,7 @@ func getGlobalParam2(native *native.NativeService, contract common.Address) (*Gl if err != nil { return nil, fmt.Errorf("getGlobalParam2, globalParam2Bytes is not available") } - if err := globalParam2.Deserialize(bytes.NewBuffer(value)); err != nil { + if err := globalParam2.Deserialization(common.NewZeroCopySource(value)); err != nil { return nil, fmt.Errorf("deserialize, deserialize getGlobalParam2 error: %v", err) } } @@ -305,11 +290,11 @@ func getGlobalParam2(native *native.NativeService, contract common.Address) (*Gl } func putGlobalParam2(native *native.NativeService, contract common.Address, globalParam2 *GlobalParam2) error { - bf := new(bytes.Buffer) - if err := globalParam2.Serialize(bf); err != nil { + sink := common.NewZeroCopySink(nil) + if err := globalParam2.Serialization(sink); err != nil { return fmt.Errorf("serialize, serialize globalParam2 error: %v", err) } - native.CacheDB.Put(utils.ConcatKey(contract, []byte(GLOBAL_PARAM2)), cstates.GenRawStorageItem(bf.Bytes())) + native.CacheDB.Put(utils.ConcatKey(contract, []byte(GLOBAL_PARAM2)), cstates.GenRawStorageItem(sink.Bytes())) return nil } @@ -402,18 +387,14 @@ func getConfig(native *native.NativeService, contract common.Address) (*Configur if err != nil { return nil, fmt.Errorf("getConfig, deserialize from raw storage item err:%v", err) } - if err := config.Deserialize(bytes.NewBuffer(value)); err != nil { + if err := config.Deserialization(common.NewZeroCopySource(value)); err != nil { return nil, fmt.Errorf("deserialize, deserialize config error: %v", err) } return config, nil } func putConfig(native *native.NativeService, contract common.Address, config *Configuration) error { - bf := new(bytes.Buffer) - if err := config.Serialize(bf); err != nil { - return fmt.Errorf("serialize, serialize config error: %v", err) - } - native.CacheDB.Put(utils.ConcatKey(contract, []byte(VBFT_CONFIG)), cstates.GenRawStorageItem(bf.Bytes())) + native.CacheDB.Put(utils.ConcatKey(contract, []byte(VBFT_CONFIG)), cstates.GenRawStorageItem(common.SerializeToBytes(config))) return nil } @@ -428,7 +409,7 @@ func getPreConfig(native *native.NativeService, contract common.Address) (*PreCo if err != nil { return nil, fmt.Errorf("getConfig, preConfigBytes is not available") } - if err := preConfig.Deserialize(bytes.NewBuffer(preConfigStore)); err != nil { + if err := preConfig.Deserialization(common.NewZeroCopySource(preConfigStore)); err != nil { return nil, fmt.Errorf("deserialize, deserialize preConfig error: %v", err) } } @@ -436,11 +417,7 @@ func getPreConfig(native *native.NativeService, contract common.Address) (*PreCo } func putPreConfig(native *native.NativeService, contract common.Address, preConfig *PreConfig) error { - bf := new(bytes.Buffer) - if err := preConfig.Serialize(bf); err != nil { - return fmt.Errorf("serialize, serialize preConfig error: %v", err) - } - native.CacheDB.Put(utils.ConcatKey(contract, []byte(PRE_CONFIG)), cstates.GenRawStorageItem(bf.Bytes())) + native.CacheDB.Put(utils.ConcatKey(contract, []byte(PRE_CONFIG)), cstates.GenRawStorageItem(common.SerializeToBytes(preConfig))) return nil } @@ -493,10 +470,7 @@ func getSplitFee(native *native.NativeService, contract common.Address) (uint64, } func putSplitFee(native *native.NativeService, contract common.Address, splitFee uint64) error { - splitFeeBytes, err := GetUint64Bytes(splitFee) - if err != nil { - return fmt.Errorf("GetUint64Bytes, get splitFeeBytes error: %v", err) - } + splitFeeBytes := GetUint64Bytes(splitFee) native.CacheDB.Put(utils.ConcatKey(contract, []byte(SPLIT_FEE)), cstates.GenRawStorageItem(splitFeeBytes)) return nil } @@ -514,7 +488,7 @@ func getSplitFeeAddress(native *native.NativeService, contract common.Address, a if err != nil { return nil, fmt.Errorf("getSplitFeeAddress, splitFeeAddressBytes is not available") } - err = splitFeeAddress.Deserialize(bytes.NewBuffer(splitFeeAddressStore)) + err = splitFeeAddress.Deserialization(common.NewZeroCopySource(splitFeeAddressStore)) if err != nil { return nil, fmt.Errorf("deserialize, deserialize splitFeeAddress error: %v", err) } @@ -523,12 +497,8 @@ func getSplitFeeAddress(native *native.NativeService, contract common.Address, a } func putSplitFeeAddress(native *native.NativeService, contract common.Address, address common.Address, splitFeeAddress *SplitFeeAddress) error { - bf := new(bytes.Buffer) - if err := splitFeeAddress.Serialize(bf); err != nil { - return fmt.Errorf("serialize, serialize splitFeeAddress error: %v", err) - } native.CacheDB.Put(utils.ConcatKey(contract, []byte(SPLIT_FEE_ADDRESS), address[:]), - cstates.GenRawStorageItem(bf.Bytes())) + cstates.GenRawStorageItem(common.SerializeToBytes(splitFeeAddress))) return nil } @@ -551,7 +521,7 @@ func getAuthorizeInfo(native *native.NativeService, contract common.Address, pee if err != nil { return nil, fmt.Errorf("getAuthorizeInfo, deserialize from raw storage item err:%v", err) } - if err := authorizeInfo.Deserialize(bytes.NewBuffer(authorizeInfoStore)); err != nil { + if err := authorizeInfo.Deserialization(common.NewZeroCopySource(authorizeInfoStore)); err != nil { return nil, fmt.Errorf("deserialize, deserialize authorizeInfo error: %v", err) } } @@ -563,12 +533,8 @@ func putAuthorizeInfo(native *native.NativeService, contract common.Address, aut if err != nil { return fmt.Errorf("hex.DecodeString, peerPubkey format error: %v", err) } - bf := new(bytes.Buffer) - if err := authorizeInfo.Serialize(bf); err != nil { - return fmt.Errorf("serialize, serialize authorizeInfo error: %v", err) - } native.CacheDB.Put(utils.ConcatKey(contract, AUTHORIZE_INFO_POOL, peerPubkeyPrefix, - authorizeInfo.Address[:]), cstates.GenRawStorageItem(bf.Bytes())) + authorizeInfo.Address[:]), cstates.GenRawStorageItem(common.SerializeToBytes(authorizeInfo))) return nil } @@ -590,7 +556,7 @@ func getPenaltyStake(native *native.NativeService, contract common.Address, peer if err != nil { return nil, fmt.Errorf("getPenaltyStake, deserialize from raw storage item err:%v", err) } - if err := penaltyStake.Deserialize(bytes.NewBuffer(penaltyStakeStore)); err != nil { + if err := penaltyStake.Deserialization(common.NewZeroCopySource(penaltyStakeStore)); err != nil { return nil, fmt.Errorf("deserialize, deserialize authorizeInfo error: %v", err) } } @@ -602,12 +568,8 @@ func putPenaltyStake(native *native.NativeService, contract common.Address, pena if err != nil { return fmt.Errorf("hex.DecodeString, peerPubkey format error: %v", err) } - bf := new(bytes.Buffer) - if err := penaltyStake.Serialize(bf); err != nil { - return fmt.Errorf("serialize, serialize authorizeInfo error: %v", err) - } native.CacheDB.Put(utils.ConcatKey(contract, []byte(PENALTY_STAKE), peerPubkeyPrefix), - cstates.GenRawStorageItem(bf.Bytes())) + cstates.GenRawStorageItem(common.SerializeToBytes(penaltyStake))) return nil } @@ -625,7 +587,7 @@ func getTotalStake(native *native.NativeService, contract common.Address, addres if err != nil { return nil, fmt.Errorf("getTotalStake, deserialize from raw storage item err:%v", err) } - if err := totalStake.Deserialize(bytes.NewBuffer(totalStakeStore)); err != nil { + if err := totalStake.Deserialization(common.NewZeroCopySource(totalStakeStore)); err != nil { return nil, fmt.Errorf("deserialize, deserialize authorizeInfo error: %v", err) } } @@ -633,12 +595,8 @@ func getTotalStake(native *native.NativeService, contract common.Address, addres } func putTotalStake(native *native.NativeService, contract common.Address, totalStake *TotalStake) error { - bf := new(bytes.Buffer) - if err := totalStake.Serialize(bf); err != nil { - return fmt.Errorf("serialize, serialize authorizeInfo error: %v", err) - } native.CacheDB.Put(utils.ConcatKey(contract, []byte(TOTAL_STAKE), totalStake.Address[:]), - cstates.GenRawStorageItem(bf.Bytes())) + cstates.GenRawStorageItem(common.SerializeToBytes(totalStake))) return nil } @@ -655,7 +613,7 @@ func getSplitCurve(native *native.NativeService, contract common.Address) (*Spli if err != nil { return nil, fmt.Errorf("getSplitCurve, deserialize from raw storage item err:%v", err) } - if err := splitCurve.Deserialize(bytes.NewBuffer(splitCurveStore)); err != nil { + if err := splitCurve.Deserialization(common.NewZeroCopySource(splitCurveStore)); err != nil { return nil, fmt.Errorf("deserialize, deserialize splitCurve error: %v", err) } } @@ -663,8 +621,8 @@ func getSplitCurve(native *native.NativeService, contract common.Address) (*Spli } func putSplitCurve(native *native.NativeService, contract common.Address, splitCurve *SplitCurve) error { - bf := new(bytes.Buffer) - if err := splitCurve.Serialize(bf); err != nil { + bf := common.NewZeroCopySink(nil) + if err := splitCurve.Serialization(bf); err != nil { return fmt.Errorf("serialize, serialize splitCurve error: %v", err) } native.CacheDB.Put(utils.ConcatKey(contract, []byte(SPLIT_CURVE)), cstates.GenRawStorageItem(bf.Bytes())) @@ -672,35 +630,23 @@ func putSplitCurve(native *native.NativeService, contract common.Address, splitC } func appCallInitContractAdmin(native *native.NativeService, adminOntID []byte) error { - bf := new(bytes.Buffer) params := &auth.InitContractAdminParam{ AdminOntID: adminOntID, } - err := params.Serialize(bf) - if err != nil { - return fmt.Errorf("appCallInitContractAdmin, param serialize error: %v", err) - } - - if _, err := native.NativeCall(utils.AuthContractAddress, "initContractAdmin", bf.Bytes()); err != nil { + if _, err := native.NativeCall(utils.AuthContractAddress, "initContractAdmin", common.SerializeToBytes(params)); err != nil { return fmt.Errorf("appCallInitContractAdmin, appCall error: %v", err) } return nil } func appCallVerifyToken(native *native.NativeService, contract common.Address, caller []byte, fn string, keyNo uint64) error { - bf := new(bytes.Buffer) params := &auth.VerifyTokenParam{ ContractAddr: contract, Caller: caller, Fn: fn, KeyNo: keyNo, } - err := params.Serialize(bf) - if err != nil { - return fmt.Errorf("appCallVerifyToken, param serialize error: %v", err) - } - - ok, err := native.NativeCall(utils.AuthContractAddress, "verifyToken", bf.Bytes()) + ok, err := native.NativeCall(utils.AuthContractAddress, "verifyToken", common.SerializeToBytes(params)) if err != nil { return fmt.Errorf("appCallVerifyToken, appCall error: %v", err) } @@ -731,7 +677,7 @@ func getPeerAttributes(native *native.NativeService, contract common.Address, pe if err != nil { return nil, fmt.Errorf("getPeerAttributes, peerAttributesStore is not available") } - if err := peerAttributes.Deserialize(bytes.NewBuffer(peerAttributesStore)); err != nil { + if err := peerAttributes.Deserialization(common.NewZeroCopySource(peerAttributesStore)); err != nil { return nil, fmt.Errorf("deserialize, deserialize peerAttributes error: %v", err) } } @@ -753,11 +699,7 @@ func putPeerAttributes(native *native.NativeService, contract common.Address, pe if err != nil { return fmt.Errorf("hex.DecodeString, peerPubkey format error: %v", err) } - bf := new(bytes.Buffer) - if err := peerAttributes.Serialize(bf); err != nil { - return fmt.Errorf("serialize, serialize peerAttributes error: %v", err) - } - native.CacheDB.Put(utils.ConcatKey(contract, []byte(PEER_ATTRIBUTES), peerPubkeyPrefix), cstates.GenRawStorageItem(bf.Bytes())) + native.CacheDB.Put(utils.ConcatKey(contract, []byte(PEER_ATTRIBUTES), peerPubkeyPrefix), cstates.GenRawStorageItem(common.SerializeToBytes(peerAttributes))) return nil } @@ -775,7 +717,7 @@ func getPromisePos(native *native.NativeService, contract common.Address, peerPu return nil, fmt.Errorf("get value from promisePosBytes err:%v", err) } promisePos := new(PromisePos) - if err := promisePos.Deserialize(bytes.NewBuffer(promisePosStore)); err != nil { + if err := promisePos.Deserialization(common.NewZeroCopySource(promisePosStore)); err != nil { return nil, fmt.Errorf("deserialize, deserialize promisePos error: %v", err) } return promisePos, nil @@ -786,12 +728,8 @@ func putPromisePos(native *native.NativeService, contract common.Address, promis if err != nil { return fmt.Errorf("hex.DecodeString, peerPubkey format error: %v", err) } - bf := new(bytes.Buffer) - if err := promisePos.Serialize(bf); err != nil { - return fmt.Errorf("serialize, serialize promisePos error: %v", err) - } native.CacheDB.Put(utils.ConcatKey(contract, []byte(PROMISE_POS), peerPubkeyPrefix), - cstates.GenRawStorageItem(bf.Bytes())) + cstates.GenRawStorageItem(common.SerializeToBytes(promisePos))) return nil } @@ -814,9 +752,7 @@ func getGasAddress(native *native.NativeService, contract common.Address) (*GasA } func putGasAddress(native *native.NativeService, contract common.Address, gasAddress *GasAddress) error { - sink := common.NewZeroCopySink(nil) - gasAddress.Serialization(sink) native.CacheDB.Put(utils.ConcatKey(contract, []byte(GAS_ADDRESS)), - cstates.GenRawStorageItem(sink.Bytes())) + cstates.GenRawStorageItem(common.SerializeToBytes(gasAddress))) return nil } diff --git a/smartcontract/service/native/ont/ont.go b/smartcontract/service/native/ont/ont.go index 8b825d27b5..08d0315240 100644 --- a/smartcontract/service/native/ont/ont.go +++ b/smartcontract/service/native/ont/ont.go @@ -19,7 +19,6 @@ package ont import ( - "bytes" "fmt" "math/big" @@ -274,7 +273,7 @@ func grantOng(native *native.NativeService, contract, address common.Address, ba } func getApproveArgs(native *native.NativeService, contract, ongContract, address common.Address, value uint64) ([]byte, error) { - bf := new(bytes.Buffer) + bf := common.NewZeroCopySink(nil) approve := State{ From: contract, To: address, @@ -287,9 +286,6 @@ func getApproveArgs(native *native.NativeService, contract, ongContract, address } approve.Value += stateValue - - if err := approve.Serialize(bf); err != nil { - return nil, err - } + approve.Serialization(bf) return bf.Bytes(), nil } diff --git a/smartcontract/service/native/ont/states.go b/smartcontract/service/native/ont/states.go index 1ffd57c903..23fea54b88 100644 --- a/smartcontract/service/native/ont/states.go +++ b/smartcontract/service/native/ont/states.go @@ -19,11 +19,7 @@ package ont import ( - "fmt" - "io" - "github.com/ontio/ontology/common" - "github.com/ontio/ontology/errors" "github.com/ontio/ontology/smartcontract/service/native/utils" ) @@ -32,18 +28,6 @@ type Transfers struct { States []State } -func (this *Transfers) Serialize(w io.Writer) error { - if err := utils.WriteVarUint(w, uint64(len(this.States))); err != nil { - return errors.NewDetailErr(err, errors.ErrNoCode, "[TokenTransfer] Serialize States length error!") - } - for _, v := range this.States { - if err := v.Serialize(w); err != nil { - return errors.NewDetailErr(err, errors.ErrNoCode, "[TokenTransfer] Serialize States error!") - } - } - return nil -} - func (this *Transfers) Serialization(sink *common.ZeroCopySink) { utils.EncodeVarUint(sink, uint64(len(this.States))) for _, v := range this.States { @@ -51,21 +35,6 @@ func (this *Transfers) Serialization(sink *common.ZeroCopySink) { } } -func (this *Transfers) Deserialize(r io.Reader) error { - n, err := utils.ReadVarUint(r) - if err != nil { - return errors.NewDetailErr(err, errors.ErrNoCode, "[TokenTransfer] Deserialize states length error!") - } - for i := 0; uint64(i) < n; i++ { - var state State - if err := state.Deserialize(r); err != nil { - return errors.NewDetailErr(err, errors.ErrNoCode, "[TokenTransfer] Deserialize states error!") - } - this.States = append(this.States, state) - } - return nil -} - func (this *Transfers) Deserialization(source *common.ZeroCopySource) error { n, err := utils.DecodeVarUint(source) if err != nil { @@ -87,43 +56,12 @@ type State struct { Value uint64 } -func (this *State) Serialize(w io.Writer) error { - if err := utils.WriteAddress(w, this.From); err != nil { - return fmt.Errorf("[State] serialize from error:%v", err) - } - if err := utils.WriteAddress(w, this.To); err != nil { - return fmt.Errorf("[State] serialize to error:%v", err) - } - if err := utils.WriteVarUint(w, this.Value); err != nil { - return fmt.Errorf("[State] serialize value error:%v", err) - } - return nil -} - func (this *State) Serialization(sink *common.ZeroCopySink) { utils.EncodeAddress(sink, this.From) utils.EncodeAddress(sink, this.To) utils.EncodeVarUint(sink, this.Value) } -func (this *State) Deserialize(r io.Reader) error { - var err error - this.From, err = utils.ReadAddress(r) - if err != nil { - return fmt.Errorf("[State] deserialize from error:%v", err) - } - this.To, err = utils.ReadAddress(r) - if err != nil { - return fmt.Errorf("[State] deserialize to error:%v", err) - } - - this.Value, err = utils.ReadVarUint(r) - if err != nil { - return err - } - return nil -} - func (this *State) Deserialization(source *common.ZeroCopySource) error { var err error this.From, err = utils.DecodeAddress(source) @@ -148,22 +86,6 @@ type TransferFrom struct { Value uint64 } -func (this *TransferFrom) Serialize(w io.Writer) error { - if err := utils.WriteAddress(w, this.Sender); err != nil { - return fmt.Errorf("[TransferFrom] serialize sender error:%v", err) - } - if err := utils.WriteAddress(w, this.From); err != nil { - return fmt.Errorf("[TransferFrom] serialize from error:%v", err) - } - if err := utils.WriteAddress(w, this.To); err != nil { - return fmt.Errorf("[TransferFrom] serialize to error:%v", err) - } - if err := utils.WriteVarUint(w, this.Value); err != nil { - return fmt.Errorf("[TransferFrom] serialize value error:%v", err) - } - return nil -} - func (this *TransferFrom) Serialization(sink *common.ZeroCopySink) { utils.EncodeAddress(sink, this.Sender) utils.EncodeAddress(sink, this.From) @@ -171,30 +93,6 @@ func (this *TransferFrom) Serialization(sink *common.ZeroCopySink) { utils.EncodeVarUint(sink, this.Value) } -func (this *TransferFrom) Deserialize(r io.Reader) error { - var err error - this.Sender, err = utils.ReadAddress(r) - if err != nil { - return fmt.Errorf("[TransferFrom] deserialize sender error:%v", err) - } - - this.From, err = utils.ReadAddress(r) - if err != nil { - return fmt.Errorf("[TransferFrom] deserialize from error:%v", err) - } - - this.To, err = utils.ReadAddress(r) - if err != nil { - return fmt.Errorf("[TransferFrom] deserialize to error:%v", err) - } - - this.Value, err = utils.ReadVarUint(r) - if err != nil { - return err - } - return nil -} - func (this *TransferFrom) Deserialization(source *common.ZeroCopySource) error { var err error this.Sender, err = utils.DecodeAddress(source) diff --git a/smartcontract/service/native/ont/states_test.go b/smartcontract/service/native/ont/states_test.go index 673dccc139..2dede3317f 100644 --- a/smartcontract/service/native/ont/states_test.go +++ b/smartcontract/service/native/ont/states_test.go @@ -19,7 +19,6 @@ package ont import ( - "bytes" "testing" "github.com/ontio/ontology/common" @@ -32,13 +31,12 @@ func TestState_Serialize(t *testing.T) { To: common.AddressFromVmCode([]byte{4, 5, 6}), Value: 1, } - bf := new(bytes.Buffer) - if err := state.Serialize(bf); err != nil { - t.Fatal("state serialize fail!") - } + sink := common.NewZeroCopySink(nil) + state.Serialization(sink) state2 := State{} - if err := state2.Deserialize(bf); err != nil { + source := common.NewZeroCopySource(sink.Bytes()) + if err := state2.Deserialization(source); err != nil { t.Fatal("state deserialize fail!") } diff --git a/smartcontract/service/native/ont/utils.go b/smartcontract/service/native/ont/utils.go index d877d78a88..2c0a198cf7 100644 --- a/smartcontract/service/native/ont/utils.go +++ b/smartcontract/service/native/ont/utils.go @@ -19,12 +19,10 @@ package ont import ( - "bytes" "fmt" "github.com/ontio/ontology/common" "github.com/ontio/ontology/common/config" - "github.com/ontio/ontology/common/serialization" cstates "github.com/ontio/ontology/core/states" "github.com/ontio/ontology/errors" "github.com/ontio/ontology/smartcontract/event" @@ -59,9 +57,9 @@ func AddNotifications(native *native.NativeService, contract common.Address, sta } func GetToUInt64StorageItem(toBalance, value uint64) *cstates.StorageItem { - bf := new(bytes.Buffer) - serialization.WriteUint64(bf, toBalance+value) - return &cstates.StorageItem{Value: bf.Bytes()} + sink := common.NewZeroCopySink(nil) + sink.WriteUint64(toBalance + value) + return &cstates.StorageItem{Value: sink.Bytes()} } func GenTotalSupplyKey(contract common.Address) []byte { diff --git a/smartcontract/service/native/ontid/attribute.go b/smartcontract/service/native/ontid/attribute.go index 3f0babb72b..3158f6b08e 100644 --- a/smartcontract/service/native/ontid/attribute.go +++ b/smartcontract/service/native/ontid/attribute.go @@ -18,14 +18,20 @@ package ontid import ( - "bytes" "errors" "fmt" - "io" - - "github.com/ontio/ontology/common/serialization" + "github.com/ontio/ontology/common" "github.com/ontio/ontology/smartcontract/service/native" "github.com/ontio/ontology/smartcontract/service/native/utils" + "io" +) + +const ( + MAX_KEY_SIZE = 32 + MAX_TYPE_SIZE = 16 + MAX_VALUE_SIZE = 512 * 1024 + + MAX_NUM = 100 ) type attribute struct { @@ -34,63 +40,76 @@ type attribute struct { valueType []byte } -func (this *attribute) Value() ([]byte, error) { - var buf bytes.Buffer - err := serialization.WriteVarBytes(&buf, this.value) - if err != nil { - return nil, err - } - err = serialization.WriteVarBytes(&buf, this.valueType) - if err != nil { - return nil, err - } - return buf.Bytes(), nil +func (this *attribute) Value() []byte { + sink := common.NewZeroCopySink(nil) + sink.WriteVarBytes(this.value) + sink.WriteVarBytes(this.valueType) + return sink.Bytes() } func (this *attribute) SetValue(data []byte) error { - buf := bytes.NewBuffer(data) - val, err := serialization.ReadVarBytes(buf) - if err != nil { - return err + source := common.NewZeroCopySource(data) + val, _, irregular, eof := source.NextVarBytes() + if irregular { + return common.ErrIrregularData } - vt, err := serialization.ReadVarBytes(buf) - if err != nil { - return err + if eof { + return io.ErrUnexpectedEOF + } + + vt, _, irregular, eof := source.NextVarBytes() + if irregular { + return common.ErrIrregularData } + if eof { + return io.ErrUnexpectedEOF + } + this.valueType = vt this.value = val return nil } -func (this *attribute) Serialize(w io.Writer) error { - err := serialization.WriteVarBytes(w, this.key) - if err != nil { - return err +func (this *attribute) Serialization(sink *common.ZeroCopySink) { + sink.WriteVarBytes(this.key) + sink.WriteVarBytes(this.valueType) + sink.WriteVarBytes(this.value) +} + +func (this *attribute) Deserialization(source *common.ZeroCopySource) error { + k, _, irregular, eof := source.NextVarBytes() + if irregular { + return common.ErrIrregularData } - err = serialization.WriteVarBytes(w, this.valueType) - if err != nil { - return err + if eof { + return io.ErrUnexpectedEOF } - err = serialization.WriteVarBytes(w, this.value) - if err != nil { - return err + if len(k) > MAX_KEY_SIZE { + return errors.New("key is too large") } - return nil -} -func (this *attribute) Deserialize(r io.Reader) error { - k, err := serialization.ReadVarBytes(r) - if err != nil { - return err + vt, _, irregular, eof := source.NextVarBytes() + if irregular { + return common.ErrIrregularData } - vt, err := serialization.ReadVarBytes(r) - if err != nil { - return err + if eof { + return io.ErrUnexpectedEOF } - v, err := serialization.ReadVarBytes(r) - if err != nil { - return err + if len(vt) > MAX_TYPE_SIZE { + return errors.New("type is too large") + } + + v, _, irregular, eof := source.NextVarBytes() + if irregular { + return common.ErrIrregularData + } + if eof { + return io.ErrUnexpectedEOF } + if len(v) > MAX_VALUE_SIZE { + return errors.New("value is too large") + } + this.key = k this.value = v this.valueType = vt @@ -99,11 +118,8 @@ func (this *attribute) Deserialize(r io.Reader) error { func insertOrUpdateAttr(srvc *native.NativeService, encID []byte, attr *attribute) error { key := append(encID, FIELD_ATTR) - val, err := attr.Value() - if err != nil { - return errors.New("serialize attribute value error: " + err.Error()) - } - err = utils.LinkedlistInsert(srvc, key, attr.key, val) + val := attr.Value() + err := utils.LinkedlistInsert(srvc, key, attr.key, val) if err != nil { return errors.New("store attribute error: " + err.Error()) } @@ -116,18 +132,36 @@ func findAttr(srvc *native.NativeService, encID, item []byte) (*utils.Linkedlist } func batchInsertAttr(srvc *native.NativeService, encID []byte, attr []attribute) error { - res := make([][]byte, len(attr)) for i, v := range attr { err := insertOrUpdateAttr(srvc, encID, &v) if err != nil { - return errors.New("store attributes error: " + err.Error()) + return fmt.Errorf("store attribute %d error: %s", i, err) } - res[i] = v.key + } + + key := append(encID, FIELD_ATTR) + n, err := utils.LinkedlistGetNumOfItems(srvc, key) + if err != nil { + return err + } + if n > MAX_NUM { + return fmt.Errorf("too many attributes, max is %d", MAX_NUM) } return nil } +func deleteAttr(srvc *native.NativeService, encID, path []byte) error { + key := append(encID, FIELD_ATTR) + ok, err := utils.LinkedlistDelete(srvc, key, path) + if err != nil { + return err + } else if !ok { + return errors.New("attribute not exist") + } + return nil +} + func getAllAttr(srvc *native.NativeService, encID []byte) ([]byte, error) { key := append(encID, FIELD_ATTR) item, err := utils.LinkedlistGetHead(srvc, key) @@ -138,7 +172,7 @@ func getAllAttr(srvc *native.NativeService, encID []byte) ([]byte, error) { return nil, nil } - var res bytes.Buffer + res := common.NewZeroCopySink(nil) var i uint16 = 0 for len(item) > 0 { node, err := utils.LinkedlistGetItem(srvc, key, item) @@ -154,13 +188,23 @@ func getAllAttr(srvc *native.NativeService, encID []byte) ([]byte, error) { return nil, fmt.Errorf("parse attribute failed, %s", err) } attr.key = item - err = attr.Serialize(&res) - if err != nil { - return nil, fmt.Errorf("serialize error, %s", err) - } + attr.Serialization(res) i += 1 item = node.GetNext() } return res.Bytes(), nil } + +func getAttrKeys(attr []attribute) [][]byte { + var paths = make([][]byte, 0) + for _, v := range attr { + paths = append(paths, v.key) + } + return paths +} + +func deleteAllAttr(srvc *native.NativeService, encID []byte) error { + key := append(encID, FIELD_ATTR) + return utils.LinkedlistDeleteAll(srvc, key) +} diff --git a/smartcontract/service/native/ontid/controller.go b/smartcontract/service/native/ontid/controller.go new file mode 100644 index 0000000000..e94ac81de3 --- /dev/null +++ b/smartcontract/service/native/ontid/controller.go @@ -0,0 +1,385 @@ +/* + * Copyright (C) 2018 The ontology Authors + * This file is part of The ontology library. + * + * The ontology is free software: you can redistribute it and/or modify + * it under the terms of the GNU Lesser General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * The ontology is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public License + * along with The ontology. If not, see . + */ +package ontid + +import ( + "errors" + "fmt" + + "github.com/ontio/ontology-crypto/keypair" + "github.com/ontio/ontology/account" + "github.com/ontio/ontology/common" + "github.com/ontio/ontology/core/states" + "github.com/ontio/ontology/smartcontract/service/native" + "github.com/ontio/ontology/smartcontract/service/native/utils" +) + +func regIdWithController(srvc *native.NativeService) ([]byte, error) { + source := common.NewZeroCopySource(srvc.Input) + // arg0: ID + arg0, err := utils.DecodeVarBytes(source) + if err != nil { + return utils.BYTE_FALSE, fmt.Errorf("argument 0 error") + } + + if !account.VerifyID(string(arg0)) { + return utils.BYTE_FALSE, fmt.Errorf("invalid ID") + } + + encId, err := encodeID(arg0) + if err != nil { + return utils.BYTE_FALSE, err + } + + if checkIDExistence(srvc, encId) { + return utils.BYTE_FALSE, fmt.Errorf("%s already registered", string(arg0)) + } + + // arg1: controller + arg1, err := utils.DecodeVarBytes(source) + if err != nil { + return utils.BYTE_FALSE, fmt.Errorf("argument 1 error") + } + + if account.VerifyID(string(arg1)) { + err = verifySingleController(srvc, arg1, source) + if err != nil { + return utils.BYTE_FALSE, err + } + } else { + controller, err := deserializeGroup(arg1) + if err != nil { + return utils.BYTE_FALSE, errors.New("deserialize controller error") + } + err = verifyGroupController(srvc, controller, source) + if err != nil { + return utils.BYTE_FALSE, err + } + } + + key := append(encId, FIELD_CONTROLLER) + utils.PutBytes(srvc, key, arg1) + + srvc.CacheDB.Put(encId, states.GenRawStorageItem([]byte{flag_exist})) + triggerRegisterEvent(srvc, arg0) + return utils.BYTE_TRUE, nil +} + +func revokeIDByController(srvc *native.NativeService) ([]byte, error) { + source := common.NewZeroCopySource(srvc.Input) + // arg0: id + arg0, err := utils.DecodeVarBytes(source) + if err != nil { + return utils.BYTE_FALSE, fmt.Errorf("argument 0 error") + } + + encID, err := encodeID(arg0) + if err != nil { + return utils.BYTE_FALSE, err + } + + if !checkIDExistence(srvc, encID) { + return utils.BYTE_FALSE, fmt.Errorf("%s is not registered or already revoked", string(arg0)) + } + + err = verifyControllerSignature(srvc, encID, source) + if err != nil { + return utils.BYTE_FALSE, fmt.Errorf("authorization failed") + } + + err = deleteID(srvc, encID) + if err != nil { + return utils.BYTE_FALSE, fmt.Errorf("delete id error, %s", err) + } + + newEvent(srvc, []interface{}{"Revoke", string(arg0)}) + return utils.BYTE_TRUE, nil +} + +func verifyController(srvc *native.NativeService) ([]byte, error) { + source := common.NewZeroCopySource(srvc.Input) + // arg0: ID + arg0, err := utils.DecodeVarBytes(source) + if err != nil { + return utils.BYTE_FALSE, fmt.Errorf("argument 0 error, %s", err) + } + + key, err := encodeID(arg0) + if err != nil { + return utils.BYTE_FALSE, err + } + + err = verifyControllerSignature(srvc, key, source) + if err == nil { + return utils.BYTE_TRUE, nil + } else { + return utils.BYTE_FALSE, fmt.Errorf("verification failed, %s", err) + } +} + +func removeController(srvc *native.NativeService) ([]byte, error) { + source := common.NewZeroCopySource(srvc.Input) + // arg0: id + arg0, err := utils.DecodeVarBytes(source) + if err != nil { + return utils.BYTE_FALSE, fmt.Errorf("argument 0 error") + } + // arg1: public key index + arg1, err := utils.DecodeVarUint(source) + if err != nil { + return utils.BYTE_FALSE, fmt.Errorf("argument 1 error") + } + encId, err := encodeID(arg0) + if err != nil { + return utils.BYTE_FALSE, err + } + pk, err := getPk(srvc, encId, uint32(arg1)) + if err != nil { + return utils.BYTE_FALSE, err + } + if pk.revoked { + return utils.BYTE_FALSE, fmt.Errorf("authentication failed, public key is removed") + } + err = checkWitness(srvc, pk.key) + if err != nil { + return utils.BYTE_FALSE, fmt.Errorf("checkWitness failed") + } + key := append(encId, FIELD_CONTROLLER) + srvc.CacheDB.Delete(key) + + newEvent(srvc, []interface{}{"RemoveController", string(arg0)}) + return utils.BYTE_TRUE, nil +} + +func addKeyByController(srvc *native.NativeService) ([]byte, error) { + source := common.NewZeroCopySource(srvc.Input) + // arg0: id + arg0, err := utils.DecodeVarBytes(source) + if err != nil { + return utils.BYTE_FALSE, fmt.Errorf("argument 0 error") + } + + // arg1: public key + arg1, err := utils.DecodeVarBytes(source) + if err != nil { + return utils.BYTE_FALSE, fmt.Errorf("argument 1 error") + } + _, err = keypair.DeserializePublicKey(arg1) + if err != nil { + return utils.BYTE_FALSE, fmt.Errorf("invalid key") + } + + encId, err := encodeID(arg0) + if err != nil { + return utils.BYTE_FALSE, err + } + + err = verifyControllerSignature(srvc, encId, source) + if err != nil { + return utils.BYTE_FALSE, fmt.Errorf("verification failed, %s", err) + } + + index, err := insertPk(srvc, encId, arg1) + if err != nil { + return utils.BYTE_FALSE, fmt.Errorf("insertion failed, %s", err) + } + + triggerPublicEvent(srvc, "add", arg0, arg1, index) + return utils.BYTE_TRUE, nil +} + +func removeKeyByController(srvc *native.NativeService) ([]byte, error) { + source := common.NewZeroCopySource(srvc.Input) + // arg0: id + arg0, err := utils.DecodeVarBytes(source) + if err != nil { + return utils.BYTE_FALSE, errors.New("argument 0") + } + + // arg1: public key index + arg1, err := utils.DecodeVarUint(source) + if err != nil { + return utils.BYTE_FALSE, errors.New("argument 1") + } + + encId, err := encodeID(arg0) + if err != nil { + return utils.BYTE_FALSE, errors.New(err.Error()) + } + + err = verifyControllerSignature(srvc, encId, source) + if err != nil { + return utils.BYTE_FALSE, errors.New("verifying signature failed") + } + + pk, err := revokePkByIndex(srvc, encId, uint32(arg1)) + if err != nil { + return utils.BYTE_FALSE, err + } + + triggerPublicEvent(srvc, "remove", arg0, pk, uint32(arg1)) + return utils.BYTE_TRUE, nil +} + +func addAttributesByController(srvc *native.NativeService) ([]byte, error) { + source := common.NewZeroCopySource(srvc.Input) + // arg0: id + arg0, err := utils.DecodeVarBytes(source) + if err != nil { + return utils.BYTE_FALSE, fmt.Errorf("argument 0 error") + } + + // arg1: attributes + num, err := utils.DecodeVarUint(source) + if err != nil { + return utils.BYTE_FALSE, fmt.Errorf("argument 1 error: %s", err) + } + var arg1 = make([]attribute, 0) + for i := 0; i < int(num); i++ { + var v attribute + err = v.Deserialization(source) + if err != nil { + return utils.BYTE_FALSE, fmt.Errorf("argument 1 error: %s", err) + } + arg1 = append(arg1, v) + } + + encId, err := encodeID(arg0) + if err != nil { + return utils.BYTE_FALSE, err + } + + err = verifyControllerSignature(srvc, encId, source) + if err != nil { + return utils.BYTE_FALSE, fmt.Errorf("verification failed, %s", err) + } + + err = batchInsertAttr(srvc, encId, arg1) + if err != nil { + return utils.BYTE_FALSE, fmt.Errorf("insert attributes error, %s", err) + } + + paths := getAttrKeys(arg1) + triggerAttributeEvent(srvc, "add", arg0, paths) + return utils.BYTE_TRUE, nil +} + +func removeAttributeByController(srvc *native.NativeService) ([]byte, error) { + source := common.NewZeroCopySource(srvc.Input) + // arg0: id + arg0, err := utils.DecodeVarBytes(source) + if err != nil { + return utils.BYTE_FALSE, errors.New("argument 0 error") + } + + // arg1: path + arg1, err := utils.DecodeVarBytes(source) + if err != nil { + return utils.BYTE_FALSE, errors.New("argument 1 error") + } + + encId, err := encodeID(arg0) + if err != nil { + return utils.BYTE_FALSE, err + } + + err = verifyControllerSignature(srvc, encId, source) + if err != nil { + return utils.BYTE_FALSE, errors.New("verifying signature failed") + } + + err = deleteAttr(srvc, encId, arg1) + if err != nil { + return utils.BYTE_FALSE, err + } + + triggerAttributeEvent(srvc, "remove", arg0, [][]byte{arg1}) + return utils.BYTE_TRUE, nil +} + +func getController(srvc *native.NativeService, encId []byte) (interface{}, error) { + key := append(encId, FIELD_CONTROLLER) + item, err := utils.GetStorageItem(srvc, key) + if err != nil { + return nil, err + } else if item == nil { + return nil, errors.New("empty controller storage") + } + + if account.VerifyID(string(item.Value)) { + return item.Value, nil + } else { + return deserializeGroup(item.Value) + } +} + +func verifySingleController(srvc *native.NativeService, id []byte, args *common.ZeroCopySource) error { + // public key index + index, err := utils.DecodeVarUint(args) + if err != nil { + return fmt.Errorf("index error, %s", err) + } + encId, err := encodeID(id) + if err != nil { + return err + } + pk, err := getPk(srvc, encId, uint32(index)) + if err != nil { + return err + } + if pk.revoked { + return fmt.Errorf("revoked key") + } + err = checkWitness(srvc, pk.key) + if err != nil { + return err + } + return nil +} + +func verifyGroupController(srvc *native.NativeService, group *Group, args *common.ZeroCopySource) error { + // signers + buf, err := utils.DecodeVarBytes(args) + if err != nil { + return fmt.Errorf("signers error, %s", err) + } + signers, err := deserializeSigners(buf) + if err != nil { + return fmt.Errorf("signers error, %s", err) + } + if !verifyGroupSignature(srvc, group, signers) { + return fmt.Errorf("verification failed") + } + return nil +} + +func verifyControllerSignature(srvc *native.NativeService, encId []byte, args *common.ZeroCopySource) error { + ctrl, err := getController(srvc, encId) + if err != nil { + return err + } + + switch t := ctrl.(type) { + case []byte: + return verifySingleController(srvc, t, args) + case *Group: + return verifyGroupController(srvc, t, args) + default: + return fmt.Errorf("unknown controller type") + } +} diff --git a/smartcontract/service/native/ontid/event.go b/smartcontract/service/native/ontid/event.go index 6bf09df710..d082353b65 100644 --- a/smartcontract/service/native/ontid/event.go +++ b/smartcontract/service/native/ontid/event.go @@ -20,7 +20,6 @@ package ontid import ( "encoding/hex" - "github.com/ontio/ontology/common" "github.com/ontio/ontology/smartcontract/event" "github.com/ontio/ontology/smartcontract/service/native" ) @@ -56,8 +55,3 @@ func triggerAttributeEvent(srvc *native.NativeService, op string, id []byte, pat st := []interface{}{"Attribute", op, string(id), attr} newEvent(srvc, st) } - -func triggerRecoveryEvent(srvc *native.NativeService, op string, id []byte, addr common.Address) { - st := []string{"Recovery", op, string(id), addr.ToHexString()} - newEvent(srvc, st) -} diff --git a/smartcontract/service/native/ontid/group.go b/smartcontract/service/native/ontid/group.go new file mode 100644 index 0000000000..1f5e9123ca --- /dev/null +++ b/smartcontract/service/native/ontid/group.go @@ -0,0 +1,200 @@ +/* + * Copyright (C) 2018 The ontology Authors + * This file is part of The ontology library. + * + * The ontology is free software: you can redistribute it and/or modify + * it under the terms of the GNU Lesser General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * The ontology is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public License + * along with The ontology. If not, see . + */ +package ontid + +import ( + "bytes" + "encoding/json" + "fmt" + + "github.com/ontio/ontology/common" + "github.com/ontio/ontology/smartcontract/service/native" + "github.com/ontio/ontology/smartcontract/service/native/utils" +) + +const MAX_DEPTH = 8 + +// Group defines a group control logic +type Group struct { + Members []interface{} `json:"members"` + Threshold uint `json:"threshold"` +} + +func (g *Group) ToJson() []byte { + j, _ := json.Marshal(g) + return j +} + +func rDeserialize(data []byte, depth uint) (*Group, error) { + if depth == MAX_DEPTH { + return nil, fmt.Errorf("recursion is too deep") + } + + g := Group{} + buf := common.NewZeroCopySource(data) + + // parse members + num, err := utils.DecodeVarUint(buf) + if err != nil { + return nil, fmt.Errorf("error parsing number: %s", err) + } + + for i := uint64(0); i < num; i++ { + m, err := utils.DecodeVarBytes(buf) + if err != nil { + return nil, fmt.Errorf("error parsing group members: %s", err) + } + if len(m) > 8 && bytes.Equal(m[:8], []byte("did:ont:")) { + g.Members = append(g.Members, m) + } else { + // parse recursively + g1, err := rDeserialize(m, depth+1) + if err != nil { + return nil, fmt.Errorf("error parsing subgroup: %s", err) + } + g.Members = append(g.Members, g1) + } + } + + // parse threshold + t, err := utils.DecodeVarUint(buf) + if err != nil { + return nil, fmt.Errorf("error parsing group threshold: %s", err) + } + if t > uint64(len(g.Members)) { + return nil, fmt.Errorf("invalid threshold") + } + + g.Threshold = uint(t) + + return &g, nil +} + +func deserializeGroup(data []byte) (*Group, error) { + return rDeserialize(data, 0) +} + +func validateMembers(srvc *native.NativeService, g *Group) error { + for _, m := range g.Members { + switch t := m.(type) { + case []byte: + key, err := encodeID(t) + if err != nil { + return fmt.Errorf("invalid id: %s", string(t)) + } + // ID must exists + if !checkIDExistence(srvc, key) { + return fmt.Errorf("id %s not registered", string(t)) + } + // Group member must have its own public key + pk, err := getPk(srvc, key, 1) + if err != nil || pk == nil { + return fmt.Errorf("id %s has no public keys", string(t)) + } + case *Group: + if err := validateMembers(srvc, t); err != nil { + return err + } + default: + panic("group member type error") + } + } + return nil +} + +type Signer struct { + id []byte + index uint32 +} + +func deserializeSigners(data []byte) ([]Signer, error) { + buf := common.NewZeroCopySource(data) + num, err := utils.DecodeVarUint(buf) + if err != nil { + return nil, err + } + + signers := []Signer{} + for i := uint64(0); i < num; i++ { + id, err := utils.DecodeVarBytes(buf) + if err != nil { + return nil, err + } + index, err := utils.DecodeVarUint(buf) + if err != nil { + return nil, err + } + + signer := Signer{id, uint32(index)} + signers = append(signers, signer) + } + + return signers, nil +} + +func findSigner(id []byte, signers []Signer) bool { + for _, signer := range signers { + if bytes.Equal(signer.id, id) { + return true + } + } + return false +} + +func verifyThreshold(g *Group, signers []Signer) bool { + var signed uint = 0 + for _, member := range g.Members { + switch t := member.(type) { + case []byte: + if findSigner(t, signers) { + signed += 1 + } + case *Group: + if verifyThreshold(t, signers) { + signed += 1 + } + default: + panic("invalid group member type") + } + } + return signed >= g.Threshold +} + +func verifyGroupSignature(srvc *native.NativeService, g *Group, signers []Signer) bool { + if !verifyThreshold(g, signers) { + return false + } + + for _, signer := range signers { + key, err := encodeID(signer.id) + if err != nil { + return false + } + pk, err := getPk(srvc, key, signer.index) + if err != nil { + return false + } + if pk.revoked { + return false + } + if checkWitness(srvc, pk.key) != nil { + return false + } + } + return true +} diff --git a/smartcontract/service/native/ontid/group_test.go b/smartcontract/service/native/ontid/group_test.go new file mode 100644 index 0000000000..8a5c02a31e --- /dev/null +++ b/smartcontract/service/native/ontid/group_test.go @@ -0,0 +1,172 @@ +/* + * Copyright (C) 2018 The ontology Authors + * This file is part of The ontology library. + * + * The ontology is free software: you can redistribute it and/or modify + * it under the terms of the GNU Lesser General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * The ontology is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public License + * along with The ontology. If not, see . + */ +package ontid + +import ( + "bytes" + "encoding/hex" + "fmt" + "testing" + + "github.com/ontio/ontology/core/states" + "github.com/ontio/ontology/core/store/leveldbstore" + "github.com/ontio/ontology/core/store/overlaydb" + "github.com/ontio/ontology/smartcontract/service/native" + "github.com/ontio/ontology/smartcontract/storage" +) + +func TestDeserializeGroup(t *testing.T) { + id0 := []byte("did:ont:ARY2ekof1eCSetcimGdjqyzUYaVDDPVWmw") + id1 := []byte("did:ont:ASbxtSqrpmydpjqCUGDiQp2mzsfd4zFArs") + id2 := []byte("did:ont:AGxc3cdeB6QFvmZXzWhGwzuvohNtqaaaDw") + g_ := &Group{ + Threshold: 1, + Members: []interface{}{ + id0, + &Group{ + Threshold: 2, + Members: []interface{}{ + id1, + id2, + }, + }, + }, + } + + data, _ := hex.DecodeString("01022a6469643a6f6e743a41525932656b6f6631654353657463696d47646a71797a5559615644445056576d775a01022a6469643a6f6e743a4153627874537172706d7964706a7143554744695170326d7a736664347a464172732a6469643a6f6e743a414778633363646542365146766d5a587a576847777a75766f684e7471616161447701020101") + + g, err := deserializeGroup(data) + if err != nil { + t.Fatal(err) + } + + err = groupCmp(g_, g) + if err != nil { + t.Fatal(err) + } + + memback, _ := leveldbstore.NewMemLevelDBStore() + overlay := overlaydb.NewOverlayDB(memback) + cache := storage.NewCacheDB(overlay) + + srvc := new(native.NativeService) + srvc.CacheDB = cache + + key, _ := encodeID(id0) + insertPk(srvc, key, []byte("test pk")) + cache.Put(key, states.GenRawStorageItem([]byte{flag_exist})) + key, _ = encodeID(id1) + insertPk(srvc, key, []byte("test pk")) + cache.Put(key, states.GenRawStorageItem([]byte{flag_exist})) + key, _ = encodeID(id2) + insertPk(srvc, key, []byte("test pk")) + cache.Put(key, states.GenRawStorageItem([]byte{flag_exist})) + + err = validateMembers(srvc, g) + if err != nil { + t.Fatal("validateMembers failed") + } +} + +func groupCmp(a, b *Group) error { + if a.Threshold != b.Threshold { + return fmt.Errorf("error threshold") + } + if len(a.Members) != len(b.Members) { + return fmt.Errorf("error number of members") + } + for i := 0; i < len(a.Members); i++ { + switch ma := a.Members[i].(type) { + case []byte: + mb, ok := b.Members[i].([]byte) + if !ok { + return fmt.Errorf("m%d: type error, ont id expected", i) + } + if !bytes.Equal(ma, mb) { + return fmt.Errorf("m%d: mismatched id", i) + } + case *Group: + mb, ok := b.Members[i].(*Group) + if !ok { + return fmt.Errorf("m%d: type error, subgroup expected", i) + } + err := groupCmp(ma, mb) + if err != nil { + return fmt.Errorf("m%d:%s", i, err) + } + default: + return fmt.Errorf("error type") + } + } + return nil +} + +func TestDeserializeGroup1(t *testing.T) { + data, _ := hex.DecodeString("01022a6469643a6f6e743a4153627874537172706d7964706a7143554744695170326d7a736664347a464172732a6469643a6f6e743a414778633363646542365146766d5a587a576847777a75766f684e747161616144770103") + _, err := deserializeGroup(data) + if err == nil { + t.Fatal("deserializeGroup should fail due to the invalid threshold") + } +} + +func TestDeserializeGroup2(t *testing.T) { + data, _ := hex.DecodeString("010203646964086469643a6f6e740101") + _, err := deserializeGroup(data) + if err == nil { + t.Fatal("deserializeGroup should fail due to invalid member data") + } +} + +func TestSigners(t *testing.T) { + id0 := []byte("did:ont:ARY2ekof1eCSetcimGdjqyzUYaVDDPVWmw") + id1 := []byte("did:ont:ASbxtSqrpmydpjqCUGDiQp2mzsfd4zFArs") + id2 := []byte("did:ont:AGxc3cdeB6QFvmZXzWhGwzuvohNtqaaaDw") + g := &Group{ + Threshold: 1, + Members: []interface{}{ + id0, + &Group{ + Threshold: 2, + Members: []interface{}{ + id1, + id2, + }, + }, + }, + } + + data, _ := hex.DecodeString("01022a6469643a6f6e743a4153627874537172706d7964706a7143554744695170326d7a736664347a4641727301012a6469643a6f6e743a414778633363646542365146766d5a587a576847777a75766f684e747161616144770101") + signers, err := deserializeSigners(data) + if err != nil { + t.Fatal(err) + } + + if !verifyThreshold(g, signers) { + t.Fatal("verifyThreshold failed") + } + + data, _ = hex.DecodeString("01012a6469643a6f6e743a4153627874537172706d7964706a7143554744695170326d7a736664347a464172730101") + signers, err = deserializeSigners(data) + if err != nil { + t.Fatal(err) + } + + if verifyThreshold(g, signers) { + t.Fatal("verifyThreshold should fail") + } +} diff --git a/smartcontract/service/native/ontid/init.go b/smartcontract/service/native/ontid/init.go index 661dbed84a..6cc3762c46 100644 --- a/smartcontract/service/native/ontid/init.go +++ b/smartcontract/service/native/ontid/init.go @@ -28,14 +28,25 @@ func Init() { func RegisterIDContract(srvc *native.NativeService) { srvc.Register("regIDWithPublicKey", regIdWithPublicKey) - srvc.Register("addKey", addKey) - srvc.Register("removeKey", removeKey) + srvc.Register("regIDWithController", regIdWithController) + srvc.Register("revokeID", revokeID) + srvc.Register("revokeIDByController", revokeIDByController) + srvc.Register("removeController", removeController) srvc.Register("addRecovery", addRecovery) srvc.Register("changeRecovery", changeRecovery) + srvc.Register("addKey", addKey) + srvc.Register("removeKey", removeKey) + srvc.Register("addKeyByController", addKeyByController) + srvc.Register("removeKeyByController", removeKeyByController) + srvc.Register("addKeyByRecovery", addKeyByRecovery) + srvc.Register("removeKeyByRecovery", removeKeyByRecovery) srvc.Register("regIDWithAttributes", regIdWithAttributes) srvc.Register("addAttributes", addAttributes) srvc.Register("removeAttribute", removeAttribute) + srvc.Register("addAttributesByController", addAttributesByController) + srvc.Register("removeAttributeByController", removeAttributeByController) srvc.Register("verifySignature", verifySignature) + srvc.Register("verifyController", verifyController) srvc.Register("getPublicKeys", GetPublicKeys) srvc.Register("getKeyState", GetKeyState) srvc.Register("getAttributes", GetAttributes) diff --git a/smartcontract/service/native/ontid/method.go b/smartcontract/service/native/ontid/method.go index 8292708b39..ff0dc79c5d 100644 --- a/smartcontract/service/native/ontid/method.go +++ b/smartcontract/service/native/ontid/method.go @@ -18,15 +18,14 @@ package ontid import ( - "bytes" "encoding/hex" "errors" "fmt" "github.com/ontio/ontology-crypto/keypair" "github.com/ontio/ontology/account" + "github.com/ontio/ontology/common" "github.com/ontio/ontology/common/log" - "github.com/ontio/ontology/common/serialization" "github.com/ontio/ontology/core/states" "github.com/ontio/ontology/core/types" "github.com/ontio/ontology/smartcontract/service/native" @@ -37,21 +36,21 @@ func regIdWithPublicKey(srvc *native.NativeService) ([]byte, error) { log.Debug("registerIdWithPublicKey") log.Debug("srvc.Input:", srvc.Input) // parse arguments - args := bytes.NewBuffer(srvc.Input) + source := common.NewZeroCopySource(srvc.Input) // arg0: ID - arg0, err := serialization.ReadVarBytes(args) + arg0, err := utils.DecodeVarBytes(source) if err != nil { return utils.BYTE_FALSE, errors.New("register ONT ID error: parsing argument 0 failed") } else if len(arg0) == 0 { return utils.BYTE_FALSE, errors.New("register ONT ID error: invalid length of argument 0") } + log.Debug("arg 0:", hex.EncodeToString(arg0), string(arg0)) // arg1: public key - arg1, err := serialization.ReadVarBytes(args) + arg1, err := utils.DecodeVarBytes(source) if err != nil { return utils.BYTE_FALSE, errors.New("register ONT ID error: parsing argument 1 failed") } - log.Debug("arg 0:", hex.EncodeToString(arg0), string(arg0)) log.Debug("arg 1:", hex.EncodeToString(arg1)) if len(arg0) == 0 || len(arg1) == 0 { @@ -96,9 +95,10 @@ func regIdWithPublicKey(srvc *native.NativeService) ([]byte, error) { func regIdWithAttributes(srvc *native.NativeService) ([]byte, error) { // parse arguments - args := bytes.NewBuffer(srvc.Input) + source := common.NewZeroCopySource(srvc.Input) // arg0: ID - arg0, err := serialization.ReadVarBytes(args) + arg0, err := utils.DecodeVarBytes(source) + if err != nil { return utils.BYTE_FALSE, errors.New("register ID with attributes error: argument 0 error, " + err.Error()) } else if len(arg0) == 0 { @@ -109,21 +109,22 @@ func regIdWithAttributes(srvc *native.NativeService) ([]byte, error) { } // arg1: public key - arg1, err := serialization.ReadVarBytes(args) + arg1, err := utils.DecodeVarBytes(source) if err != nil { - return utils.BYTE_FALSE, errors.New("register ID with attributes error: argument 1 error, " + err.Error()) + return utils.BYTE_FALSE, errors.New("register ID with attributes error: argument 1 error," + err.Error()) } // arg2: attributes // first get number - num, err := utils.ReadVarUint(args) + num, err := utils.DecodeVarUint(source) if err != nil { return utils.BYTE_FALSE, errors.New("register ID with attributes error: argument 2 error, " + err.Error()) } + // next parse each attribute var arg2 = make([]attribute, 0) for i := 0; i < int(num); i++ { var v attribute - err = v.Deserialize(args) + err = v.Deserialization(source) if err != nil { return utils.BYTE_FALSE, errors.New("register ID with attributes error: argument 2 error, " + err.Error()) } @@ -164,23 +165,27 @@ func regIdWithAttributes(srvc *native.NativeService) ([]byte, error) { func addKey(srvc *native.NativeService) ([]byte, error) { log.Debug("ID contract: AddKey") - args := bytes.NewBuffer(srvc.Input) + source := common.NewZeroCopySource(srvc.Input) // arg0: id - arg0, err := serialization.ReadVarBytes(args) + arg0, err := utils.DecodeVarBytes(source) if err != nil { return utils.BYTE_FALSE, errors.New("add key failed: argument 0 error, " + err.Error()) } log.Debug("arg 0:", hex.EncodeToString(arg0)) // arg1: public key - arg1, err := serialization.ReadVarBytes(args) + arg1, err := utils.DecodeVarBytes(source) if err != nil { return utils.BYTE_FALSE, errors.New("add key failed: argument 1 error, " + err.Error()) } log.Debug("arg 1:", hex.EncodeToString(arg1)) + _, err = keypair.DeserializePublicKey(arg1) + if err != nil { + return utils.BYTE_FALSE, errors.New("add key error: invalid key") + } - // arg2: operator's public key / address - arg2, err := serialization.ReadVarBytes(args) + // arg2: operator's public key + arg2, err := utils.DecodeVarBytes(source) if err != nil { return utils.BYTE_FALSE, errors.New("add key failed: argument 2 error, " + err.Error()) } @@ -194,18 +199,8 @@ func addKey(srvc *native.NativeService) ([]byte, error) { if err != nil { return utils.BYTE_FALSE, errors.New("add key failed: " + err.Error()) } - if !checkIDExistence(srvc, key) { - return utils.BYTE_FALSE, errors.New("add key failed: ID not registered") - } - var auth bool = false - rec, _ := getRecovery(srvc, key) - if len(rec) > 0 { - auth = bytes.Equal(rec, arg2) - } - if !auth { - if !isOwner(srvc, key, arg2) { - return utils.BYTE_FALSE, errors.New("add key failed: operator has no authorization") - } + if !isOwner(srvc, key, arg2) { + return utils.BYTE_FALSE, errors.New("add key failed: operator has no authorization") } item, _, err := findPk(srvc, key, arg1) @@ -224,21 +219,21 @@ func addKey(srvc *native.NativeService) ([]byte, error) { } func removeKey(srvc *native.NativeService) ([]byte, error) { - args := bytes.NewBuffer(srvc.Input) + source := common.NewZeroCopySource(srvc.Input) // arg0: id - arg0, err := serialization.ReadVarBytes(args) + arg0, err := utils.DecodeVarBytes(source) if err != nil { return utils.BYTE_FALSE, fmt.Errorf("remove key failed: argument 0 error, %s", err) } // arg1: public key - arg1, err := serialization.ReadVarBytes(args) + arg1, err := utils.DecodeVarBytes(source) if err != nil { return utils.BYTE_FALSE, fmt.Errorf("remove key failed: argument 1 error, %s", err) } - // arg2: operator's public key / address - arg2, err := serialization.ReadVarBytes(args) + // arg2: operator's public key + arg2, err := utils.DecodeVarBytes(source) if err != nil { return utils.BYTE_FALSE, fmt.Errorf("remove key failed: argument 2 error, %s", err) } @@ -253,15 +248,8 @@ func removeKey(srvc *native.NativeService) ([]byte, error) { if !checkIDExistence(srvc, key) { return utils.BYTE_FALSE, errors.New("remove key failed: ID not registered") } - var auth = false - rec, err := getRecovery(srvc, key) - if len(rec) > 0 { - auth = bytes.Equal(rec, arg2) - } - if !auth { - if !isOwner(srvc, key, arg2) { - return utils.BYTE_FALSE, errors.New("remove key failed: operator has no authorization") - } + if !isOwner(srvc, key, arg2) { + return utils.BYTE_FALSE, errors.New("remove key failed: operator has no authorization") } keyID, err := revokePk(srvc, key, arg1) @@ -274,110 +262,16 @@ func removeKey(srvc *native.NativeService) ([]byte, error) { return utils.BYTE_TRUE, nil } -func addRecovery(srvc *native.NativeService) ([]byte, error) { - args := bytes.NewBuffer(srvc.Input) - // arg0: ID - arg0, err := serialization.ReadVarBytes(args) - if err != nil { - return utils.BYTE_FALSE, errors.New("add recovery failed: argument 0 error") - } - // arg1: recovery address - arg1, err := utils.ReadAddress(args) - if err != nil { - return utils.BYTE_FALSE, errors.New("add recovery failed: argument 1 error") - } - // arg2: operator's public key - arg2, err := serialization.ReadVarBytes(args) - if err != nil { - return utils.BYTE_FALSE, errors.New("add recovery failed: argument 2 error") - } - - err = checkWitness(srvc, arg2) - if err != nil { - return utils.BYTE_FALSE, errors.New("add recovery failed: " + err.Error()) - } - - key, err := encodeID(arg0) - if err != nil { - return utils.BYTE_FALSE, errors.New("add recovery failed: " + err.Error()) - } - if !checkIDExistence(srvc, key) { - return utils.BYTE_FALSE, errors.New("add recovery failed: ID not registered") - } - if !isOwner(srvc, key, arg2) { - return utils.BYTE_FALSE, errors.New("add recovery failed: not authorized") - } - - re, err := getRecovery(srvc, key) - if err == nil && len(re) > 0 { - return utils.BYTE_FALSE, errors.New("add recovery failed: already set recovery") - } - - err = setRecovery(srvc, key, arg1) - if err != nil { - return utils.BYTE_FALSE, errors.New("add recovery failed: " + err.Error()) - } - - triggerRecoveryEvent(srvc, "add", arg0, arg1) - - return utils.BYTE_TRUE, nil -} - -func changeRecovery(srvc *native.NativeService) ([]byte, error) { - args := bytes.NewBuffer(srvc.Input) - // arg0: ID - arg0, err := serialization.ReadVarBytes(args) - if err != nil { - return utils.BYTE_FALSE, errors.New("change recovery failed: argument 0 error") - } - // arg1: new recovery address - arg1, err := utils.ReadAddress(args) - if err != nil { - return utils.BYTE_FALSE, errors.New("change recovery failed: argument 1 error") - } - // arg2: operator's address, who should be the old recovery - arg2, err := utils.ReadAddress(args) - if err != nil { - return utils.BYTE_FALSE, errors.New("change recovery failed: argument 2 error") - } - - key, err := encodeID(arg0) - if err != nil { - return utils.BYTE_FALSE, errors.New("change recovery failed: " + err.Error()) - } - re, err := getRecovery(srvc, key) - if err != nil { - return utils.BYTE_FALSE, errors.New("change recovery failed: recovery not set") - } - if !bytes.Equal(re, arg2[:]) { - return utils.BYTE_FALSE, errors.New("change recovery failed: operator is not the recovery") - } - err = checkWitness(srvc, arg2[:]) - if err != nil { - return utils.BYTE_FALSE, errors.New("change recovery failed: " + err.Error()) - } - if !checkIDExistence(srvc, key) { - return utils.BYTE_FALSE, errors.New("change recovery failed: ID not registered") - } - err = setRecovery(srvc, key, arg1) - if err != nil { - return utils.BYTE_FALSE, errors.New("change recovery failed: " + err.Error()) - } - - triggerRecoveryEvent(srvc, "change", arg0, arg1) - return utils.BYTE_TRUE, nil -} - func addAttributes(srvc *native.NativeService) ([]byte, error) { - args := bytes.NewBuffer(srvc.Input) + source := common.NewZeroCopySource(srvc.Input) // arg0: ID - arg0, err := serialization.ReadVarBytes(args) + arg0, err := utils.DecodeVarBytes(source) if err != nil { - return utils.BYTE_FALSE, fmt.Errorf("add attributes failed, argument 0 error: %s", err) + return utils.BYTE_FALSE, fmt.Errorf("add attributes failed, argument 0, error: %s", err) } // arg1: attributes // first get number - num, err := utils.ReadVarUint(args) + num, err := utils.DecodeVarUint(source) if err != nil { return utils.BYTE_FALSE, fmt.Errorf("add attributes failed, argument 1 error: %s", err) } @@ -385,16 +279,16 @@ func addAttributes(srvc *native.NativeService) ([]byte, error) { var arg1 = make([]attribute, 0) for i := 0; i < int(num); i++ { var v attribute - err = v.Deserialize(args) + err = v.Deserialization(source) if err != nil { return utils.BYTE_FALSE, fmt.Errorf("add attributes failed, argument 1 error: %s", err) } arg1 = append(arg1, v) } // arg2: opperator's public key - arg2, err := serialization.ReadVarBytes(args) + arg2, err := utils.DecodeVarBytes(source) if err != nil { - return utils.BYTE_FALSE, fmt.Errorf("add attributes failed, argument 2 error: %s", err) + return utils.BYTE_FALSE, fmt.Errorf("add attributes failed, argument 2, error: %s", err) } key, err := encodeID(arg0) @@ -417,28 +311,25 @@ func addAttributes(srvc *native.NativeService) ([]byte, error) { return utils.BYTE_FALSE, fmt.Errorf("add attributes failed, %s", err) } - var paths = make([][]byte, 0) - for _, v := range arg1 { - paths = append(paths, v.key) - } + paths := getAttrKeys(arg1) triggerAttributeEvent(srvc, "add", arg0, paths) return utils.BYTE_TRUE, nil } func removeAttribute(srvc *native.NativeService) ([]byte, error) { - args := bytes.NewBuffer(srvc.Input) + args := common.NewZeroCopySource(srvc.Input) // arg0: ID - arg0, err := serialization.ReadVarBytes(args) + arg0, err := utils.DecodeVarBytes(args) if err != nil { return utils.BYTE_FALSE, errors.New("remove attribute failed: argument 0 error") } // arg1: path - arg1, err := serialization.ReadVarBytes(args) + arg1, err := utils.DecodeVarBytes(args) if err != nil { return utils.BYTE_FALSE, errors.New("remove attribute failed: argument 1 error") } // arg2: operator's public key - arg2, err := serialization.ReadVarBytes(args) + arg2, err := utils.DecodeVarBytes(args) if err != nil { return utils.BYTE_FALSE, errors.New("remove attribute failed: argument 2 error") } @@ -458,12 +349,9 @@ func removeAttribute(srvc *native.NativeService) ([]byte, error) { return utils.BYTE_FALSE, errors.New("remove attribute failed: no authorization") } - key1 := append(key, FIELD_ATTR) - ok, err := utils.LinkedlistDelete(srvc, key1, arg1) + err = deleteAttr(srvc, key, arg1) if err != nil { - return utils.BYTE_FALSE, errors.New("remove attribute failed: delete error, " + err.Error()) - } else if !ok { - return utils.BYTE_FALSE, errors.New("remove attribute failed: attribute not exist") + return utils.BYTE_FALSE, errors.New("remove attribute failed: " + err.Error()) } triggerAttributeEvent(srvc, "remove", arg0, [][]byte{arg1}) @@ -471,14 +359,14 @@ func removeAttribute(srvc *native.NativeService) ([]byte, error) { } func verifySignature(srvc *native.NativeService) ([]byte, error) { - args := bytes.NewBuffer(srvc.Input) + source := common.NewZeroCopySource(srvc.Input) // arg0: ID - arg0, err := serialization.ReadVarBytes(args) + arg0, err := utils.DecodeVarBytes(source) if err != nil { - return utils.BYTE_FALSE, errors.New("verify signature error: argument 0 error, " + err.Error()) + return utils.BYTE_FALSE, errors.New("verify signature error: argument 0 error, error: " + err.Error()) } // arg1: index of public key - arg1, err := utils.ReadVarUint(args) + arg1, err := utils.DecodeVarUint(source) if err != nil { return utils.BYTE_FALSE, errors.New("verify signature error: argument 1 error, " + err.Error()) } @@ -492,6 +380,8 @@ func verifySignature(srvc *native.NativeService) ([]byte, error) { return utils.BYTE_FALSE, errors.New("verify signature error: get key failed, " + err.Error()) } else if owner == nil { return utils.BYTE_FALSE, errors.New("verify signature error: public key not found") + } else if owner.revoked { + return utils.BYTE_FALSE, errors.New("verify signature error: revoked key") } err = checkWitness(srvc, owner.key) @@ -501,3 +391,44 @@ func verifySignature(srvc *native.NativeService) ([]byte, error) { return utils.BYTE_TRUE, nil } + +func revokeID(srvc *native.NativeService) ([]byte, error) { + source := common.NewZeroCopySource(srvc.Input) + // arg0: id + arg0, err := utils.DecodeVarBytes(source) + if err != nil { + return utils.BYTE_FALSE, fmt.Errorf("argument 0 error") + } + // arg1: index of public key + arg1, err := utils.DecodeVarUint(source) + if err != nil { + return utils.BYTE_FALSE, fmt.Errorf("argument 1 error") + } + + encID, err := encodeID(arg0) + if err != nil { + return utils.BYTE_FALSE, err + } + + if !checkIDExistence(srvc, encID) { + return utils.BYTE_FALSE, fmt.Errorf("%s is not registered or already revoked", string(arg0)) + } + + pk, err := getPk(srvc, encID, uint32(arg1)) + if err != nil { + return utils.BYTE_FALSE, fmt.Errorf("get public key error: %s", err) + } else if pk.revoked { + return utils.BYTE_FALSE, fmt.Errorf("revoked key") + } + + if checkWitness(srvc, pk.key) != nil { + return utils.BYTE_FALSE, fmt.Errorf("authorization failed") + } + + err = deleteID(srvc, encID) + if err != nil { + return utils.BYTE_FALSE, fmt.Errorf("delete id error, %s", err) + } + newEvent(srvc, []interface{}{"Revoke", string(arg0)}) + return utils.BYTE_TRUE, nil +} diff --git a/smartcontract/service/native/ontid/owner.go b/smartcontract/service/native/ontid/owner.go index 7c387d5c11..0f1cc66cc6 100644 --- a/smartcontract/service/native/ontid/owner.go +++ b/smartcontract/service/native/ontid/owner.go @@ -21,36 +21,32 @@ import ( "bytes" "errors" "fmt" - "io" + "github.com/ontio/ontology/common" "github.com/ontio/ontology/common/log" - "github.com/ontio/ontology/common/serialization" "github.com/ontio/ontology/core/states" "github.com/ontio/ontology/smartcontract/service/native" "github.com/ontio/ontology/smartcontract/service/native/utils" ) +const OWNER_TOTAL_SIZE = 1024 * 1024 // 1MB + type owner struct { key []byte revoked bool } -func (this *owner) Serialize(w io.Writer) error { - if err := serialization.WriteVarBytes(w, this.key); err != nil { - return err - } - if err := serialization.WriteBool(w, this.revoked); err != nil { - return err - } - return nil +func (this *owner) Serialization(sink *common.ZeroCopySink) { + sink.WriteVarBytes(this.key) + sink.WriteBool(this.revoked) } -func (this *owner) Deserialize(r io.Reader) error { - v1, err := serialization.ReadVarBytes(r) +func (this *owner) Deserialization(source *common.ZeroCopySource) error { + v1, err := utils.DecodeVarBytes(source) if err != nil { return err } - v2, err := serialization.ReadBool(r) + v2, err := utils.DecodeBool(source) if err != nil { return err } @@ -67,11 +63,11 @@ func getAllPk(srvc *native.NativeService, key []byte) ([]*owner, error) { if val == nil { return nil, nil } - buf := bytes.NewBuffer(val.Value) + source := common.NewZeroCopySource(val.Value) owners := make([]*owner, 0) - for buf.Len() > 0 { + for source.Len() > 0 { var t = new(owner) - err = t.Deserialize(buf) + err = t.Deserialization(source) if err != nil { return nil, fmt.Errorf("deserialize owners error, %s", err) } @@ -81,15 +77,15 @@ func getAllPk(srvc *native.NativeService, key []byte) ([]*owner, error) { } func putAllPk(srvc *native.NativeService, key []byte, val []*owner) error { - var buf bytes.Buffer + sink := common.NewZeroCopySink(nil) for _, i := range val { - err := i.Serialize(&buf) - if err != nil { - return fmt.Errorf("serialize owner error, %s", err) - } + i.Serialization(sink) } var v states.StorageItem - v.Value = buf.Bytes() + v.Value = sink.Bytes() + if len(v.Value) > OWNER_TOTAL_SIZE { + return errors.New("total key size is out of range") + } srvc.CacheDB.Put(key, v.ToArray()) return nil } @@ -101,11 +97,6 @@ func insertPk(srvc *native.NativeService, encID, pk []byte) (uint32, error) { owners = make([]*owner, 0) } size := len(owners) - if size >= 0xFFFFFFFF { - //FIXME currently the limit is for all the keys, including the - // revoked ones. - return 0, errors.New("reach the max limit, cannot add more keys") - } owners = append(owners, &owner{pk, false}) err = putAllPk(srvc, key, owners) if err != nil { @@ -159,11 +150,26 @@ func revokePk(srvc *native.NativeService, encID, pub []byte) (uint32, error) { if index == 0 { return 0, errors.New("revoke failed, public key not found") } - err = putAllPk(srvc, key, owners) + putAllPk(srvc, key, owners) + return index, nil +} + +func revokePkByIndex(srvc *native.NativeService, encID []byte, index uint32) ([]byte, error) { + key := append(encID, FIELD_PK) + owners, err := getAllPk(srvc, key) if err != nil { - return 0, err + return nil, err } - return index, nil + if uint32(len(owners)) < index { + return nil, errors.New("no such key") + } + index -= 1 + if owners[index].revoked { + return nil, errors.New("already revoked") + } + owners[index].revoked = true + putAllPk(srvc, key, owners) + return owners[index].key, nil } func isOwner(srvc *native.NativeService, encID, pub []byte) bool { diff --git a/smartcontract/service/native/ontid/query.go b/smartcontract/service/native/ontid/query.go index f0c31647f6..42a429e2f7 100644 --- a/smartcontract/service/native/ontid/query.go +++ b/smartcontract/service/native/ontid/query.go @@ -18,26 +18,25 @@ package ontid import ( - "bytes" "encoding/hex" "errors" "fmt" + "github.com/ontio/ontology/common" "github.com/ontio/ontology/common/log" - "github.com/ontio/ontology/common/serialization" "github.com/ontio/ontology/smartcontract/service/native" "github.com/ontio/ontology/smartcontract/service/native/utils" ) func GetPublicKeyByID(srvc *native.NativeService) ([]byte, error) { - args := bytes.NewBuffer(srvc.Input) + args := common.NewZeroCopySource(srvc.Input) // arg0: ID - arg0, err := serialization.ReadVarBytes(args) + arg0, err := utils.DecodeVarBytes(args) if err != nil { return nil, errors.New("get public key failed: argument 0 error") } // arg1: key ID - arg1, err := serialization.ReadUint32(args) + arg1, err := utils.DecodeUint32(args) if err != nil { return nil, errors.New("get public key failed: argument 1 error") } @@ -61,34 +60,62 @@ func GetPublicKeyByID(srvc *native.NativeService) ([]byte, error) { func GetDDO(srvc *native.NativeService) ([]byte, error) { log.Debug("GetDDO") + // keys var0, err := GetPublicKeys(srvc) if err != nil { return nil, fmt.Errorf("get DDO error: %s", err) - } else if var0 == nil { - log.Debug("DDO: null") - return nil, nil } - var buf bytes.Buffer - serialization.WriteVarBytes(&buf, var0) + sink := common.NewZeroCopySink(nil) + sink.WriteVarBytes(var0) + + // attributes var1, err := GetAttributes(srvc) - serialization.WriteVarBytes(&buf, var1) + if err != nil { + return nil, fmt.Errorf("get attribute error, %s", err) + } + sink.WriteVarBytes(var1) - args := bytes.NewBuffer(srvc.Input) - did, _ := serialization.ReadVarBytes(args) - key, _ := encodeID(did) - var2, err := getRecovery(srvc, key) - serialization.WriteVarBytes(&buf, var2) + source := common.NewZeroCopySource(srvc.Input) + did, err := utils.DecodeVarBytes(source) + if err != nil { + return nil, fmt.Errorf("get id error, %s", err) + } + key, err := encodeID(did) + if err != nil { + return nil, err + } + + // controller + con, err := getController(srvc, key) + var2 := []byte{} + if err == nil { + switch t := con.(type) { + case []byte: + var2 = t + case *Group: + var2 = t.ToJson() + } + } + sink.WriteVarBytes(var2) + + //recovery + var3 := []byte{} + rec, err := getRecovery(srvc, key) + if rec != nil && err == nil { + var3 = rec.ToJson() + } + sink.WriteVarBytes(var3) - res := buf.Bytes() + res := sink.Bytes() log.Debug("DDO:", hex.EncodeToString(res)) return res, nil } func GetPublicKeys(srvc *native.NativeService) ([]byte, error) { log.Debug("GetPublicKeys") - args := bytes.NewBuffer(srvc.Input) - did, err := serialization.ReadVarBytes(args) + args := common.NewZeroCopySource(srvc.Input) + did, err := utils.DecodeVarBytes(args) if err != nil { return nil, fmt.Errorf("get public keys error: invalid argument, %s", err) } @@ -107,28 +134,22 @@ func GetPublicKeys(srvc *native.NativeService) ([]byte, error) { return nil, nil } - var res bytes.Buffer + sink := common.NewZeroCopySink(nil) for i, v := range list { if v.revoked { continue } - err = serialization.WriteUint32(&res, uint32(i+1)) - if err != nil { - return nil, fmt.Errorf("get public keys error: %s", err) - } - err = serialization.WriteVarBytes(&res, v.key) - if err != nil { - return nil, fmt.Errorf("get public keys error: %s", err) - } + sink.WriteUint32(uint32(i + 1)) + sink.WriteVarBytes(v.key) } - return res.Bytes(), nil + return sink.Bytes(), nil } func GetAttributes(srvc *native.NativeService) ([]byte, error) { log.Debug("GetAttributes") - args := bytes.NewBuffer(srvc.Input) - did, err := serialization.ReadVarBytes(args) + source := common.NewZeroCopySource(srvc.Input) + did, err := utils.DecodeVarBytes(source) if err != nil { return nil, fmt.Errorf("get public keys error: invalid argument, %s", err) } @@ -149,14 +170,14 @@ func GetAttributes(srvc *native.NativeService) ([]byte, error) { func GetKeyState(srvc *native.NativeService) ([]byte, error) { log.Debug("GetKeyState") - args := bytes.NewBuffer(srvc.Input) + source := common.NewZeroCopySource(srvc.Input) // arg0: ID - arg0, err := serialization.ReadVarBytes(args) - if err != nil { - return nil, fmt.Errorf("get key state failed: argument 0 error, %s", err) + arg0, _, irregular, eof := source.NextVarBytes() + if irregular || eof { + return nil, fmt.Errorf("get key state failed: argument 0 error") } // arg1: public key ID - arg1, err := utils.ReadVarUint(args) + arg1, err := utils.DecodeVarUint(source) if err != nil { return nil, fmt.Errorf("get key state failed: argument 1 error, %s", err) } diff --git a/smartcontract/service/native/ontid/recovery.go b/smartcontract/service/native/ontid/recovery.go new file mode 100644 index 0000000000..8a9eff13a3 --- /dev/null +++ b/smartcontract/service/native/ontid/recovery.go @@ -0,0 +1,235 @@ +/* + * Copyright (C) 2018 The ontology Authors + * This file is part of The ontology library. + * + * The ontology is free software: you can redistribute it and/or modify + * it under the terms of the GNU Lesser General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * The ontology is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public License + * along with The ontology. If not, see . + */ +package ontid + +import ( + "errors" + "fmt" + + "github.com/ontio/ontology/common" + "github.com/ontio/ontology/smartcontract/service/native" + "github.com/ontio/ontology/smartcontract/service/native/utils" +) + +func addRecovery(srvc *native.NativeService) ([]byte, error) { + source := common.NewZeroCopySource(srvc.Input) + // arg0: ID + arg0, err := utils.DecodeVarBytes(source) + if err != nil { + return utils.BYTE_FALSE, errors.New("add recovery failed: argument 0 error") + } + // arg1: recovery struct + arg1, err := utils.DecodeVarBytes(source) + if err != nil { + return utils.BYTE_FALSE, errors.New("add recovery failed: argument 1 error") + } + // arg2: operator's public key index + arg2, err := utils.DecodeVarUint(source) + if err != nil { + return utils.BYTE_FALSE, errors.New("add recovery failed: argument 2 error") + } + + encId, err := encodeID(arg0) + if err != nil { + return utils.BYTE_FALSE, errors.New("add recovery failed: " + err.Error()) + } + pk, err := getPk(srvc, encId, uint32(arg2)) + if err != nil { + return utils.BYTE_FALSE, err + } + if pk.revoked { + return utils.BYTE_FALSE, errors.New("authentication failed, public key is revoked") + } + err = checkWitness(srvc, pk.key) + if err != nil { + return utils.BYTE_FALSE, errors.New("checkWitness failed") + } + + re, err := getRecovery(srvc, encId) + if err == nil && re != nil { + return utils.BYTE_FALSE, errors.New("recovery is already set") + } + + re, err = setRecovery(srvc, encId, arg1) + if err != nil { + return utils.BYTE_FALSE, errors.New("add recovery failed: " + err.Error()) + } + + newEvent(srvc, []interface{}{"recovery", "add", string(arg0), re.ToJson()}) + return utils.BYTE_TRUE, nil +} + +func changeRecovery(srvc *native.NativeService) ([]byte, error) { + source := common.NewZeroCopySource(srvc.Input) + // arg0: ID + arg0, err := utils.DecodeVarBytes(source) + if err != nil { + return utils.BYTE_FALSE, errors.New("argument 0 error") + } + // arg1: new recovery + arg1, err := utils.DecodeVarBytes(source) + if err != nil { + return utils.BYTE_FALSE, errors.New("argument 1 error") + } + // arg2: signers + arg2, err := utils.DecodeVarBytes(source) + if err != nil { + return utils.BYTE_FALSE, errors.New("argument 2 error") + } + + key, err := encodeID(arg0) + if err != nil { + return utils.BYTE_FALSE, errors.New("change recovery failed: " + err.Error()) + } + re, err := getRecovery(srvc, key) + if err != nil { + return utils.BYTE_FALSE, errors.New("change recovery failed: recovery not set") + } + signers, err := deserializeSigners(arg2) + if err != nil { + return utils.BYTE_FALSE, errors.New("signers error: " + err.Error()) + } + + if !verifyGroupSignature(srvc, re, signers) { + return utils.BYTE_FALSE, errors.New("verification failed") + } + re, err = setRecovery(srvc, key, arg1) + if err != nil { + return utils.BYTE_FALSE, errors.New("change recovery failed: " + err.Error()) + } + + newEvent(srvc, []interface{}{"Recovery", "change", string(arg0), re.ToJson()}) + return utils.BYTE_TRUE, nil +} + +func addKeyByRecovery(srvc *native.NativeService) ([]byte, error) { + source := common.NewZeroCopySource(srvc.Input) + // arg0: id + arg0, err := utils.DecodeVarBytes(source) + if err != nil { + return utils.BYTE_FALSE, errors.New("argument 0 error") + } + // arg1: public key + arg1, err := utils.DecodeVarBytes(source) + if err != nil { + return utils.BYTE_FALSE, errors.New("argument 1 error") + } + // arg2: signers + arg2, err := utils.DecodeVarBytes(source) + if err != nil { + return utils.BYTE_FALSE, errors.New("argument 2 error") + } + + encId, err := encodeID(arg0) + if err != nil { + return utils.BYTE_FALSE, err + } + + signers, err := deserializeSigners(arg2) + if err != nil { + return utils.BYTE_FALSE, fmt.Errorf("signers error, %s", err) + } + + rec, err := getRecovery(srvc, encId) + if err != nil { + return utils.BYTE_FALSE, err + } + + if !verifyGroupSignature(srvc, rec, signers) { + return utils.BYTE_FALSE, errors.New("verification failed") + } + + index, err := insertPk(srvc, encId, arg1) + if err != nil { + return utils.BYTE_FALSE, err + } + + triggerPublicEvent(srvc, "add", arg0, arg1, index) + return utils.BYTE_TRUE, nil +} + +func removeKeyByRecovery(srvc *native.NativeService) ([]byte, error) { + source := common.NewZeroCopySource(srvc.Input) + // arg0: id + arg0, err := utils.DecodeVarBytes(source) + if err != nil { + return utils.BYTE_FALSE, errors.New("argument 0 error") + } + // arg1: public key index + arg1, err := utils.DecodeVarUint(source) + if err != nil { + return utils.BYTE_FALSE, errors.New("argument 1 error") + } + // arg2: signers + arg2, err := utils.DecodeVarBytes(source) + if err != nil { + return utils.BYTE_FALSE, errors.New("argument 2 error") + } + + encId, err := encodeID(arg0) + if err != nil { + return utils.BYTE_FALSE, err + } + + signers, err := deserializeSigners(arg2) + if err != nil { + return utils.BYTE_FALSE, fmt.Errorf("signers error, %s", err) + } + + rec, err := getRecovery(srvc, encId) + if err != nil { + return utils.BYTE_FALSE, err + } + + if !verifyGroupSignature(srvc, rec, signers) { + return utils.BYTE_FALSE, errors.New("verification failed") + } + + pk, err := revokePkByIndex(srvc, encId, uint32(arg1)) + if err != nil { + return utils.BYTE_FALSE, err + } + + triggerPublicEvent(srvc, "remove", arg0, pk, uint32(arg1)) + return utils.BYTE_TRUE, nil +} + +func setRecovery(srvc *native.NativeService, encID, data []byte) (*Group, error) { + rec, err := deserializeGroup(data) + if err != nil { + return nil, err + } + err = validateMembers(srvc, rec) + if err != nil { + return nil, fmt.Errorf("invalid recovery member, %s", err) + } + key := append(encID, FIELD_RECOVERY) + utils.PutBytes(srvc, key, data) + return rec, nil +} + +func getRecovery(srvc *native.NativeService, encID []byte) (*Group, error) { + key := append(encID, FIELD_RECOVERY) + item, err := utils.GetStorageItem(srvc, key) + if err != nil { + return nil, err + } else if item == nil { + return nil, errors.New("empty storage item") + } + return deserializeGroup(item.Value) +} diff --git a/smartcontract/service/native/ontid/utils.go b/smartcontract/service/native/ontid/utils.go index 9a5fd60014..3336a19324 100644 --- a/smartcontract/service/native/ontid/utils.go +++ b/smartcontract/service/native/ontid/utils.go @@ -22,7 +22,6 @@ import ( "errors" "github.com/ontio/ontology-crypto/keypair" - com "github.com/ontio/ontology/common" "github.com/ontio/ontology/core/states" "github.com/ontio/ontology/core/types" "github.com/ontio/ontology/smartcontract/service/native" @@ -43,14 +42,16 @@ func checkIDExistence(srvc *native.NativeService, encID []byte) bool { } const ( - flag_exist = 0x01 + flag_exist byte = 0x01 + flag_revoke byte = 0x02 FIELD_VERSION byte = 0 FLAG_VERSION byte = 0x01 - FIELD_PK byte = 1 - FIELD_ATTR byte = 2 - FIELD_RECOVERY byte = 3 + FIELD_PK byte = 1 + FIELD_ATTR byte = 2 + FIELD_RECOVERY byte = 3 + FIELD_CONTROLLER byte = 4 ) func encodeID(id []byte) ([]byte, error) { @@ -73,26 +74,7 @@ func decodeID(data []byte) ([]byte, error) { return data[prefix+1:], nil } -func setRecovery(srvc *native.NativeService, encID []byte, recovery com.Address) error { - key := append(encID, FIELD_RECOVERY) - val := states.StorageItem{Value: recovery[:]} - srvc.CacheDB.Put(key, val.ToArray()) - return nil -} - -func getRecovery(srvc *native.NativeService, encID []byte) ([]byte, error) { - key := append(encID, FIELD_RECOVERY) - item, err := utils.GetStorageItem(srvc, key) - if err != nil { - return nil, errors.New("get recovery error: " + err.Error()) - } else if item == nil { - return nil, nil - } - return item.Value, nil -} - func checkWitness(srvc *native.NativeService, key []byte) error { - // try as if key is a public key pk, err := keypair.DeserializePublicKey(key) if err == nil { addr := types.AddressFromPubKey(pk) @@ -101,11 +83,25 @@ func checkWitness(srvc *native.NativeService, key []byte) error { } } - // try as if key is an address - addr, err := com.AddressParseFromBytes(key) - if srvc.ContextRef.CheckWitness(addr) { - return nil + return errors.New("check witness failed, " + hex.EncodeToString(key)) +} + +func deleteID(srvc *native.NativeService, encID []byte) error { + key := append(encID, FIELD_PK) + srvc.CacheDB.Delete(key) + + key = append(encID, FIELD_CONTROLLER) + srvc.CacheDB.Delete(key) + + key = append(encID, FIELD_RECOVERY) + srvc.CacheDB.Delete(key) + + err := deleteAllAttr(srvc, encID) + if err != nil { + return err } - return errors.New("check witness failed, " + hex.EncodeToString(key)) + //set flag to revoke + utils.PutBytes(srvc, encID, []byte{flag_revoke}) + return nil } diff --git a/smartcontract/service/native/utils/linked_list.go b/smartcontract/service/native/utils/linked_list.go index eb812f689a..6404e9397b 100644 --- a/smartcontract/service/native/utils/linked_list.go +++ b/smartcontract/service/native/utils/linked_list.go @@ -18,10 +18,8 @@ package utils import ( - "bytes" - "fmt" - "github.com/ontio/ontology/common/serialization" + "github.com/ontio/ontology/common" cstates "github.com/ontio/ontology/core/states" "github.com/ontio/ontology/errors" "github.com/ontio/ontology/smartcontract/service/native" @@ -47,37 +45,31 @@ func (this *LinkedlistNode) GetPayload() []byte { func makeLinkedlistNode(next []byte, prev []byte, payload []byte) ([]byte, error) { node := &LinkedlistNode{next: next, prev: prev, payload: payload} - node_bytes, err := node.Serialize() + node_bytes, err := node.Serialization() if err != nil { return nil, err } return node_bytes, nil } -func (this *LinkedlistNode) Serialize() ([]byte, error) { - bf := new(bytes.Buffer) - if err := serialization.WriteVarBytes(bf, this.next); err != nil { - return nil, errors.NewDetailErr(err, errors.ErrNoCode, "[linked list] serialize next error!") - } - if err := serialization.WriteVarBytes(bf, this.prev); err != nil { - return nil, errors.NewDetailErr(err, errors.ErrNoCode, "[linked list] serialize prev error!") - } - if err := serialization.WriteVarBytes(bf, this.payload); err != nil { - return nil, errors.NewDetailErr(err, errors.ErrNoCode, "[linked list] serialize payload error!") - } - return bf.Bytes(), nil +func (this *LinkedlistNode) Serialization() ([]byte, error) { + sink := common.NewZeroCopySink(nil) + sink.WriteVarBytes(this.next) + sink.WriteVarBytes(this.prev) + sink.WriteVarBytes(this.payload) + return sink.Bytes(), nil } -func (this *LinkedlistNode) Deserialize(r []byte) error { - bf := bytes.NewReader(r) - next, err := serialization.ReadVarBytes(bf) +func (this *LinkedlistNode) Deserialization(r []byte) error { + source := common.NewZeroCopySource(r) + next, err := DecodeVarBytes(source) if err != nil { return errors.NewDetailErr(err, errors.ErrNoCode, "[linked list] deserialize next error!") } - prev, err := serialization.ReadVarBytes(bf) + prev, err := DecodeVarBytes(source) if err != nil { return errors.NewDetailErr(err, errors.ErrNoCode, "[linked list] deserialize prev error!") } - payload, err := serialization.ReadVarBytes(bf) + payload, err := DecodeVarBytes(source) if err != nil { return errors.NewDetailErr(err, errors.ErrNoCode, "[linked list] deserialize payload error!") } @@ -118,7 +110,7 @@ func getListNode(native *native.NativeService, index []byte, item []byte) (*Link if len(rawNode) == 0 { return nil, nil } - err = node.Deserialize(rawNode) + err = node.Deserialization(rawNode) if err != nil { //log.Tracef("[index: %s, item: %s] error %s", hex.EncodeToString(index), hex.EncodeToString(item), err) return nil, err @@ -271,6 +263,7 @@ func LinkedlistGetHead(native *native.NativeService, index []byte) ([]byte, erro } return head, nil } + func LinkedlistGetNumOfItems(native *native.NativeService, index []byte) (int, error) { n := 0 head, err := getListHead(native, index) @@ -288,3 +281,21 @@ func LinkedlistGetNumOfItems(native *native.NativeService, index []byte) (int, e } return n, nil } + +func LinkedlistDeleteAll(native *native.NativeService, index []byte) error { + head, err := getListHead(native, index) + if err != nil { + return err + } + q := head + for q != nil { + qnode, err := getListNode(native, index, q) + if err != nil { + return err + } + native.CacheDB.Delete(append(index, q...)) + q = qnode.next + } + native.CacheDB.Delete(index) + return nil +} diff --git a/smartcontract/service/native/utils/serialization.go b/smartcontract/service/native/utils/serialization.go index c69ac3db90..bfcc2feb4a 100644 --- a/smartcontract/service/native/utils/serialization.go +++ b/smartcontract/service/native/utils/serialization.go @@ -24,43 +24,8 @@ import ( "math/big" "github.com/ontio/ontology/common" - "github.com/ontio/ontology/common/serialization" ) -func WriteVarUint(w io.Writer, value uint64) error { - if err := serialization.WriteVarBytes(w, common.BigIntToNeoBytes(big.NewInt(int64(value)))); err != nil { - return fmt.Errorf("serialize value error:%v", err) - } - return nil -} - -func ReadVarUint(r io.Reader) (uint64, error) { - value, err := serialization.ReadVarBytes(r) - if err != nil { - return 0, fmt.Errorf("deserialize value error:%v", err) - } - v := common.BigIntFromNeoBytes(value) - if v.Cmp(big.NewInt(0)) < 0 { - return 0, fmt.Errorf("%s", "value should not be a negative number.") - } - return v.Uint64(), nil -} - -func WriteAddress(w io.Writer, address common.Address) error { - if err := serialization.WriteVarBytes(w, address[:]); err != nil { - return fmt.Errorf("serialize value error:%v", err) - } - return nil -} - -func ReadAddress(r io.Reader) (common.Address, error) { - from, err := serialization.ReadVarBytes(r) - if err != nil { - return common.Address{}, fmt.Errorf("[State] deserialize from error:%v", err) - } - return common.AddressParseFromBytes(from) -} - func EncodeAddress(sink *common.ZeroCopySink, addr common.Address) (size uint64) { return sink.WriteVarBytes(addr[:]) } @@ -95,3 +60,49 @@ func DecodeAddress(source *common.ZeroCopySource) (common.Address, error) { return common.AddressParseFromBytes(from) } +func DecodeVarBytes(source *common.ZeroCopySource) ([]byte, error) { + data, _, irregular, eof := source.NextVarBytes() + if eof { + return nil, io.ErrUnexpectedEOF + } + if irregular { + return nil, common.ErrIrregularData + } + + return data, nil +} +func DecodeUint64(source *common.ZeroCopySource) (uint64, error) { + data, eof := source.NextUint64() + if eof { + return 0, io.ErrUnexpectedEOF + } + return data, nil +} +func DecodeUint32(source *common.ZeroCopySource) (uint32, error) { + data, eof := source.NextUint32() + if eof { + return 0, io.ErrUnexpectedEOF + } + return data, nil +} +func DecodeBool(source *common.ZeroCopySource) (bool, error) { + data, irregular, eof := source.NextBool() + if eof { + return false, io.ErrUnexpectedEOF + } + if irregular { + return false, common.ErrIrregularData + } + return data, nil +} +func DecodeString(source *common.ZeroCopySource) (string, error) { + data, _, irregular, eof := source.NextString() + if eof { + return "", io.ErrUnexpectedEOF + } + if irregular { + return "", common.ErrIrregularData + } + + return data, nil +} diff --git a/smartcontract/service/native/utils/store.go b/smartcontract/service/native/utils/store.go index f3aa44ccf8..db26dd261d 100644 --- a/smartcontract/service/native/utils/store.go +++ b/smartcontract/service/native/utils/store.go @@ -20,6 +20,7 @@ package utils import ( "bytes" + "github.com/ontio/ontology/common" "github.com/ontio/ontology/common/serialization" cstates "github.com/ontio/ontology/core/states" "github.com/ontio/ontology/errors" @@ -35,7 +36,7 @@ func GetStorageItem(native *native.NativeService, key []byte) (*cstates.StorageI return nil, nil } item := new(cstates.StorageItem) - err = item.Deserialize(bytes.NewBuffer(store)) + err = item.Deserialization(common.NewZeroCopySource(store)) if err != nil { return nil, errors.NewDetailErr(err, errors.ErrNoCode, "[GetStorageItem] instance doesn't StorageItem!") } @@ -73,9 +74,9 @@ func GetStorageUInt32(native *native.NativeService, key []byte) (uint32, error) } func GenUInt64StorageItem(value uint64) *cstates.StorageItem { - bf := new(bytes.Buffer) - serialization.WriteUint64(bf, value) - return &cstates.StorageItem{Value: bf.Bytes()} + sink := common.NewZeroCopySink(nil) + sink.WriteUint64(value) + return &cstates.StorageItem{Value: sink.Bytes()} } func GenUInt32StorageItem(value uint32) *cstates.StorageItem { diff --git a/smartcontract/service/wasmvm/runtime.go b/smartcontract/service/wasmvm/runtime.go index 818014ca48..a289fdf545 100644 --- a/smartcontract/service/wasmvm/runtime.go +++ b/smartcontract/service/wasmvm/runtime.go @@ -27,7 +27,6 @@ import ( "github.com/go-interpreter/wagon/wasm" "github.com/ontio/ontology/common" "github.com/ontio/ontology/common/log" - "github.com/ontio/ontology/common/serialization" "github.com/ontio/ontology/core/payload" "github.com/ontio/ontology/core/types" "github.com/ontio/ontology/errors" @@ -38,6 +37,7 @@ import ( "github.com/ontio/ontology/smartcontract/states" "github.com/ontio/ontology/vm/crossvm_codec" neotypes "github.com/ontio/ontology/vm/neovm/types" + "io" ) type ContractType byte @@ -259,20 +259,25 @@ func CallContract(proc *exec.Process, contractAddr uint32, inputPtr uint32, inpu switch contracttype { case NATIVE_CONTRACT: - bf := bytes.NewBuffer(inputs) - ver, err := serialization.ReadByte(bf) - if err != nil { - panic(err) + source := common.NewZeroCopySource(inputs) + ver, eof := source.NextByte() + if eof { + panic(io.ErrUnexpectedEOF) } - - method, err := serialization.ReadString(bf) - if err != nil { - panic(err) + method, _, irregular, eof := source.NextString() + if irregular { + panic(common.ErrIrregularData) + } + if eof { + panic(io.ErrUnexpectedEOF) } - args, err := serialization.ReadVarBytes(bf) - if err != nil { - panic(err) + args, _, irregular, eof := source.NextVarBytes() + if irregular { + panic(common.ErrIrregularData) + } + if eof { + panic(io.ErrUnexpectedEOF) } contract := states.ContractInvokeParam{ @@ -302,10 +307,9 @@ func CallContract(proc *exec.Process, contractAddr uint32, inputPtr uint32, inpu case WASMVM_CONTRACT: conParam := states.WasmContractParam{Address: contractAddress, Args: inputs} - sink := common.NewZeroCopySink(nil) - conParam.Serialization(sink) + param := common.SerializeToBytes(&conParam) - newservice, err := self.Service.ContextRef.NewExecuteEngine(sink.Bytes(), types.InvokeWasm) + newservice, err := self.Service.ContextRef.NewExecuteEngine(param, types.InvokeWasm) if err != nil { panic(err) } diff --git a/smartcontract/states/contract.go b/smartcontract/states/contract.go index 87acd7e5ad..bec34b03f4 100644 --- a/smartcontract/states/contract.go +++ b/smartcontract/states/contract.go @@ -22,8 +22,6 @@ import ( "io" "github.com/ontio/ontology/common" - "github.com/ontio/ontology/common/serialization" - "github.com/ontio/ontology/errors" "github.com/ontio/ontology/smartcontract/event" ) @@ -39,23 +37,6 @@ type ContractInvokeParam struct { Args []byte } -// Serialize contract -func (this *ContractInvokeParam) Serialize(w io.Writer) error { - if err := serialization.WriteByte(w, this.Version); err != nil { - return errors.NewDetailErr(err, errors.ErrNoCode, "[ContractInvokeParam] Version serialize error!") - } - if err := this.Address.Serialize(w); err != nil { - return errors.NewDetailErr(err, errors.ErrNoCode, "[ContractInvokeParam] Address serialize error!") - } - if err := serialization.WriteVarBytes(w, []byte(this.Method)); err != nil { - return errors.NewDetailErr(err, errors.ErrNoCode, "[ContractInvokeParam] Method serialize error!") - } - if err := serialization.WriteVarBytes(w, this.Args); err != nil { - return errors.NewDetailErr(err, errors.ErrNoCode, "[ContractInvokeParam] Args serialize error!") - } - return nil -} - func (this *ContractInvokeParam) Serialization(sink *common.ZeroCopySink) { sink.WriteByte(this.Version) sink.WriteAddress(this.Address) @@ -63,31 +44,6 @@ func (this *ContractInvokeParam) Serialization(sink *common.ZeroCopySink) { sink.WriteVarBytes([]byte(this.Args)) } -// Deserialize contract -func (this *ContractInvokeParam) Deserialize(r io.Reader) error { - var err error - this.Version, err = serialization.ReadByte(r) - if err != nil { - return errors.NewDetailErr(err, errors.ErrNoCode, "[ContractInvokeParam] Version deserialize error!") - } - - if err := this.Address.Deserialize(r); err != nil { - return errors.NewDetailErr(err, errors.ErrNoCode, "[ContractInvokeParam] Address deserialize error!") - } - - method, err := serialization.ReadVarBytes(r) - if err != nil { - return errors.NewDetailErr(err, errors.ErrNoCode, "[ContractInvokeParam] Method deserialize error!") - } - this.Method = string(method) - - this.Args, err = serialization.ReadVarBytes(r) - if err != nil { - return errors.NewDetailErr(err, errors.ErrNoCode, "[ContractInvokeParam] Args deserialize error!") - } - return nil -} - // `ContractInvokeParam.Args` has reference of `source` func (this *ContractInvokeParam) Deserialization(source *common.ZeroCopySource) error { var irregular, eof bool diff --git a/smartcontract/states/contract_test.go b/smartcontract/states/contract_test.go index 041abaac0e..d47f60dc28 100644 --- a/smartcontract/states/contract_test.go +++ b/smartcontract/states/contract_test.go @@ -18,7 +18,6 @@ package states import ( - "bytes" "testing" "github.com/ontio/ontology/common" @@ -33,13 +32,12 @@ func TestContract_Serialize_Deserialize(t *testing.T) { Method: "init", Args: []byte{2}, } - bf := new(bytes.Buffer) - if err := c.Serialize(bf); err != nil { - t.Fatalf("ContractInvokeParam serialize error: %v", err) - } + sink := common.NewZeroCopySink(nil) + c.Serialization(sink) v := new(ContractInvokeParam) - if err := v.Deserialize(bf); err != nil { + source := common.NewZeroCopySource(sink.Bytes()) + if err := v.Deserialization(source); err != nil { t.Fatalf("ContractInvokeParam deserialize error: %v", err) } } diff --git a/smartcontract/test/panic_test.go b/smartcontract/test/panic_test.go index 3d6ee6e283..4e95b2ee31 100644 --- a/smartcontract/test/panic_test.go +++ b/smartcontract/test/panic_test.go @@ -25,8 +25,8 @@ import ( "os" "testing" + "github.com/ontio/ontology/common" "github.com/ontio/ontology/common/log" - "github.com/ontio/ontology/common/serialization" "github.com/ontio/ontology/core/types" . "github.com/ontio/ontology/smartcontract" neovm2 "github.com/ontio/ontology/smartcontract/service/neovm" @@ -110,19 +110,19 @@ func TestOpReadMemAttack(t *testing.T) { bf := new(bytes.Buffer) builder := neovm.NewParamsBuilder(bf) builder.Emit(neovm.SYSCALL) - bs := bytes.NewBuffer(builder.ToArray()) + sink := common.NewZeroCopySink(builder.ToArray()) builder.EmitPushByteArray([]byte(neovm2.NATIVE_INVOKE_NAME)) l := 0X7fffffc7 - 1 - serialization.WriteVarUint(bs, uint64(l)) + sink.WriteVarUint(uint64(l)) b := make([]byte, 4) - bs.Write(b) + sink.WriteBytes(b) sc := SmartContract{ Config: config, Gas: 100000, CacheDB: nil, } - engine, _ := sc.NewExecuteEngine(bs.Bytes(), types.InvokeNeo) + engine, _ := sc.NewExecuteEngine(sink.Bytes(), types.InvokeNeo) _, err := engine.Invoke() assert.NotNil(t, err) diff --git a/txnpool/proc/txnpool_server.go b/txnpool/proc/txnpool_server.go index ae5edfaf71..e9cc6755b2 100644 --- a/txnpool/proc/txnpool_server.go +++ b/txnpool/proc/txnpool_server.go @@ -21,7 +21,6 @@ package proc import ( - "bytes" "encoding/hex" "fmt" "github.com/ontio/ontology-eventbus/actor" @@ -117,7 +116,7 @@ func getGlobalGasPrice() (uint64, error) { return 0, fmt.Errorf("decode result error %v", err) } - err = queriedParams.Deserialize(bytes.NewBuffer([]byte(data))) + err = queriedParams.Deserialization(common.NewZeroCopySource([]byte(data))) if err != nil { return 0, fmt.Errorf("deserialize result error %v", err) } diff --git a/validator/db/store.go b/validator/db/store.go index aa020f16e9..80c5f10b54 100644 --- a/validator/db/store.go +++ b/validator/db/store.go @@ -24,7 +24,6 @@ import ( "sync" "github.com/ontio/ontology/common" - "github.com/ontio/ontology/common/serialization" storcomm "github.com/ontio/ontology/core/store/common" leveldb "github.com/ontio/ontology/core/store/leveldbstore" "github.com/ontio/ontology/core/types" @@ -162,11 +161,10 @@ func (self *Store) saveTransaction(tx *types.Transaction, height uint32) error { // generate key with DATA_TRANSACTION prefix key := GenDataTransactionKey(tx.Hash()) defer keyPool.Put(key) - value := valuePool.Get() - defer valuePool.Put(value) - serialization.WriteUint32(value, height) - tx.Serialize(value) + value := common.NewZeroCopySink(nil) + value.WriteUint32(height) + tx.Serialization(value) // put value self.db.BatchPut(key.Bytes(), value.Bytes()) diff --git a/vm/neovm/params_builder.go b/vm/neovm/params_builder.go index 763cc2af6f..8db72a1223 100644 --- a/vm/neovm/params_builder.go +++ b/vm/neovm/params_builder.go @@ -47,15 +47,16 @@ func (p *ParamsBuilder) EmitPushBool(data bool) { } func (p *ParamsBuilder) EmitPushInteger(data *big.Int) { - if data.Int64() == -1 { + if data.Cmp(big.NewInt(int64(-1))) == 0 { p.Emit(PUSHM1) return } - if data.Int64() == 0 { + if data.Sign() == 0 { p.Emit(PUSH0) return } - if data.Int64() > 0 && data.Int64() < 16 { + + if data.Cmp(big.NewInt(int64(0))) == 1 && data.Cmp(big.NewInt(int64(16))) == -1 { p.Emit(OpCode((int(PUSH1) - 1 + int(data.Int64())))) return } diff --git a/wasmtest/common/common.go b/wasmtest/common/common.go index 60390975e0..0411dcd5bb 100644 --- a/wasmtest/common/common.go +++ b/wasmtest/common/common.go @@ -107,10 +107,7 @@ func GenWasmTransaction(testCase TestCase, contract common.Address, testConext * contextParam := buildTestConext(testConext) contract.Args = append(contract.Args, contextParam...) - sink := common.NewZeroCopySink(nil) - contract.Serialization(sink) - - tx.Payload.(*payload.InvokeCode).Code = sink.Bytes() + tx.Payload.(*payload.InvokeCode).Code = common.SerializeToBytes(contract) } imt, err := tx.IntoImmutable()