diff --git a/app/abci.go b/app/abci.go index fd4a45c78e..591c850f56 100644 --- a/app/abci.go +++ b/app/abci.go @@ -121,4 +121,4 @@ func hasWirePayForMessage(tx sdk.Tx) bool { // hardcoded. todo(evan): don't hardcode the square size func (app *App) SquareSize() uint64 { return consts.MaxSquareSize -} +} \ No newline at end of file diff --git a/x/payment/client/cli/wirepayformessage.go b/x/payment/client/cli/wirepayformessage.go index f54666b5f1..5257270473 100644 --- a/x/payment/client/cli/wirepayformessage.go +++ b/x/payment/client/cli/wirepayformessage.go @@ -15,6 +15,8 @@ import ( sdk "github.com/cosmos/cosmos-sdk/types" ) +const FlagSquareSizes = "square-sizes" + func CmdWirePayForMessage() *cobra.Command { cmd := &cobra.Command{ Use: "payForMessage [hexNamespace] [hexMessage]", @@ -52,7 +54,12 @@ func CmdWirePayForMessage() *cobra.Command { } // create the MsgPayForMessage - pfmMsg, err := types.NewWirePayForMessage(namespace, message, consts.MaxSquareSize) + squareSizes, err := cmd.Flags().GetUintSlice(FlagSquareSizes) + if err != nil { + return err + } + squareSizes64 := parseSquareSizes(squareSizes) + pfmMsg, err := types.NewWirePayForMessage(namespace, message, squareSizes64...) if err != nil { return err } @@ -102,6 +109,15 @@ func CmdWirePayForMessage() *cobra.Command { } flags.AddTxFlagsToCmd(cmd) + cmd.Flags().UintSlice(FlagSquareSizes, []uint{consts.MaxSquareSize, 128, 64}, "Specify the square sizes, must be power of 2") return cmd } + +func parseSquareSizes(squareSizes []uint) []uint64 { + squareSizes64 := make([]uint64, len(squareSizes)) + for i := range squareSizes { + squareSizes64[i] = uint64(squareSizes[i]) + } + return squareSizes64 +} \ No newline at end of file diff --git a/x/payment/client/testutil/integration_test.go b/x/payment/client/testutil/integration_test.go index 7e728f7e93..93d274ca3a 100644 --- a/x/payment/client/testutil/integration_test.go +++ b/x/payment/client/testutil/integration_test.go @@ -82,6 +82,32 @@ func (s *IntegrationTestSuite) TestSubmitWirePayForMessage() { }, false, 0, &sdk.TxResponse{}, }, + { + "valid transaction list of square sizes", + []string{ + hexNS, + hexMsg, + fmt.Sprintf("--from=%s", username), + fmt.Sprintf("--%s=%s", flags.FlagBroadcastMode, flags.BroadcastBlock), + fmt.Sprintf("--%s=%s", flags.FlagFees, sdk.NewCoins(sdk.NewCoin(s.cfg.BondDenom, sdk.NewInt(2))).String()), + fmt.Sprintf("--%s=true", flags.FlagSkipConfirmation), + fmt.Sprintf("--%s=%s", paycli.FlagSquareSizes, "256,128,64"), + }, + false, 0, &sdk.TxResponse{}, + }, + { + "invalid transaction list of square sizes", + []string{ + hexNS, + hexMsg, + fmt.Sprintf("--from=%s", username), + fmt.Sprintf("--%s=%s", flags.FlagBroadcastMode, flags.BroadcastBlock), + fmt.Sprintf("--%s=%s", flags.FlagFees, sdk.NewCoins(sdk.NewCoin(s.cfg.BondDenom, sdk.NewInt(2))).String()), + fmt.Sprintf("--%s=true", flags.FlagSkipConfirmation), + fmt.Sprintf("--%s=%s", paycli.FlagSquareSizes, "256,123,64"), + }, + true, 0, &sdk.TxResponse{}, + }, } for _, tc := range testCases { diff --git a/x/payment/types/payformessage.go b/x/payment/types/payformessage.go index 10b298fd22..10e8a2a868 100644 --- a/x/payment/types/payformessage.go +++ b/x/payment/types/payformessage.go @@ -226,3 +226,12 @@ func nextPowerOf2(v uint64) uint64 { // return the next lowest power return v / 2 } + +// Check if number is power of 2 +func powerOf2(v uint64) bool { + if v & (v-1) == 0 && v != 0 { + return true + } else { + return false + } +} diff --git a/x/payment/types/payformessage_test.go b/x/payment/types/payformessage_test.go index e071c7bd8e..24dd9f49bb 100644 --- a/x/payment/types/payformessage_test.go +++ b/x/payment/types/payformessage_test.go @@ -75,6 +75,43 @@ func TestNextPowerOf2(t *testing.T) { } } +func TestPowerOf2 (t *testing.T) { + type test struct { + input uint64 + expected bool + } + tests := []test { + { + input: 1, + expected: true, + }, + { + input: 2, + expected: true, + }, + { + input: 256, + expected: true, + }, + { + input: 3, + expected: false, + }, + { + input: 79, + expected: false, + }, + { + input: 0, + expected: false, + }, + } + for _, tt := range tests { + res := powerOf2(tt.input) + assert.Equal(t, tt.expected, res) + } +} + // TestCreateCommit only shows if something changed, it doesn't actually show // the commit is being created correctly todo(evan): fix me. func TestCreateCommitment(t *testing.T) { @@ -251,6 +288,14 @@ func TestWirePayForMessage_ValidateBasic(t *testing.T) { badCommitMsg := validWirePayForMessage(t) badCommitMsg.MessageShareCommitment[0].ShareCommitment = []byte{1, 2, 3, 4} + // pfm that has invalid square size (not power of 2) + invalidSquareSizeMsg := validWirePayForMessage(t) + invalidSquareSizeMsg.MessageShareCommitment[0].K = 15 + + // pfm that has a different power of 2 square size + badSquareSizeMsg := validWirePayForMessage(t) + badSquareSizeMsg.MessageShareCommitment[0].K = 4 + tests := []test{ { name: "valid msg", @@ -286,6 +331,18 @@ func TestWirePayForMessage_ValidateBasic(t *testing.T) { expectErr: true, errStr: "invalid commit for square size", }, + { + name: "invalid square size", + msg: invalidSquareSizeMsg, + expectErr: true, + errStr: fmt.Sprintf("invalid square size, the size must be power of 2: %d", invalidSquareSizeMsg.MessageShareCommitment[0].K), + }, + { + name: "wrong but valid square size", + msg: badSquareSizeMsg, + expectErr: true, + errStr: fmt.Sprintf("invalid commit for square size %d", badSquareSizeMsg.MessageShareCommitment[0].K), + }, } for _, tt := range tests { @@ -371,7 +428,7 @@ func TestProcessMessage(t *testing.T) { func validWirePayForMessage(t *testing.T) *MsgWirePayForMessage { msg, err := NewWirePayForMessage( []byte{1, 2, 3, 4, 5, 6, 7, 8}, - bytes.Repeat([]byte{1}, 1000), + bytes.Repeat([]byte{1}, 2000), 16, 32, 64, ) if err != nil { diff --git a/x/payment/types/wirepayformessage.go b/x/payment/types/wirepayformessage.go index 5ad1a0761d..ff1eac534b 100644 --- a/x/payment/types/wirepayformessage.go +++ b/x/payment/types/wirepayformessage.go @@ -30,6 +30,9 @@ func NewWirePayForMessage(namespace, message []byte, sizes ...uint64) (*MsgWireP // generate the share commitments for i, size := range sizes { + if !powerOf2(size) { + return nil, fmt.Errorf("Invalid square size, the size must be power of 2: %d", size) + } commit, err := CreateCommitment(size, namespace, message) if err != nil { return nil, err @@ -101,6 +104,10 @@ func (msg *MsgWirePayForMessage) ValidateBasic() error { for _, commit := range msg.MessageShareCommitment { // check that each commit is valid + if !powerOf2(commit.K) { + return fmt.Errorf("invalid square size, the size must be power of 2: %d", commit.K) + } + calculatedCommit, err := CreateCommitment(commit.K, msg.GetMessageNameSpaceId(), msg.Message) if err != nil { return err @@ -205,4 +212,4 @@ func ProcessWirePayForMessage(msg *MsgWirePayForMessage, squareSize uint64) (*tm } return &coreMsg, pfm, shareCommit.Signature, nil -} +} \ No newline at end of file