diff --git a/go.mod b/go.mod index cb5c1b6a85..7a04932c4c 100644 --- a/go.mod +++ b/go.mod @@ -33,4 +33,4 @@ require ( golang.org/x/sys v0.11.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect rsc.io/tmplfunc v0.0.3 // indirect -) +) \ No newline at end of file diff --git a/std/compress/io.go b/std/compress/io.go new file mode 100644 index 0000000000..88cc9bcf64 --- /dev/null +++ b/std/compress/io.go @@ -0,0 +1,141 @@ +package compress + +import ( + "errors" + "github.com/consensys/compress/lzss" + realHash "hash" + "math/big" + + "github.com/consensys/compress" + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark-crypto/hash" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/std/hash/mimc" + test_vector_utils "github.com/consensys/gnark/std/utils/test_vectors_utils" +) + +// Checksum packs the words into as few field elements as possible, and returns the hash of the packed words +func Checksum(api frontend.API, words []frontend.Variable, nbWords frontend.Variable, wordLen int) frontend.Variable { + packed := pack(api, words, wordLen) + hsh, err := mimc.NewMiMC(api) + if err != nil { + return err + } + hsh.Write(packed...) + hsh.Write(nbWords) + return hsh.Sum() +} + +func pack(api frontend.API, words []frontend.Variable, wordLen int) []frontend.Variable { + wordsPerElem := (api.Compiler().FieldBitLen() - 1) / wordLen + packed := make([]frontend.Variable, (len(words)+wordsPerElem-1)/wordsPerElem) + radix := 1 << wordLen + for i := range packed { + packed[i] = 0 + for j := wordsPerElem - 1; j >= 0; j-- { + absJ := i*wordsPerElem + j + if absJ >= len(words) { + continue + } + packed[i] = api.Add(words[absJ], api.Mul(packed[i], radix)) + } + } + return packed +} + +type NumReader struct { + api frontend.API + c []frontend.Variable + stepCoeff int + maxCoeff int + nbWords int + nxt frontend.Variable +} + +func NewNumReader(api frontend.API, c []frontend.Variable, numNbBits, wordNbBits int) *NumReader { + nbWords := numNbBits / wordNbBits + stepCoeff := 1 << wordNbBits + nxt := ReadNum(api, c, nbWords, stepCoeff) + return &NumReader{ + api: api, + c: c, + stepCoeff: stepCoeff, + maxCoeff: 1 << numNbBits, + nxt: nxt, + nbWords: nbWords, + } +} + +func ReadNum(api frontend.API, c []frontend.Variable, nbWords, stepCoeff int) frontend.Variable { + res := frontend.Variable(0) + for i := 0; i < nbWords && i < len(c); i++ { + res = api.Add(c[i], api.Mul(res, stepCoeff)) + } + return res +} + +// Next returns the next number in the sequence. assumes bits past the end of the slice are 0 +func (nr *NumReader) Next() frontend.Variable { + res := nr.nxt + + if len(nr.c) != 0 { + nr.nxt = nr.api.Sub(nr.api.Mul(nr.nxt, nr.stepCoeff), nr.api.Mul(nr.c[0], nr.maxCoeff)) + + if nr.nbWords < len(nr.c) { + nr.nxt = nr.api.Add(nr.nxt, nr.c[nr.nbWords]) + } + + nr.c = nr.c[1:] + } + + return res +} + +// ToSnarkData breaks a stream up into words of the right size for snark consumption, and computes the checksum of that data in a way congruent with Checksum +func ToSnarkData(curveId ecc.ID, s compress.Stream, paddedNbBits int, level lzss.Level) (words []frontend.Variable, checksum []byte, err error) { + + wordNbBits := int(level) + + paddedNbWords := paddedNbBits / wordNbBits + + if paddedNbWords*wordNbBits != paddedNbBits { + return nil, nil, errors.New("the padded size must divide the word length") + } + + wStream := s.BreakUp(1 << wordNbBits) + wPadded := wStream + + if contentNbBits := wStream.Len() * wordNbBits; contentNbBits != paddedNbBits { + wPadded.D = make([]int, paddedNbWords) + copy(wPadded.D, wStream.D) + } + + words = test_vector_utils.ToVariableSlice(wPadded.D) + + var hsh realHash.Hash + switch curveId { + case ecc.BLS12_377: + hsh = hash.MIMC_BLS12_377.New() + case ecc.BN254: + hsh = hash.MIMC_BN254.New() + default: + return nil, nil, errors.New("TODO Add switch-case for curve") + } + + fieldNbBits := curveId.ScalarField().BitLen() + fieldNbBytes := (fieldNbBits + 7) / 8 + packed := wPadded.Pack(fieldNbBits) + byts := make([]byte, fieldNbBytes) + + for _, w := range packed { + w.FillBytes(byts) + hsh.Write(byts) + } + + big.NewInt(int64(wStream.Len())).FillBytes(byts) + hsh.Write(byts) + + checksum = hsh.Sum(nil) + + return +} diff --git a/std/compress/lzss/snark.go b/std/compress/lzss/decompress.go similarity index 100% rename from std/compress/lzss/snark.go rename to std/compress/lzss/decompress.go diff --git a/std/compress/lzss/snark_test.go b/std/compress/lzss/decompress_test.go similarity index 63% rename from std/compress/lzss/snark_test.go rename to std/compress/lzss/decompress_test.go index 9de006de57..fc7370a1e6 100644 --- a/std/compress/lzss/snark_test.go +++ b/std/compress/lzss/decompress_test.go @@ -14,19 +14,19 @@ import ( "testing" ) -func Test1ZeroSnark(t *testing.T) { - testCompressionRoundTripSnark(t, []byte{0}, nil) +func Test1Zero(t *testing.T) { + testCompressionRoundTrip(t, []byte{0}, nil) } -func TestGoodCompressionSnark(t *testing.T) { - testCompressionRoundTripSnark(t, []byte{1, 2}, nil, withLevel(lzss.GoodCompression)) +func TestGoodCompression(t *testing.T) { + testCompressionRoundTrip(t, []byte{1, 2}, nil, withLevel(lzss.GoodCompression)) } -func Test0To10ExplicitSnark(t *testing.T) { - testCompressionRoundTripSnark(t, []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, nil) +func Test0To10Explicit(t *testing.T) { + testCompressionRoundTrip(t, []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, nil) } -func TestNoCompressionSnark(t *testing.T) { +func TestNoCompression(t *testing.T) { d, err := os.ReadFile("./testdata/3c2943/data.bin") assert.NoError(t, err) @@ -58,21 +58,21 @@ func TestNoCompressionSnark(t *testing.T) { test.NewAssert(t).CheckCircuit(circuit, test.WithValidAssignment(assignment), test.WithBackends(backend.PLONK), test.WithCurves(ecc.BN254)) } -func Test255_254_253Snark(t *testing.T) { - testCompressionRoundTripSnark(t, []byte{255, 254, 253}, nil) +func Test255_254_253(t *testing.T) { + testCompressionRoundTrip(t, []byte{255, 254, 253}, nil) } -func Test3c2943Snark(t *testing.T) { +func Test3c2943(t *testing.T) { d, err := os.ReadFile("./testdata/3c2943/data.bin") assert.NoError(t, err) dict := getDictionary() - testCompressionRoundTripSnark(t, d, dict) + testCompressionRoundTrip(t, d, dict) } // Fuzz test the decompression -func FuzzSnark(f *testing.F) { // TODO This is always skipped +func Fuzz(f *testing.F) { // TODO This is always skipped f.Fuzz(func(t *testing.T, input, dict []byte) { if len(input) > lzss.MaxInputSize { t.Skip("input too large") @@ -83,7 +83,7 @@ func FuzzSnark(f *testing.F) { // TODO This is always skipped if len(input) == 0 { t.Skip("input too small") } - testCompressionRoundTripSnark(t, input, dict) + testCompressionRoundTrip(t, input, dict) }) } @@ -95,7 +95,7 @@ func withLevel(level lzss.Level) testCompressionRoundTripOption { } } -func testCompressionRoundTripSnark(t *testing.T, d, dict []byte, options ...testCompressionRoundTripOption) { +func testCompressionRoundTrip(t *testing.T, d, dict []byte, options ...testCompressionRoundTripOption) { level := lzss.BestCompression @@ -126,36 +126,6 @@ func testCompressionRoundTripSnark(t *testing.T, d, dict []byte, options ...test test.NewAssert(t).CheckCircuit(circuit, test.WithValidAssignment(assignment), test.WithBackends(backend.PLONK), test.WithCurves(ecc.BN254)) } -func TestReadBytes(t *testing.T) { - expected := []byte{254, 0, 0, 0} - circuit := &readBytesCircuit{ - Words: make([]frontend.Variable, 8*len(expected)), - WordNbBits: 1, - Expected: expected, - } - words, err := goCompress.NewStream(expected, 8) - assert.NoError(t, err) - words = words.BreakUp(2) - assignment := &readBytesCircuit{ - Words: test_vector_utils.ToVariableSlice(words.D), - } - test.NewAssert(t).CheckCircuit(circuit, test.WithValidAssignment(assignment), test.WithBackends(backend.PLONK), test.WithCurves(ecc.BN254)) -} - -type readBytesCircuit struct { - Words []frontend.Variable - WordNbBits int - Expected []byte -} - -func (c *readBytesCircuit) Define(api frontend.API) error { - byts := combineIntoBytes(api, c.Words, c.WordNbBits) - for i := range c.Expected { - api.AssertIsEqual(c.Expected[i], byts[i*8]) - } - return nil -} - func getDictionary() []byte { d, err := os.ReadFile("./testdata/dict_naive") if err != nil { diff --git a/std/compress/lzss/e2e_test.go b/std/compress/lzss/e2e_test.go index 2a303d3d94..602155e108 100644 --- a/std/compress/lzss/e2e_test.go +++ b/std/compress/lzss/e2e_test.go @@ -3,13 +3,14 @@ package lzss import ( goCompress "github.com/consensys/compress" "github.com/consensys/compress/lzss" + "github.com/consensys/gnark/std/compress" + "math/bits" "os" "testing" "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark/backend" "github.com/consensys/gnark/frontend" - test_vector_utils "github.com/consensys/gnark/std/utils/test_vectors_utils" "github.com/consensys/gnark/test" "github.com/stretchr/testify/assert" ) @@ -33,27 +34,30 @@ func testCompressionE2E(t *testing.T, d, dict []byte, name string) { // compress level := lzss.GoodCompression + wordNbBits := int(level) + const curveId = ecc.BLS12_377 + compressor, err := lzss.NewCompressor(dict, level) assert.NoError(t, err) c, err := compressor.Compress(d) assert.NoError(t, err) - cStream, err := goCompress.NewStream(c, uint8(level)) + cStream, err := goCompress.NewStream(c, uint8(wordNbBits)) assert.NoError(t, err) - cSum, err := check(cStream, cStream.Len()) + cWords, cSum, err := compress.ToSnarkData(curveId, cStream, wordNbBits*cStream.Len(), level) assert.NoError(t, err) dStream, err := goCompress.NewStream(d, 8) assert.NoError(t, err) - dSum, err := check(dStream, len(d)) + dWords, dSum, err := compress.ToSnarkData(curveId, dStream, len(d)*8, 8) assert.NoError(t, err) circuit := compressionCircuit{ - C: make([]frontend.Variable, cStream.Len()), - D: make([]frontend.Variable, len(d)), + C: make([]frontend.Variable, len(cWords)), + D: make([]frontend.Variable, len(dWords)), Dict: make([]byte, len(dict)), Level: level, } @@ -63,45 +67,60 @@ func testCompressionE2E(t *testing.T, d, dict []byte, name string) { assignment := compressionCircuit{ CChecksum: cSum, DChecksum: dSum, - C: test_vector_utils.ToVariableSlice(cStream.D), - D: test_vector_utils.ToVariableSlice(d), + C: cWords, + D: dWords, Dict: dict, CLen: cStream.Len(), DLen: len(d), } - test.NewAssert(t).SolvingSucceeded(&circuit, &assignment, test.WithBackends(backend.PLONK), test.WithCurves(ecc.BN254)) + + test.NewAssert(t).CheckCircuit(&circuit, test.WithValidAssignment(&assignment), test.WithBackends(backend.PLONK), test.WithCurves(curveId)) +} + +func TestChecksumNothing(t *testing.T) { + testChecksum(t, goCompress.Stream{D: []int{}, NbSymbs: 256}, 0, lzss.BestSnarkDecompression) +} + +func TestChecksumOne(t *testing.T) { + testChecksum(t, goCompress.Stream{D: []int{1}, NbSymbs: 256}, 8, lzss.BestSnarkDecompression) } -func TestChecksum0(t *testing.T) { - testChecksum(t, goCompress.Stream{D: []int{}, NbSymbs: 256}) +func TestChecksumOneWithBits(t *testing.T) { + testChecksum(t, goCompress.Stream{D: []int{1}, NbSymbs: 256}, 9, lzss.BestCompression) } -func testChecksum(t *testing.T, d goCompress.Stream) { +func testChecksum(t *testing.T, d goCompress.Stream, paddedNbBits int, level lzss.Level) { + const curveId = ecc.BLS12_377 + + words, checksum, err := compress.ToSnarkData(curveId, d, paddedNbBits, level) + assert.NoError(t, err) + circuit := checksumTestCircuit{ - Inputs: make([]frontend.Variable, d.Len()), - InputLen: d.Len(), + Inputs: make([]frontend.Variable, len(words)), + WordLen: int(level), } - sum, err := check(d, d.Len()) - assert.NoError(t, err) + dWordLen := 63 - bits.LeadingZeros64(uint64(d.NbSymbs)) + assert.Equal(t, 1<= len(words) { - break - } - res[elemI] = api.Add(res[elemI], api.Mul(words[absWordI], 1<