diff --git a/pkg/proof/proof.go b/pkg/proof/proof.go index 56d9726b81..d6ac95f02a 100644 --- a/pkg/proof/proof.go +++ b/pkg/proof/proof.go @@ -121,9 +121,37 @@ func NewShareInclusionProofFromEDS( rows[i-startRow] = shares } - var shareProofs []*NMTProof //nolint:prealloc + shareProofs, rawShares, err := CreateShareToRowRootProofs(squareSize, rows, rowRoots, startLeaf, endLeaf) + if err != nil { + return ShareProof{}, err + } + return ShareProof{ + RowProof: &RowProof{ + RowRoots: rowRoots, + Proofs: rowProofs, + StartRow: uint32(startRow), + EndRow: uint32(endRow), + }, + Data: rawShares, + ShareProofs: shareProofs, + NamespaceId: namespace.ID, + NamespaceVersion: uint32(namespace.Version), + }, nil +} + +func safeConvertUint64ToInt(val uint64) (int, error) { + if val > math.MaxInt { + return 0, fmt.Errorf("value %d is too large to convert to int", val) + } + return int(val), nil +} + +// CreateShareToRowRootProofs takes a set of shares and their corresponding row roots, and generates +// an NMT inclusion proof of a set of shares, defined by startLeaf and endLeaf, to their corresponding row roots. +func CreateShareToRowRootProofs(squareSize int, rowShares [][]shares.Share, rowRoots [][]byte, startLeaf, endLeaf int) ([]*NMTProof, [][]byte, error) { + shareProofs := make([]*NMTProof, 0, len(rowRoots)) var rawShares [][]byte - for i, row := range rows { + for i, row := range rowShares { // create an nmt to generate a proof. // we have to re-create the tree as the eds one is not accessible. tree := wrapper.NewErasuredNamespacedMerkleTree(uint64(squareSize), uint(i)) @@ -132,17 +160,17 @@ func NewShareInclusionProofFromEDS( share.ToBytes(), ) if err != nil { - return ShareProof{}, err + return nil, nil, err } } // make sure that the generated root is the same as the eds row root. root, err := tree.Root() if err != nil { - return ShareProof{}, err + return nil, nil, err } if !bytes.Equal(rowRoots[i], root) { - return ShareProof{}, errors.New("eds row root is different than tree root") + return nil, nil, errors.New("eds row root is different than tree root") } startLeafPos := startLeaf @@ -153,14 +181,14 @@ func NewShareInclusionProofFromEDS( startLeafPos = 0 } // if this is not the last row, then select for the rest of the row - if i != (len(rows) - 1) { + if i != (len(rowShares) - 1) { endLeafPos = squareSize - 1 } rawShares = append(rawShares, shares.ToBytes(row[startLeafPos:endLeafPos+1])...) proof, err := tree.ProveRange(startLeafPos, endLeafPos+1) if err != nil { - return ShareProof{}, err + return nil, nil, err } shareProofs = append(shareProofs, &NMTProof{ @@ -170,24 +198,5 @@ func NewShareInclusionProofFromEDS( LeafHash: proof.LeafHash(), }) } - - return ShareProof{ - RowProof: &RowProof{ - RowRoots: rowRoots, - Proofs: rowProofs, - StartRow: uint32(startRow), - EndRow: uint32(endRow), - }, - Data: rawShares, - ShareProofs: shareProofs, - NamespaceId: namespace.ID, - NamespaceVersion: uint32(namespace.Version), - }, nil -} - -func safeConvertUint64ToInt(val uint64) (int, error) { - if val > math.MaxInt { - return 0, fmt.Errorf("value %d is too large to convert to int", val) - } - return int(val), nil + return shareProofs, rawShares, nil }