diff --git a/prover/server/prover/batch_address_append_circuit.go b/prover/server/prover/batch_address_append_circuit.go index 458da3652..743486def 100644 --- a/prover/server/prover/batch_address_append_circuit.go +++ b/prover/server/prover/batch_address_append_circuit.go @@ -156,23 +156,6 @@ func (params *BatchAddressAppendParameters) CreateWitness() (*BatchAddressTreeAp return nil, fmt.Errorf("tree height cannot be 0") } - // Validate array sizes match BatchSize - if len(params.LowElementValues) != int(params.BatchSize) || - len(params.LowElementNextIndices) != int(params.BatchSize) || - len(params.LowElementNextValues) != int(params.BatchSize) || - len(params.LowElementIndices) != int(params.BatchSize) || - len(params.NewElementValues) != int(params.BatchSize) { - return nil, fmt.Errorf("array lengths must match BatchSize") - } - - // Validate proof lengths match TreeHeight - for i := 0; i < int(params.BatchSize); i++ { - if len(params.LowElementProofs[i]) != int(params.TreeHeight) || - len(params.NewElementProofs[i]) != int(params.TreeHeight) { - return nil, fmt.Errorf("proof lengths must match TreeHeight") - } - } - circuit := &BatchAddressTreeAppendCircuit{ BatchSize: params.BatchSize, TreeHeight: params.TreeHeight, @@ -211,6 +194,13 @@ func (params *BatchAddressAppendParameters) CreateWitness() (*BatchAddressTreeAp return circuit, nil } func (p *BatchAddressAppendParameters) ValidateShape() error { + if p.BatchSize == 0 { + return fmt.Errorf("batch size cannot be 0") + } + if p.TreeHeight == 0 { + return fmt.Errorf("tree height cannot be 0") + } + expectedArrayLen := int(p.BatchSize) expectedProofLen := int(p.TreeHeight) diff --git a/prover/server/prover/batch_address_append_circuit_test.go b/prover/server/prover/batch_address_append_circuit_test.go index 71a994e69..67ebd164e 100644 --- a/prover/server/prover/batch_address_append_circuit_test.go +++ b/prover/server/prover/batch_address_append_circuit_test.go @@ -85,6 +85,7 @@ func TestBatchAddressAppendCircuit(t *testing.T) { batchSize uint32 startIndex uint64 modifyParams func(*BatchAddressAppendParameters) + wantPanic bool }{ { name: "Invalid OldRoot", @@ -139,6 +140,7 @@ func TestBatchAddressAppendCircuit(t *testing.T) { modifyParams: func(p *BatchAddressAppendParameters) { p.LowElementValues = p.LowElementValues[:len(p.LowElementValues)-1] }, + wantPanic: true, }, { name: "Invalid proof length", @@ -148,6 +150,7 @@ func TestBatchAddressAppendCircuit(t *testing.T) { modifyParams: func(p *BatchAddressAppendParameters) { p.LowElementProofs[0] = p.LowElementProofs[0][:len(p.LowElementProofs[0])-1] }, + wantPanic: true, }, { name: "Empty arrays", @@ -182,6 +185,14 @@ func TestBatchAddressAppendCircuit(t *testing.T) { tc.modifyParams(params) + if tc.wantPanic { + assert.Panics(func() { + witness, _ := params.CreateWitness() + test.IsSolved(&circuit, witness, ecc.BN254.ScalarField()) + }) + return + } + witness, err := params.CreateWitness() if err != nil { return