diff --git a/cmd/contract_cmd.go b/cmd/contract_cmd.go index 0cb8b02681..576b1516e2 100644 --- a/cmd/contract_cmd.go +++ b/cmd/contract_cmd.go @@ -125,7 +125,12 @@ func deployContract(ctx *cli.Context) error { return nil } - vmtype := ctx.Uint(utils.GetFlagName(utils.ContractVmTypeFlag)) + vmtypeFlag := ctx.Uint(utils.GetFlagName(utils.ContractVmTypeFlag)) + vmtype, err := payload.VmTypeFromByte(byte(vmtypeFlag)) + if err != nil { + return err + } + codeFile := ctx.String(utils.GetFlagName(utils.ContractCodeFileFlag)) if "" == codeFile { return fmt.Errorf("please specific code file") @@ -154,7 +159,7 @@ func deployContract(ctx *cli.Context) error { cversion := fmt.Sprintf("%s", version) if ctx.IsSet(utils.GetFlagName(utils.ContractPrepareDeployFlag)) { - preResult, err := utils.PrepareDeployContract(byte(vmtype), code, name, cversion, author, email, desc) + preResult, err := utils.PrepareDeployContract(vmtype, code, name, cversion, author, email, desc) if err != nil { return fmt.Errorf("PrepareDeployContract error:%s", err) } @@ -171,7 +176,7 @@ func deployContract(ctx *cli.Context) error { return fmt.Errorf("get signer account error:%s", err) } - txHash, err := utils.DeployContract(gasPrice, gasLimit, signer, byte(vmtype), code, name, cversion, author, email, desc) + txHash, err := utils.DeployContract(gasPrice, gasLimit, signer, vmtype, code, name, cversion, author, email, desc) if err != nil { return fmt.Errorf("DeployContract error:%s", err) } @@ -289,9 +294,10 @@ func invokeContract(ctx *cli.Context) error { if err != nil { return fmt.Errorf("invalid contract address error:%s", err) } - vmtype := ctx.Uint(utils.GetFlagName(utils.ContractVmTypeFlag)) - if byte(vmtype) != payload.NEOVM_TYPE && byte(vmtype) != payload.WASMVM_TYPE { - return fmt.Errorf("invalid vmtype") + vmtypeFlag := ctx.Uint(utils.GetFlagName(utils.ContractVmTypeFlag)) + vmtype, err := payload.VmTypeFromByte(byte(vmtypeFlag)) + if err != nil { + return err } paramsStr := ctx.String(utils.GetFlagName(utils.ContractParamsFlag)) params, err := utils.ParseParams(paramsStr) @@ -304,11 +310,11 @@ func invokeContract(ctx *cli.Context) error { if ctx.IsSet(utils.GetFlagName(utils.ContractPrepareInvokeFlag)) { var preResult *states.PreExecResult - if byte(vmtype) == payload.NEOVM_TYPE { + if vmtype == payload.NEOVM_TYPE { preResult, err = utils.PrepareInvokeNeoVMContract(contractAddr, params) } - if byte(vmtype) == payload.WASMVM_TYPE { + if vmtype == payload.WASMVM_TYPE { preResult, err = utils.PrepareInvokeWasmVMContract(contractAddr, params) } @@ -327,7 +333,7 @@ func invokeContract(ctx *cli.Context) error { PrintInfoMsg(" Return:%s (raw value)", preResult.Result) return nil } - values, err := utils.ParseReturnValue(preResult.Result, rawReturnTypes, byte(vmtype)) + values, err := utils.ParseReturnValue(preResult.Result, rawReturnTypes, vmtype) if err != nil { return fmt.Errorf("parseReturnValue values:%+v types:%s error:%s", values, rawReturnTypes, err) } @@ -356,13 +362,13 @@ func invokeContract(ctx *cli.Context) error { } var txHash string - if byte(vmtype) == payload.NEOVM_TYPE { + if vmtype == payload.NEOVM_TYPE { txHash, err = utils.InvokeNeoVMContract(gasPrice, gasLimit, signer, contractAddr, params) if err != nil { return fmt.Errorf("invoke NeoVM contract error:%s", err) } } - if byte(vmtype) == payload.WASMVM_TYPE { + if vmtype == payload.WASMVM_TYPE { txHash, err = utils.InvokeWasmVMContract(gasPrice, gasLimit, signer, contractAddr, params) if err != nil { return fmt.Errorf("invoke NeoVM contract error:%s", err) diff --git a/cmd/utils/ont.go b/cmd/utils/ont.go index bc80d6b5bf..72f076b86e 100644 --- a/cmd/utils/ont.go +++ b/cmd/utils/ont.go @@ -593,7 +593,7 @@ func DeployContract( gasPrice, gasLimit uint64, signer *account.Account, - vmtype byte, + vmtype payload.VmType, code, cname, cversion, @@ -623,7 +623,7 @@ func DeployContract( } func PrepareDeployContract( - needStorage byte, + vmtype payload.VmType, code, cname, cversion, @@ -634,7 +634,7 @@ func PrepareDeployContract( if err != nil { return nil, fmt.Errorf("hex.DecodeString error:%s", err) } - mutable := NewDeployCodeTransaction(0, 0, c, needStorage, cname, cversion, cauthor, cemail, cdesc) + mutable := NewDeployCodeTransaction(0, 0, c, vmtype, cname, cversion, cauthor, cemail, cdesc) tx, _ := mutable.IntoImmutable() var buffer bytes.Buffer err = tx.Serialize(&buffer) @@ -775,18 +775,18 @@ func PrepareInvokeNativeContract( } //NewDeployCodeTransaction return a smart contract deploy transaction instance -func NewDeployCodeTransaction(gasPrice, gasLimit uint64, code []byte, vmType byte, +func NewDeployCodeTransaction(gasPrice, gasLimit uint64, code []byte, vmType payload.VmType, cname, cversion, cauthor, cemail, cdesc string) *types.MutableTransaction { deployPayload := &payload.DeployCode{ Code: code, - VmType: vmType, Name: cname, Version: cversion, Author: cauthor, Email: cemail, Description: cdesc, } + deployPayload.SetVmType(vmType) tx := &types.MutableTransaction{ Version: VERSION_TRANSACTION, TxType: types.Deploy, diff --git a/cmd/utils/params.go b/cmd/utils/params.go index d6857e47ba..e55a107ee2 100644 --- a/cmd/utils/params.go +++ b/cmd/utils/params.go @@ -209,7 +209,7 @@ func parseRawParamValue(pType string, pValue string) (interface{}, error) { //Return type can be: bytearray, string, int, bool. //Types can be split with "," each other, such as int,string,bool //Type array can be express with "[]", such [int,string], param array can be nested, such as [int,[int,bool]] -func ParseReturnValue(rawValue interface{}, rawReturnTypeStr string, vmtype byte) ([]interface{}, error) { +func ParseReturnValue(rawValue interface{}, rawReturnTypeStr string, vmtype payload.VmType) ([]interface{}, error) { returnTypes, _, err := parseRawParamsString(rawReturnTypeStr) if err != nil { return nil, fmt.Errorf("parse raw return types:%s error:%s", rawReturnTypeStr, err) diff --git a/core/payload/deploy_code.go b/core/payload/deploy_code.go index 1c8f3c7663..ee08591e79 100644 --- a/core/payload/deploy_code.go +++ b/core/payload/deploy_code.go @@ -28,16 +28,27 @@ import ( "github.com/ontio/ontology/errors" ) +type VmType byte + const ( - NEOVM_TYPE byte = 1 - WASMVM_TYPE byte = 3 + NEOVM_TYPE VmType = 1 + WASMVM_TYPE VmType = 3 ) +func VmTypeFromByte(ty byte) (VmType, error) { + switch ty { + case 1, 3: + return VmType(ty), nil + default: + return VmType(0), fmt.Errorf("can not convert byte:%d to vm type", ty) + } +} + // DeployCode is an implementation of transaction payload for deploy smartcontract type DeployCode struct { Code []byte - //modify for define contract type - VmType byte + //0, 1 means NEOVM_TYPE, 3 means WASMVM_TYPE + vmFlags byte Name string Version string Author string @@ -54,6 +65,30 @@ func (dc *DeployCode) Address() common.Address { return dc.address } +func (dc *DeployCode) SetVmType(ty VmType) { + dc.vmFlags = byte(ty) +} + +func checkVmFlags(vmFlags byte) error { + switch vmFlags { + case 0, 1, 3: + return nil + default: + return fmt.Errorf("invalid vm flags: %d", vmFlags) + } +} + +func (dc *DeployCode) VmType() VmType { + switch dc.vmFlags { + case 0, 1: + return NEOVM_TYPE + case 3: + return WASMVM_TYPE + default: + panic("unreachable") + } +} + func (dc *DeployCode) Serialize(w io.Writer) error { var err error @@ -62,7 +97,7 @@ func (dc *DeployCode) Serialize(w io.Writer) error { return fmt.Errorf("DeployCode Code Serialize failed: %s", err) } - err = serialization.WriteByte(w, dc.VmType) + err = serialization.WriteByte(w, dc.vmFlags) if err != nil { return fmt.Errorf("DeployCode NeedStorage Serialize failed: %s", err) } @@ -102,7 +137,7 @@ func (dc *DeployCode) Deserialize(r io.Reader) error { } dc.Code = code - dc.VmType, err = serialization.ReadByte(r) + dc.vmFlags, err = serialization.ReadByte(r) if err != nil { return fmt.Errorf("DeployCode NeedStorage Deserialize failed: %s", err) } @@ -148,7 +183,7 @@ func (dc *DeployCode) ToArray() []byte { func (dc *DeployCode) Serialization(sink *common.ZeroCopySink) error { sink.WriteVarBytes(dc.Code) - sink.WriteByte(dc.VmType) + sink.WriteByte(dc.vmFlags) sink.WriteString(dc.Name) sink.WriteString(dc.Version) sink.WriteString(dc.Author) @@ -166,8 +201,7 @@ func (dc *DeployCode) Deserialization(source *common.ZeroCopySource) error { return common.ErrIrregularData } - dc.VmType, eof = source.NextByte() - + dc.vmFlags, eof = source.NextByte() dc.Name, _, irregular, eof = source.NextString() if irregular { return common.ErrIrregularData @@ -208,38 +242,39 @@ func (dc *DeployCode) Deserialization(source *common.ZeroCopySource) error { const maxWasmCodeSize = 512 * 1024 func validateDeployCode(dep *DeployCode) error { - if dep.VmType == WASMVM_TYPE { + err := checkVmFlags(dep.vmFlags) + if err != nil { + return err + } + + if dep.VmType() == WASMVM_TYPE { if len(dep.Code) > maxWasmCodeSize { - return errors.NewErr("[Contract] Code too long!") + return errors.NewErr("[contract] Code too long!") } } else { if len(dep.Code) > 1024*1024 { - return errors.NewErr("[Contract] Code too long!") + return errors.NewErr("[contract] Code too long!") } } if len(dep.Name) > 252 { - return errors.NewErr("[Contract] name too long!") + return errors.NewErr("[contract] name too long!") } if len(dep.Version) > 252 { - return errors.NewErr("[Contract] version too long!") + return errors.NewErr("[contract] version too long!") } if len(dep.Author) > 252 { - return errors.NewErr("[author] version too long!") + return errors.NewErr("[contract] version too long!") } if len(dep.Email) > 252 { - return errors.NewErr("[author] emailPtr too long!") + return errors.NewErr("[contract] email too long!") } if len(dep.Description) > 65536 { - return errors.NewErr("[descPtr] emailPtr too long!") - } - - if dep.VmType != WASMVM_TYPE && dep.VmType != NEOVM_TYPE { - return errors.NewErr("[descPtr] VmType invalid!") + return errors.NewErr("[contract] description too long!") } return nil @@ -252,10 +287,13 @@ func CreateDeployCode(code []byte, author []byte, email []byte, desc []byte) (*DeployCode, error) { + if vmType > 255 { + return nil, fmt.Errorf("wrong vm flags: %d", vmType) + } contract := &DeployCode{ Code: code, - VmType: byte(vmType), + vmFlags: byte(vmType), Name: string(name), Version: string(version), Author: string(author), diff --git a/core/store/ledgerstore/tx_handler.go b/core/store/ledgerstore/tx_handler.go index c3f7f47ea0..d119fc981b 100644 --- a/core/store/ledgerstore/tx_handler.go +++ b/core/store/ledgerstore/tx_handler.go @@ -56,7 +56,7 @@ func (self *StateStore) HandleDeployTransaction(store store.LedgerStore, overlay ) _, err = wasmvm.ReadWasmModule(deploy, true) - if deploy.VmType == payload.WASMVM_TYPE && err != nil { + if deploy.VmType() == payload.WASMVM_TYPE && err != nil { return err } diff --git a/core/utils/transaction_builder.go b/core/utils/transaction_builder.go index 5f1fe13c93..82788dcb11 100644 --- a/core/utils/transaction_builder.go +++ b/core/utils/transaction_builder.go @@ -36,21 +36,21 @@ import ( const NATIVE_INVOKE_NAME = "Ontology.Native.Invoke" // copy from smartcontract/service/neovm/config.go to avoid cycle dependences // NewDeployTransaction returns a deploy Transaction -func NewDeployTransaction(code []byte, name, version, author, email, desp string, vmType byte) *types.MutableTransaction { +func NewDeployTransaction(code []byte, name, version, author, email, desp string, vmType payload.VmType) *types.MutableTransaction { //TODO: check arguments - DeployCodePayload := &payload.DeployCode{ + depCode := &payload.DeployCode{ Code: code, - VmType: vmType, Name: name, Version: version, Author: author, Email: email, Description: desp, } + depCode.SetVmType(vmType) return &types.MutableTransaction{ TxType: types.Deploy, - Payload: DeployCodePayload, + Payload: depCode, } } diff --git a/core/validation/transaction_validator.go b/core/validation/transaction_validator.go index f9b482e948..e79ddae310 100644 --- a/core/validation/transaction_validator.go +++ b/core/validation/transaction_validator.go @@ -117,7 +117,7 @@ func checkTransactionPayload(tx *types.Transaction) error { case *payload.DeployCode: deploy := tx.Payload.(*payload.DeployCode) _, err := wasmvm.ReadWasmModule(deploy, true) - if deploy.VmType == payload.WASMVM_TYPE && err != nil { + if deploy.VmType() == payload.WASMVM_TYPE && err != nil { return err } return nil diff --git a/http/base/common/payload_to_hex.go b/http/base/common/payload_to_hex.go index f538d1c54d..bce849b41a 100644 --- a/http/base/common/payload_to_hex.go +++ b/http/base/common/payload_to_hex.go @@ -100,7 +100,7 @@ func TransPayloadToHex(p types.Payload) PayloadInfo { case *payload.DeployCode: obj := new(DeployCodeInfo) obj.Code = common.ToHexString(object.Code) - obj.VmType = object.VmType + obj.VmType = byte(object.VmType()) obj.Name = object.Name obj.CodeVersion = object.Version obj.Author = object.Author diff --git a/smartcontract/service/neovm/wasmvm.go b/smartcontract/service/neovm/wasmvm.go index 0f6960019f..7b3a22f999 100644 --- a/smartcontract/service/neovm/wasmvm.go +++ b/smartcontract/service/neovm/wasmvm.go @@ -48,7 +48,7 @@ func WASMInvoke(service *NeoVmService, engine *vm.Executor) error { if dp == nil { return fmt.Errorf("wasm contract does not exist") } - if dp.VmType != payload.WASMVM_TYPE { + if dp.VmType() != payload.WASMVM_TYPE { return fmt.Errorf("not a wasm contract") } diff --git a/smartcontract/service/wasmvm/contract.go b/smartcontract/service/wasmvm/contract.go index 6d4343afaf..25f043a019 100644 --- a/smartcontract/service/wasmvm/contract.go +++ b/smartcontract/service/wasmvm/contract.go @@ -80,7 +80,7 @@ func ContractCreate(proc *exec.Process, } _, err = ReadWasmModule(dep, true) - if dep.VmType == payload.WASMVM_TYPE && err != nil { + if dep.VmType() == payload.WASMVM_TYPE && err != nil { panic(err) } @@ -156,7 +156,7 @@ func ContractMigrate(proc *exec.Process, } _, err = ReadWasmModule(dep, true) - if dep.VmType == payload.WASMVM_TYPE && err != nil { + if dep.VmType() == payload.WASMVM_TYPE && err != nil { panic(err) } diff --git a/smartcontract/service/wasmvm/runtime.go b/smartcontract/service/wasmvm/runtime.go index d8d79e3df2..818014ca48 100644 --- a/smartcontract/service/wasmvm/runtime.go +++ b/smartcontract/service/wasmvm/runtime.go @@ -690,7 +690,7 @@ func (self *Runtime) getContractType(addr common.Address) (ContractType, error) if dep == nil { return UNKOWN_CONTRACT, errors.NewErr("contract is not exist.") } - if dep.VmType == payload.WASMVM_TYPE { + if dep.VmType() == payload.WASMVM_TYPE { return WASMVM_CONTRACT, nil } diff --git a/smartcontract/service/wasmvm/utils.go b/smartcontract/service/wasmvm/utils.go index 49139c50e0..d1ac83786b 100644 --- a/smartcontract/service/wasmvm/utils.go +++ b/smartcontract/service/wasmvm/utils.go @@ -42,7 +42,7 @@ func ReadWasmMemory(proc *exec.Process, ptr uint32, len uint32) ([]byte, error) } func ReadWasmModule(dep *payload.DeployCode, verify bool) (*exec.CompiledModule, error) { - if dep.VmType == payload.NEOVM_TYPE { + if dep.VmType() == payload.NEOVM_TYPE { return nil, errors.New("only wasm contract need verify") }