Skip to content

Commit

Permalink
recompute and cache axes independently
Browse files Browse the repository at this point in the history
  • Loading branch information
walldiss committed Nov 30, 2023
1 parent 46067e7 commit 196cd3a
Show file tree
Hide file tree
Showing 2 changed files with 152 additions and 49 deletions.
163 changes: 117 additions & 46 deletions share/eds/cache_file.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,12 @@ import (
"github.com/celestiaorg/celestia-node/share/ipld"
)

// TODO: allow concurrency safety fpr CacheFile methods
type CacheFile struct {
File

// TODO(@walldiss): add columns support
rowCache map[int]inMemoryAxis
codec rsmt2d.Codec
axisCache []map[int]inMemoryAxis
// disableCache disables caching of rows for testing purposes
disableCache bool
}
Expand All @@ -32,10 +33,11 @@ type inMemoryAxis struct {
proofs blockservice.BlockGetter
}

func NewCacheFile(f File) *CacheFile {
func NewCacheFile(f File, codec rsmt2d.Codec) *CacheFile {
return &CacheFile{
File: f,
rowCache: make(map[int]inMemoryAxis),
File: f,
codec: codec,
axisCache: []map[int]inMemoryAxis{make(map[int]inMemoryAxis), make(map[int]inMemoryAxis)},
}
}

Expand All @@ -51,72 +53,134 @@ func (f *CacheFile) ShareWithProof(
axsIdx, shrIdx = shrIdx, axsIdx
}

row := f.rowCache[axsIdx]
if row.proofs == nil {
shrs, err := f.Axis(axsIdx, axis)
if err != nil {
return nil, err
}
ax, err := f.axisWithProofs(axsIdx, axis)
if err != nil {
return nil, err
}

// calculate proofs
adder := ipld.NewProofsAdder(sqrLn*2, ipld.CollectShares)
tree := wrapper.NewErasuredNamespacedMerkleTree(uint64(sqrLn/2), uint(axsIdx),
nmt.NodeVisitor(adder.VisitFn()))
for _, shr := range shrs {
err = tree.Push(shr)
if err != nil {
return nil, err
}
}
// TODO(@walldiss): add proper calc to prealloc size for proofs
proof := make([]cid.Cid, 0, 16)
rootCid := ipld.MustCidFromNamespacedSha256(axisRoot)
proofs, err := ipld.GetProof(ctx, ax.proofs, rootCid, proof, shrIdx, sqrLn)
if err != nil {
return nil, fmt.Errorf("bulding proof from cache: %w", err)
}

if _, err := tree.Root(); err != nil {
return nil, err
}
return byzantine.NewShareWithProof(shrIdx, ax.shares[shrIdx], proofs), nil
}

row = inMemoryAxis{
shares: shrs,
proofs: newRowProofsGetter(adder.Proofs()),
}
func (f *CacheFile) axisWithProofs(idx int, axis rsmt2d.Axis) (inMemoryAxis, error) {
ax := f.axisCache[axis][idx]
if ax.proofs != nil {
return ax, nil
}

if !f.disableCache {
f.rowCache[axsIdx] = row
// build proofs from shares and cache them
shrs, err := f.Axis(idx, axis)
if err != nil {
return inMemoryAxis{}, err
}

fmt.Println("building proofs for axis", idx, axis)
// calculate proofs
adder := ipld.NewProofsAdder(f.Size(), ipld.CollectShares)
tree := wrapper.NewErasuredNamespacedMerkleTree(uint64(f.Size()/2), uint(idx),
nmt.NodeVisitor(adder.VisitFn()))
for _, shr := range shrs {
err = tree.Push(shr)
if err != nil {
return inMemoryAxis{}, err
}
}

// TODO(@walldiss): find prealloc size for proofs
proof := make([]cid.Cid, 0, 8)
rootCid := ipld.MustCidFromNamespacedSha256(axisRoot)
proofs, err := ipld.GetProof(ctx, row.proofs, rootCid, proof, shrIdx, sqrLn)
if err != nil {
return nil, fmt.Errorf("bulding proof from cache: %w", err)
// build the tree
if _, err := tree.Root(); err != nil {
return inMemoryAxis{}, err
}

return byzantine.NewShareWithProof(shrIdx, row.shares[shrIdx], proofs), nil
ax = f.axisCache[axis][idx]
ax.proofs = newRowProofsGetter(adder.Proofs())

if !f.disableCache {
f.axisCache[axis][idx] = ax
}
return ax, nil
}

func (f *CacheFile) Axis(idx int, axis rsmt2d.Axis) ([]share.Share, error) {
row, ok := f.rowCache[idx]
// return axis from cache if possible
ax, ok := f.axisCache[axis][idx]
if ok {
return row.shares, nil
return ax.shares, nil
}

shrs, err := f.File.Axis(idx, axis)
// recompute axis from half
original, err := f.AxisHalf(idx, axis)
if err != nil {
return nil, err
}

// cache row shares
parity, err := f.codec.Encode(original)
if err != nil {
return nil, err
}

shares := make([]share.Share, 0, len(original)+len(parity))
shares = append(shares, original...)
shares = append(shares, parity...)

// cache axis shares
if !f.disableCache {
f.rowCache[idx] = inMemoryAxis{
shares: shrs,
f.axisCache[axis][idx] = inMemoryAxis{
shares: shares,
}
}
return shares, nil
}

func (f *CacheFile) AxisHalf(idx int, axis rsmt2d.Axis) ([]share.Share, error) {
// return axis from cache if possible
ax, ok := f.axisCache[axis][idx]
if ok {
return ax.shares[:f.Size()/2], nil
}

// read axis from file if axis is in the first quadrant
if idx < f.Size()/2 {
return f.File.AxisHalf(idx, axis)
}

shares := make([]share.Share, 0, f.Size()/2)
// extend opposite half of the square while collecting shares for the first half of required axis
//TODO: parallelize this
for i := 0; i < f.Size()/2; i++ {
ax, err := f.Axis(i, oppositeAxis(axis))
if err != nil {
return nil, err
}
shares = append(shares, ax[idx])
}
return shrs, nil
return shares, nil
}

// TODO(@walldiss): needs to be implemented
func (f *CacheFile) EDS() (*rsmt2d.ExtendedDataSquare, error) {
return f.File.EDS()
shares := make([][]byte, 0, f.Size()*f.Size())
for i := 0; i < f.Size(); i++ {
ax, err := f.Axis(i, rsmt2d.Row)
if err != nil {
return nil, err
}
shares = append(shares, ax...)
}

eds, err := rsmt2d.ImportExtendedDataSquare(
shares,
share.DefaultRSMT2DCodec(),
wrapper.NewConstructor(uint64(f.Size())/2))
if err != nil {
return nil, fmt.Errorf("recomputing data square: %w", err)
}
return eds, nil
}

// rowProofsGetter implements blockservice.BlockGetter interface
Expand Down Expand Up @@ -144,3 +208,10 @@ func (r rowProofsGetter) GetBlock(_ context.Context, c cid.Cid) (blocks.Block, e
func (r rowProofsGetter) GetBlocks(_ context.Context, _ []cid.Cid) <-chan blocks.Block {
panic("not implemented")
}

func oppositeAxis(axis rsmt2d.Axis) rsmt2d.Axis {
if axis == rsmt2d.Col {
return rsmt2d.Row
}
return rsmt2d.Col
}
38 changes: 35 additions & 3 deletions share/eds/file_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,12 +124,44 @@ func BenchmarkGetShareFromCache(b *testing.B) {
minSize, maxSize := 16, 128
newFile := func(size int) File {
sqr := edstest.RandEDS(b, size)
return NewCacheFile(&MemFile{Eds: sqr})
return NewCacheFile(&MemFile{Eds: sqr}, rsmt2d.NewLeoRSCodec())
}

benchGetShareFromFile(b, newFile, minSize, maxSize)
}

func TestCachedAxis(t *testing.T) {
sqr := edstest.RandEDS(t, 32)
mem := &MemFile{Eds: sqr}
file := NewCacheFile(mem, rsmt2d.NewLeoRSCodec())

for i := 0; i < mem.Size(); i++ {
a1, err := mem.Axis(i, rsmt2d.Row)
require.NoError(t, err)
a2, err := file.Axis(i, rsmt2d.Row)
require.NoError(t, err)
require.Equal(t, a1, a2)
}
}

func TestCachedEDS(t *testing.T) {
sqr := edstest.RandEDS(t, 32)
mem := &MemFile{Eds: sqr}
file := NewCacheFile(mem, rsmt2d.NewLeoRSCodec())

e1, err := mem.EDS()
require.NoError(t, err)
e2, err := file.EDS()
require.NoError(t, err)

r1, err := e1.RowRoots()
require.NoError(t, err)
r2, err := e2.RowRoots()
require.NoError(t, err)

require.Equal(t, r1, r2)
}

// BenchmarkGetShareFromCacheMiss/16 16308 72295 ns/op
// BenchmarkGetShareFromCacheMiss/32 8216 141334 ns/op
// BenchmarkGetShareFromCacheMiss/64 3877 284171 ns/op
Expand All @@ -138,7 +170,7 @@ func BenchmarkGetShareFromCacheMiss(b *testing.B) {
minSize, maxSize := 16, 128
newFile := func(size int) File {
sqr := edstest.RandEDS(b, size)
f := NewCacheFile(&MemFile{Eds: sqr})
f := NewCacheFile(&MemFile{Eds: sqr}, rsmt2d.NewLeoRSCodec())
f.disableCache = true
return f
}
Expand Down Expand Up @@ -171,7 +203,7 @@ func TestCacheMemoryUsage(t *testing.T) {
eds := edstest.RandEDS(t, size)
df, err := CreateFile(dir+"/"+strconv.Itoa(i), eds)
require.NoError(t, err)
f := NewCacheFile(df)
f := NewCacheFile(df, rsmt2d.NewLeoRSCodec())

rows, err := eds.RowRoots()
require.NoError(t, err)
Expand Down

0 comments on commit 196cd3a

Please sign in to comment.