Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor: Simplify computation of checksum and words #959

Closed
wants to merge 9 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,4 @@ require (
golang.org/x/sys v0.11.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
rsc.io/tmplfunc v0.0.3 // indirect
)
)
141 changes: 141 additions & 0 deletions std/compress/io.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
package compress

import (
"errors"
"github.com/consensys/compress/lzss"
realHash "hash"
"math/big"

"github.com/consensys/compress"
"github.com/consensys/gnark-crypto/ecc"
"github.com/consensys/gnark-crypto/hash"
"github.com/consensys/gnark/frontend"
"github.com/consensys/gnark/std/hash/mimc"
test_vector_utils "github.com/consensys/gnark/std/utils/test_vectors_utils"
)

// Checksum packs the words into as few field elements as possible, and returns the hash of the packed words
func Checksum(api frontend.API, words []frontend.Variable, nbWords frontend.Variable, wordLen int) frontend.Variable {
packed := pack(api, words, wordLen)
hsh, err := mimc.NewMiMC(api)
if err != nil {
return err
}
hsh.Write(packed...)
hsh.Write(nbWords)
return hsh.Sum()
}

func pack(api frontend.API, words []frontend.Variable, wordLen int) []frontend.Variable {
wordsPerElem := (api.Compiler().FieldBitLen() - 1) / wordLen
packed := make([]frontend.Variable, (len(words)+wordsPerElem-1)/wordsPerElem)
radix := 1 << wordLen
for i := range packed {
packed[i] = 0
for j := wordsPerElem - 1; j >= 0; j-- {
absJ := i*wordsPerElem + j
if absJ >= len(words) {
continue
}
packed[i] = api.Add(words[absJ], api.Mul(packed[i], radix))
}
}
return packed
}

type NumReader struct {
api frontend.API
c []frontend.Variable
stepCoeff int
maxCoeff int
nbWords int
nxt frontend.Variable
}

func NewNumReader(api frontend.API, c []frontend.Variable, numNbBits, wordNbBits int) *NumReader {
nbWords := numNbBits / wordNbBits
stepCoeff := 1 << wordNbBits
nxt := ReadNum(api, c, nbWords, stepCoeff)
return &NumReader{
api: api,
c: c,
stepCoeff: stepCoeff,
maxCoeff: 1 << numNbBits,
nxt: nxt,
nbWords: nbWords,
}
}

func ReadNum(api frontend.API, c []frontend.Variable, nbWords, stepCoeff int) frontend.Variable {
res := frontend.Variable(0)
for i := 0; i < nbWords && i < len(c); i++ {
res = api.Add(c[i], api.Mul(res, stepCoeff))
}
return res
}

// Next returns the next number in the sequence. assumes bits past the end of the slice are 0
func (nr *NumReader) Next() frontend.Variable {
res := nr.nxt

if len(nr.c) != 0 {
nr.nxt = nr.api.Sub(nr.api.Mul(nr.nxt, nr.stepCoeff), nr.api.Mul(nr.c[0], nr.maxCoeff))

if nr.nbWords < len(nr.c) {
nr.nxt = nr.api.Add(nr.nxt, nr.c[nr.nbWords])
}

nr.c = nr.c[1:]
}

return res
}

// ToSnarkData breaks a stream up into words of the right size for snark consumption, and computes the checksum of that data in a way congruent with Checksum
func ToSnarkData(curveId ecc.ID, s compress.Stream, paddedNbBits int, level lzss.Level) (words []frontend.Variable, checksum []byte, err error) {

wordNbBits := int(level)

paddedNbWords := paddedNbBits / wordNbBits

if paddedNbWords*wordNbBits != paddedNbBits {
return nil, nil, errors.New("the padded size must divide the word length")
}

wStream := s.BreakUp(1 << wordNbBits)
wPadded := wStream

if contentNbBits := wStream.Len() * wordNbBits; contentNbBits != paddedNbBits {
wPadded.D = make([]int, paddedNbWords)
copy(wPadded.D, wStream.D)
}

words = test_vector_utils.ToVariableSlice(wPadded.D)

var hsh realHash.Hash
switch curveId {
case ecc.BLS12_377:
hsh = hash.MIMC_BLS12_377.New()
case ecc.BN254:
hsh = hash.MIMC_BN254.New()
default:
return nil, nil, errors.New("TODO Add switch-case for curve")
}

fieldNbBits := curveId.ScalarField().BitLen()
fieldNbBytes := (fieldNbBits + 7) / 8
packed := wPadded.Pack(fieldNbBits)
byts := make([]byte, fieldNbBytes)

for _, w := range packed {
w.FillBytes(byts)
hsh.Write(byts)
}

big.NewInt(int64(wStream.Len())).FillBytes(byts)
hsh.Write(byts)

checksum = hsh.Sum(nil)

return
}
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,19 @@ import (
"testing"
)

func Test1ZeroSnark(t *testing.T) {
testCompressionRoundTripSnark(t, []byte{0}, nil)
func Test1Zero(t *testing.T) {
testCompressionRoundTrip(t, []byte{0}, nil)
}

func TestGoodCompressionSnark(t *testing.T) {
testCompressionRoundTripSnark(t, []byte{1, 2}, nil, withLevel(lzss.GoodCompression))
func TestGoodCompression(t *testing.T) {
testCompressionRoundTrip(t, []byte{1, 2}, nil, withLevel(lzss.GoodCompression))
}

func Test0To10ExplicitSnark(t *testing.T) {
testCompressionRoundTripSnark(t, []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, nil)
func Test0To10Explicit(t *testing.T) {
testCompressionRoundTrip(t, []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, nil)
}

func TestNoCompressionSnark(t *testing.T) {
func TestNoCompression(t *testing.T) {

d, err := os.ReadFile("./testdata/3c2943/data.bin")
assert.NoError(t, err)
Expand Down Expand Up @@ -58,21 +58,21 @@ func TestNoCompressionSnark(t *testing.T) {
test.NewAssert(t).CheckCircuit(circuit, test.WithValidAssignment(assignment), test.WithBackends(backend.PLONK), test.WithCurves(ecc.BN254))
}

func Test255_254_253Snark(t *testing.T) {
testCompressionRoundTripSnark(t, []byte{255, 254, 253}, nil)
func Test255_254_253(t *testing.T) {
testCompressionRoundTrip(t, []byte{255, 254, 253}, nil)
}

func Test3c2943Snark(t *testing.T) {
func Test3c2943(t *testing.T) {
d, err := os.ReadFile("./testdata/3c2943/data.bin")
assert.NoError(t, err)

dict := getDictionary()

testCompressionRoundTripSnark(t, d, dict)
testCompressionRoundTrip(t, d, dict)
}

// Fuzz test the decompression
func FuzzSnark(f *testing.F) { // TODO This is always skipped
func Fuzz(f *testing.F) { // TODO This is always skipped
f.Fuzz(func(t *testing.T, input, dict []byte) {
if len(input) > lzss.MaxInputSize {
t.Skip("input too large")
Expand All @@ -83,7 +83,7 @@ func FuzzSnark(f *testing.F) { // TODO This is always skipped
if len(input) == 0 {
t.Skip("input too small")
}
testCompressionRoundTripSnark(t, input, dict)
testCompressionRoundTrip(t, input, dict)
})
}

Expand All @@ -95,7 +95,7 @@ func withLevel(level lzss.Level) testCompressionRoundTripOption {
}
}

func testCompressionRoundTripSnark(t *testing.T, d, dict []byte, options ...testCompressionRoundTripOption) {
func testCompressionRoundTrip(t *testing.T, d, dict []byte, options ...testCompressionRoundTripOption) {

level := lzss.BestCompression

Expand Down Expand Up @@ -126,36 +126,6 @@ func testCompressionRoundTripSnark(t *testing.T, d, dict []byte, options ...test
test.NewAssert(t).CheckCircuit(circuit, test.WithValidAssignment(assignment), test.WithBackends(backend.PLONK), test.WithCurves(ecc.BN254))
}

func TestReadBytes(t *testing.T) {
expected := []byte{254, 0, 0, 0}
circuit := &readBytesCircuit{
Words: make([]frontend.Variable, 8*len(expected)),
WordNbBits: 1,
Expected: expected,
}
words, err := goCompress.NewStream(expected, 8)
assert.NoError(t, err)
words = words.BreakUp(2)
assignment := &readBytesCircuit{
Words: test_vector_utils.ToVariableSlice(words.D),
}
test.NewAssert(t).CheckCircuit(circuit, test.WithValidAssignment(assignment), test.WithBackends(backend.PLONK), test.WithCurves(ecc.BN254))
}

type readBytesCircuit struct {
Words []frontend.Variable
WordNbBits int
Expected []byte
}

func (c *readBytesCircuit) Define(api frontend.API) error {
byts := combineIntoBytes(api, c.Words, c.WordNbBits)
for i := range c.Expected {
api.AssertIsEqual(c.Expected[i], byts[i*8])
}
return nil
}

func getDictionary() []byte {
d, err := os.ReadFile("./testdata/dict_naive")
if err != nil {
Expand Down
65 changes: 42 additions & 23 deletions std/compress/lzss/e2e_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@ package lzss
import (
goCompress "github.com/consensys/compress"
"github.com/consensys/compress/lzss"
"github.com/consensys/gnark/std/compress"
"math/bits"
"os"
"testing"

"github.com/consensys/gnark-crypto/ecc"
"github.com/consensys/gnark/backend"
"github.com/consensys/gnark/frontend"
test_vector_utils "github.com/consensys/gnark/std/utils/test_vectors_utils"
"github.com/consensys/gnark/test"
"github.com/stretchr/testify/assert"
)
Expand All @@ -33,27 +34,30 @@ func testCompressionE2E(t *testing.T, d, dict []byte, name string) {
// compress

level := lzss.GoodCompression
wordNbBits := int(level)
const curveId = ecc.BLS12_377

compressor, err := lzss.NewCompressor(dict, level)
assert.NoError(t, err)

c, err := compressor.Compress(d)
assert.NoError(t, err)

cStream, err := goCompress.NewStream(c, uint8(level))
cStream, err := goCompress.NewStream(c, uint8(wordNbBits))
assert.NoError(t, err)

cSum, err := check(cStream, cStream.Len())
cWords, cSum, err := compress.ToSnarkData(curveId, cStream, wordNbBits*cStream.Len(), level)
assert.NoError(t, err)

dStream, err := goCompress.NewStream(d, 8)
assert.NoError(t, err)

dSum, err := check(dStream, len(d))
dWords, dSum, err := compress.ToSnarkData(curveId, dStream, len(d)*8, 8)
assert.NoError(t, err)

circuit := compressionCircuit{
C: make([]frontend.Variable, cStream.Len()),
D: make([]frontend.Variable, len(d)),
C: make([]frontend.Variable, len(cWords)),
D: make([]frontend.Variable, len(dWords)),
Dict: make([]byte, len(dict)),
Level: level,
}
Expand All @@ -63,45 +67,60 @@ func testCompressionE2E(t *testing.T, d, dict []byte, name string) {
assignment := compressionCircuit{
CChecksum: cSum,
DChecksum: dSum,
C: test_vector_utils.ToVariableSlice(cStream.D),
D: test_vector_utils.ToVariableSlice(d),
C: cWords,
D: dWords,
Dict: dict,
CLen: cStream.Len(),
DLen: len(d),
}
test.NewAssert(t).SolvingSucceeded(&circuit, &assignment, test.WithBackends(backend.PLONK), test.WithCurves(ecc.BN254))

test.NewAssert(t).CheckCircuit(&circuit, test.WithValidAssignment(&assignment), test.WithBackends(backend.PLONK), test.WithCurves(curveId))
}

func TestChecksumNothing(t *testing.T) {
testChecksum(t, goCompress.Stream{D: []int{}, NbSymbs: 256}, 0, lzss.BestSnarkDecompression)
}

func TestChecksumOne(t *testing.T) {
testChecksum(t, goCompress.Stream{D: []int{1}, NbSymbs: 256}, 8, lzss.BestSnarkDecompression)
}

func TestChecksum0(t *testing.T) {
testChecksum(t, goCompress.Stream{D: []int{}, NbSymbs: 256})
func TestChecksumOneWithBits(t *testing.T) {
testChecksum(t, goCompress.Stream{D: []int{1}, NbSymbs: 256}, 9, lzss.BestCompression)
}

func testChecksum(t *testing.T, d goCompress.Stream) {
func testChecksum(t *testing.T, d goCompress.Stream, paddedNbBits int, level lzss.Level) {
const curveId = ecc.BLS12_377

words, checksum, err := compress.ToSnarkData(curveId, d, paddedNbBits, level)
assert.NoError(t, err)

circuit := checksumTestCircuit{
Inputs: make([]frontend.Variable, d.Len()),
InputLen: d.Len(),
Inputs: make([]frontend.Variable, len(words)),
WordLen: int(level),
}

sum, err := check(d, d.Len())
assert.NoError(t, err)
dWordLen := 63 - bits.LeadingZeros64(uint64(d.NbSymbs))
assert.Equal(t, 1<<dWordLen, d.NbSymbs)

assignment := checksumTestCircuit{
Inputs: test_vector_utils.ToVariableSlice(d.D),
InputLen: d.Len(),
Sum: sum,
Inputs: words,
InputLen: d.Len() * dWordLen / int(level),
Sum: checksum,
}
test.NewAssert(t).SolvingSucceeded(&circuit, &assignment, test.WithBackends(backend.PLONK), test.WithCurves(ecc.BN254))

test.NewAssert(t).CheckCircuit(&circuit, test.WithValidAssignment(&assignment), test.WithBackends(backend.PLONK), test.WithCurves(curveId))
}

type checksumTestCircuit struct {
Inputs []frontend.Variable
InputLen frontend.Variable
Sum frontend.Variable
WordLen int
}

func (c *checksumTestCircuit) Define(api frontend.API) error {
if err := checkSnark(api, c.Inputs, len(c.Inputs), c.Sum); err != nil {
return err
}
sum := compress.Checksum(api, c.Inputs, c.InputLen, c.WordLen)
api.AssertIsEqual(c.Sum, sum)
return nil
}
Loading