Skip to content

Commit

Permalink
[MySQL Conformance] fix tree size bug (#382)
Browse files Browse the repository at this point in the history
Major changes:
 - MySQL storage read methods return os.ErrNotExist when values aren't found
 - ReadTile returns an error if the user requests more data than we have available
 - Added tests for writing and reading data from tiles
 - Made tests hermetic (though slower) by resetting the DB for each test case

This got a bit bigger than intended. This fixes #364.
  • Loading branch information
mhutchinson authored Dec 5, 2024
1 parent 653c4e0 commit c57fd57
Show file tree
Hide file tree
Showing 3 changed files with 146 additions and 30 deletions.
15 changes: 9 additions & 6 deletions cmd/conformance/mysql/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -165,17 +165,17 @@ func configureTilesReadAPI(mux *http.ServeMux, storage *mysql.Storage) {
}
return
}
impliedSize := (index*256 + width) << (level * 8)
tile, err := storage.ReadTile(r.Context(), level, index, impliedSize)
inferredMinTreeSize := (index*256 + width) << (level * 8)
tile, err := storage.ReadTile(r.Context(), level, index, inferredMinTreeSize)
if err != nil {
if os.IsNotExist(err) {
w.WriteHeader(http.StatusNotFound)
return
}
klog.Errorf("/tile/{level}/{index...}: %v", err)
w.WriteHeader(http.StatusInternalServerError)
return
}
if tile == nil {
w.WriteHeader(http.StatusNotFound)
return
}

w.Header().Set("Cache-Control", "public, max-age=31536000, immutable")

Expand Down Expand Up @@ -207,6 +207,9 @@ func configureTilesReadAPI(mux *http.ServeMux, storage *mysql.Storage) {
}

// TODO: Add immutable Cache-Control header.
// Only do this once we're sure we're returning the right number of entries
// Currently a user can request a full tile and we can return a partial tile.
// If cache headers were set then this could cause caches to be poisoned.

if _, err := w.Write(entryBundle); err != nil {
klog.Errorf("/tile/entries/{index...}: %v", err)
Expand Down
54 changes: 33 additions & 21 deletions storage/mysql/mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,11 @@ package mysql
import (
"bytes"
"context"
"crypto/sha256"
"database/sql"
"errors"
"fmt"
"os"
"strings"
"time"

Expand Down Expand Up @@ -121,7 +123,7 @@ func (s *Storage) maybeInitTree(ctx context.Context) error {
}()

treeState, err := s.readTreeState(ctx, tx)
if err != nil {
if err != nil && !os.IsNotExist(err) {
klog.Errorf("Failed to read tree state: %v", err)
return err
}
Expand All @@ -142,7 +144,7 @@ func (s *Storage) maybeInitTree(ctx context.Context) error {
}

// ReadCheckpoint returns the latest stored checkpoint.
// If the checkpoint is not found, nil is returned with no error.
// If the checkpoint is not found, it returns os.ErrNotExist.
func (s *Storage) ReadCheckpoint(ctx context.Context) ([]byte, error) {
row := s.db.QueryRowContext(ctx, selectCheckpointByIDSQL, checkpointID)
if err := row.Err(); err != nil {
Expand All @@ -153,7 +155,7 @@ func (s *Storage) ReadCheckpoint(ctx context.Context) ([]byte, error) {
var at int64
if err := row.Scan(&checkpoint, &at); err != nil {
if err == sql.ErrNoRows {
return nil, nil
return nil, os.ErrNotExist
}
return nil, fmt.Errorf("scan checkpoint: %v", err)
}
Expand Down Expand Up @@ -207,7 +209,7 @@ type treeState struct {
}

// readTreeState returns the currently stored tree state information.
// If there is no stored tree state, nil is returned with no error.
// If there is no stored tree state, it returns os.ErrNotExist.
func (s *Storage) readTreeState(ctx context.Context, tx *sql.Tx) (*treeState, error) {
row := tx.QueryRowContext(ctx, selectTreeStateByIDForUpdateSQL, treeStateID)
if err := row.Err(); err != nil {
Expand All @@ -217,7 +219,7 @@ func (s *Storage) readTreeState(ctx context.Context, tx *sql.Tx) (*treeState, er
r := &treeState{}
if err := row.Scan(&r.size, &r.root); err != nil {
if err == sql.ErrNoRows {
return nil, nil
return nil, os.ErrNotExist
}
return nil, fmt.Errorf("scan tree state: %v", err)
}
Expand All @@ -234,16 +236,13 @@ func (s *Storage) writeTreeState(ctx context.Context, tx *sql.Tx, size uint64, r
return nil
}

// ReadTile returns a full tile or a partial tile at the given level, index and width.
// If the tile is not found, nil is returned with no error.
// ReadTile returns a full tile or a partial tile at the given level, index and treeSize.
// If the tile is not found, it returns os.ErrNotExist.
//
// TODO: Handle the following scenarios:
// 1. Full tile request with full tile output: Return full tile.
// 2. Full tile request with partial tile output: Return error.
// 3. Partial tile request with full/larger partial tile output: Return trimmed partial tile with correct tile width.
// 4. Partial tile request with partial tile (same width) output: Return partial tile.
// 5. Partial tile request with smaller partial tile output: Return error.
func (s *Storage) ReadTile(ctx context.Context, level, index, width uint64) ([]byte, error) {
// Note that if a partial tile is requested, but a larger tile is available, this
// will return the largest tile available. This could be trimmed to return only the
// number of entries specifically requested if this behaviour becomes problematic.
func (s *Storage) ReadTile(ctx context.Context, level, index, minTreeSize uint64) ([]byte, error) {
row := s.db.QueryRowContext(ctx, selectSubtreeByLevelAndIndexSQL, level, index)
if err := row.Err(); err != nil {
return nil, err
Expand All @@ -252,20 +251,34 @@ func (s *Storage) ReadTile(ctx context.Context, level, index, width uint64) ([]b
var tile []byte
if err := row.Scan(&tile); err != nil {
if err == sql.ErrNoRows {
return nil, nil
return nil, os.ErrNotExist
}

return nil, fmt.Errorf("scan tile: %v", err)
}

// Return nil when returning a partial tile on a full tile request.
if width == 256 && uint64(len(tile)/32) != width {
return nil, nil
requestedWidth := partialTileSize(level, index, minTreeSize)
numEntries := uint64(len(tile) / sha256.Size)

if requestedWidth > numEntries {
// If the user has requested a size larger than we have, they can't have it
return nil, os.ErrNotExist
}

return tile, nil
}

// partialTileSize returns the expected number of leaves in a tile at the given location within
// a tree of the specified logSize, or 0 if the tile is expected to be fully populated.
func partialTileSize(level, index, logSize uint64) uint64 {
sizeAtLevel := logSize >> (level * 8)
fullTiles := sizeAtLevel / 256
if index < fullTiles {
return 256
}
return sizeAtLevel % 256
}

// writeTile replaces the tile nodes at the given level and index.
func (s *Storage) writeTile(ctx context.Context, tx *sql.Tx, level, index uint64, nodes []byte) error {
if _, err := tx.ExecContext(ctx, replaceSubtreeSQL, level, index, nodes); err != nil {
Expand All @@ -277,7 +290,7 @@ func (s *Storage) writeTile(ctx context.Context, tx *sql.Tx, level, index uint64
}

// ReadEntryBundle returns the log entries at the given index.
// If the entry bundle is not found, nil is returned with no error.
// If the entry bundle is not found, it returns os.ErrNotExist.
//
// TODO: Handle the following scenarios:
// 1. Full tile request with full tile output: Return full tile.
Expand All @@ -294,9 +307,8 @@ func (s *Storage) ReadEntryBundle(ctx context.Context, index, treeSize uint64) (
var entryBundle []byte
if err := row.Scan(&entryBundle); err != nil {
if err == sql.ErrNoRows {
return nil, nil
return nil, os.ErrNotExist
}

return nil, fmt.Errorf("scan entry bundle: %v", err)
}

Expand Down
107 changes: 104 additions & 3 deletions storage/mysql/mysql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,10 @@ package mysql_test
import (
"bytes"
"context"
"crypto/sha256"
"database/sql"
"flag"
"fmt"
"os"
"testing"
"time"
Expand All @@ -48,8 +50,8 @@ var (
)

const (
// Matching public key: "transparency.dev/tessera/example+ae330e15+ASf4/L1zE859VqlfQgGzKy34l91Gl8W6wfwp+vKP62DW"
testPrivateKey = "PRIVATE+KEY+transparency.dev/tessera/example+ae330e15+AXEwZQ2L6Ga3NX70ITObzyfEIketMr2o9Kc+ed/rt/QR"
testPublicKey = "transparency.dev/tessera/example+ae330e15+ASf4/L1zE859VqlfQgGzKy34l91Gl8W6wfwp+vKP62DW"
)

// TestMain checks whether the test MySQL database is available and starts the tests including database schema initialization.
Expand Down Expand Up @@ -163,6 +165,93 @@ func TestNew(t *testing.T) {
}
}

func TestGetTile(t *testing.T) {
ctx := context.Background()
s := newTestMySQLStorage(t, ctx)

awaiter := tessera.NewIntegrationAwaiter(ctx, s.ReadCheckpoint, 10*time.Millisecond)

treeSize := 258
var lastIndex uint64
for i := range treeSize {
idx, _, err := awaiter.Await(ctx, s.Add(ctx, tessera.NewEntry([]byte(fmt.Sprintf("TestGetTile %d", i)))))
if err != nil {
t.Fatalf("Failed to prep test with entry: %v", err)
}
if idx > lastIndex {
lastIndex = idx
}
}
if got, want := lastIndex, uint64(treeSize-1); got != want {
t.Fatalf("expected only newly created entries in database; tests are not hermetic (got %d, want %d)", got, want)
}

for _, test := range []struct {
name string
level, index, treeSize uint64
wantEntries int
wantNotFound bool
}{
{
name: "requested partial tile for a complete tile",
level: 0, index: 0, treeSize: 10,
wantEntries: 256,
wantNotFound: false,
},
{
name: "too small but that's ok",
level: 0, index: 1, treeSize: uint64(treeSize) - 1,
wantEntries: 2,
wantNotFound: false,
},
{
name: "just right",
level: 0, index: 1, treeSize: uint64(treeSize),
wantEntries: 2,
wantNotFound: false,
},
{
name: "too big",
level: 0, index: 1, treeSize: uint64(treeSize + 1),
wantNotFound: true,
},
{
name: "level 1 too small",
level: 1, index: 0, treeSize: uint64(treeSize - 1),
wantEntries: 1,
wantNotFound: false,
},
{
name: "level 1 just right",
level: 1, index: 0, treeSize: uint64(treeSize),
wantEntries: 1,
wantNotFound: false,
},
{
name: "level 1 too big",
level: 1, index: 0, treeSize: 550,
wantNotFound: true,
},
} {
t.Run(test.name, func(t *testing.T) {
tile, err := s.ReadTile(ctx, test.level, test.index, test.treeSize)
if err != nil {
if notFound, wantNotFound := os.IsNotExist(err), test.wantNotFound; notFound != wantNotFound {
t.Errorf("wantNotFound %v but notFound %v", wantNotFound, notFound)
}
if test.wantNotFound {
return
}
t.Errorf("got err: %v", err)
}
numEntries := len(tile) / sha256.Size
if got, want := numEntries, test.wantEntries; got != want {
t.Errorf("got %d entries, but want %d", got, want)
}
})
}
}

func TestReadMissingTile(t *testing.T) {
ctx := context.Background()
s := newTestMySQLStorage(t, ctx)
Expand All @@ -183,6 +272,10 @@ func TestReadMissingTile(t *testing.T) {
t.Run(test.name, func(t *testing.T) {
tile, err := s.ReadTile(ctx, test.level, test.index, test.width)
if err != nil {
if os.IsNotExist(err) {
// this is success for this test
return
}
t.Errorf("got err: %v", err)
}
if tile != nil {
Expand Down Expand Up @@ -212,6 +305,10 @@ func TestReadMissingEntryBundle(t *testing.T) {
t.Run(test.name, func(t *testing.T) {
entryBundle, err := s.ReadEntryBundle(ctx, test.index, test.index)
if err != nil {
if os.IsNotExist(err) {
// this is success for this test
return
}
t.Errorf("got err: %v", err)
}
if entryBundle != nil {
Expand Down Expand Up @@ -286,7 +383,7 @@ func TestTileRoundTrip(t *testing.T) {
}

tileLevel, tileIndex, _, nodeIndex := layout.NodeCoordsToTileAddress(0, entryIndex)
tileRaw, err := s.ReadTile(ctx, tileLevel, tileIndex, nodeIndex)
tileRaw, err := s.ReadTile(ctx, tileLevel, tileIndex, nodeIndex+1)
if err != nil {
t.Errorf("ReadTile got err: %v", err)
}
Expand Down Expand Up @@ -358,8 +455,12 @@ func TestEntryBundleRoundTrip(t *testing.T) {

func newTestMySQLStorage(t *testing.T, ctx context.Context) *mysql.Storage {
t.Helper()
initDatabaseSchema(ctx)

s, err := mysql.New(ctx, testDB, tessera.WithCheckpointSigner(noteSigner))
s, err := mysql.New(ctx, testDB,
tessera.WithCheckpointSigner(noteSigner),
tessera.WithCheckpointInterval(200*time.Millisecond),
tessera.WithBatching(128, 100*time.Millisecond))
if err != nil {
t.Errorf("Failed to create mysql.Storage: %v", err)
}
Expand Down

0 comments on commit c57fd57

Please sign in to comment.