diff --git a/array.go b/array.go index ec4d156c..f661e9c4 100644 --- a/array.go +++ b/array.go @@ -246,25 +246,29 @@ func newArrayDataSlabFromData( return nil, NewDecodingErrorf("data is too short for array data slab") } - version, flag := data[0], data[1] + h, err := newHeadFromData(data[:versionAndFlagSize]) + if err != nil { + return nil, NewDecodingError(err) + } - if getSlabArrayType(flag) != slabArrayData { + if h.getSlabArrayType() != slabArrayData { return nil, NewDecodingErrorf( - "data has invalid flag 0x%x, want 0x%x", - flag, - maskArrayData, + "data has invalid head 0x%x, want array data slab flag", + h[:], ) } - switch version { + data = data[versionAndFlagSize:] + + switch h.version() { case 0: - return newArrayDataSlabFromDataV0(id, data, decMode, decodeStorable, decodeTypeInfo) + return newArrayDataSlabFromDataV0(id, h, data, decMode, decodeStorable, decodeTypeInfo) case 1: - return newArrayDataSlabFromDataV1(id, data, decMode, decodeStorable, decodeTypeInfo) + return newArrayDataSlabFromDataV1(id, h, data, decMode, decodeStorable, decodeTypeInfo) default: - return nil, NewDecodingErrorf("unexpected version %d for array data slab", version) + return nil, NewDecodingErrorf("unexpected version %d for array data slab", h.version()) } } @@ -289,6 +293,7 @@ func newArrayDataSlabFromData( // See ArrayExtraData.Encode() for extra data section format. func newArrayDataSlabFromDataV0( id SlabID, + h head, data []byte, decMode cbor.DecMode, decodeStorable StorableDecoder, @@ -297,59 +302,51 @@ func newArrayDataSlabFromDataV0( *ArrayDataSlab, error, ) { - // Check minimum data length - if len(data) < versionAndFlagSize { - return nil, NewDecodingErrorf("data is too short for array data slab") - } - - isRootSlab := isRoot(data[1]) + var err error var extraData *ArrayExtraData // Check flag for extra data - if isRootSlab { + if h.isRoot() { // Decode extra data - var err error - extraData, data, err = newArrayExtraDataFromData(data[versionAndFlagSize:], decMode, decodeTypeInfo) + extraData, data, err = newArrayExtraDataFromData(data, decMode, decodeTypeInfo) if err != nil { // err is categorized already by newArrayExtraDataFromData. return nil, err } - } - minDataLength := arrayDataSlabPrefixSize - if isRootSlab { - minDataLength = arrayRootDataSlabPrefixSize - } + // Skip second head (version + flag) here because it is only present in root slab in version 0. + if len(data) < versionAndFlagSize { + return nil, NewDecodingErrorf("data is too short for array data slab") + } - // Check data length (after decoding extra data if present) - if len(data) < minDataLength { - return nil, NewDecodingErrorf("data is too short for array data slab") + data = data[versionAndFlagSize:] } var next SlabID - - var contentOffset int - - if !isRootSlab { + if !h.isRoot() { + // Check data length for next slab ID + if len(data) < slabIDSize { + return nil, NewDecodingErrorf("data is too short for array data slab") + } // Decode next slab ID - const nextSlabIDOffset = versionAndFlagSize - var err error - next, err = NewSlabIDFromRawBytes(data[nextSlabIDOffset:]) + next, err = NewSlabIDFromRawBytes(data) if err != nil { // error returned from NewSlabIDFromRawBytes is categorized already. return nil, err } - contentOffset = nextSlabIDOffset + slabIDSize + data = data[slabIDSize:] + } - } else { - contentOffset = versionAndFlagSize + // Check data length for array element head + if len(data) < arrayDataSlabElementHeadSize { + return nil, NewDecodingErrorf("data is too short for array data slab") } // Decode content (CBOR array) - cborDec := decMode.NewByteStreamDecoder(data[contentOffset:]) + cborDec := decMode.NewByteStreamDecoder(data) elemCount, err := cborDec.DecodeArrayHead() if err != nil { @@ -368,7 +365,7 @@ func newArrayDataSlabFromDataV0( // Compute slab size for version 1. slabSize := versionAndFlagSize + cborDec.NumBytesDecoded() - if !isRootSlab { + if !h.isRoot() { slabSize += slabIDSize } @@ -407,7 +404,8 @@ func newArrayDataSlabFromDataV0( // See ArrayExtraData.Encode() for extra data section format. func newArrayDataSlabFromDataV1( id SlabID, - data []byte, + h head, + data []byte, // data doesn't include head (first two bytes) decMode cbor.DecMode, decodeStorable StorableDecoder, decodeTypeInfo TypeInfoDecoder, @@ -415,29 +413,21 @@ func newArrayDataSlabFromDataV1( *ArrayDataSlab, error, ) { - // Check minimum data length - if len(data) < versionAndFlagSize { - return nil, NewDecodingErrorf("data is too short for array data slab") - } - - isRootSlab := isRoot(data[1]) - - data = data[versionAndFlagSize:] - var err error var extraData *ArrayExtraData var next SlabID - // Decode header - if isRootSlab { - // Decode extra data + // Decode extra data + if h.isRoot() { extraData, data, err = newArrayExtraDataFromData(data, decMode, decodeTypeInfo) if err != nil { // err is categorized already by newArrayExtraDataFromData. return nil, err } - } else { - // Decode next slab ID + } + + // Decode next slab ID + if h.hasNextSlabID() { next, err = NewSlabIDFromRawBytes(data) if err != nil { // error returned from NewSlabIDFromRawBytes is categorized already. @@ -477,7 +467,7 @@ func newArrayDataSlabFromDataV1( // Compute slab size for version 1. slabSize := versionAndFlagSize + cborDec.NumBytesDecoded() - if !isRootSlab { + if !h.isRoot() { slabSize += slabIDSize } @@ -516,23 +506,27 @@ func newArrayDataSlabFromDataV1( // See ArrayExtraData.Encode() for extra data section format. func (a *ArrayDataSlab) Encode(enc *Encoder) error { - flag := maskArrayData + const version = 1 - if a.hasPointer() { - flag = setHasPointers(flag) + h, err := newArraySlabHead(version, slabArrayData) + if err != nil { + return NewEncodingError(err) } - if a.extraData != nil { - flag = setRoot(flag) + if a.hasPointer() { + h.setHasPointers() } - // Encode version - enc.Scratch[0] = 1 + if a.next != SlabIDUndefined { + h.setHasNextSlabID() + } - // Encode flag - enc.Scratch[1] = flag + if a.extraData != nil { + h.setRoot() + } - _, err := enc.Write(enc.Scratch[:versionAndFlagSize]) + // Encode head (version + flag) + _, err = enc.Write(h[:]) if err != nil { return NewEncodingError(err) } @@ -545,8 +539,10 @@ func (a *ArrayDataSlab) Encode(enc *Encoder) error { // err is already categorized by ArrayExtraData.Encode(). return err } - } else { - // Encode next slab ID to scratch + } + + // Encode next slab ID + if a.next != SlabIDUndefined { n, err := a.next.ToRawBytes(enc.Scratch[:]) if err != nil { // Don't need to wrap because err is already categorized by SlabID.ToRawBytes(). @@ -1021,25 +1017,29 @@ func newArrayMetaDataSlabFromData( return nil, NewDecodingErrorf("data is too short for array metadata slab") } - version, flag := data[0], data[1] + h, err := newHeadFromData(data[:versionAndFlagSize]) + if err != nil { + return nil, NewDecodingError(err) + } - if getSlabArrayType(flag) != slabArrayMeta { + if h.getSlabArrayType() != slabArrayMeta { return nil, NewDecodingErrorf( - "data has invalid flag 0x%x, want 0x%x", - flag, - maskArrayMeta, + "data has invalid head 0x%x, want array metadata slab flag", + h[:], ) } - switch version { + data = data[versionAndFlagSize:] + + switch h.version() { case 0: - return newArrayMetaDataSlabFromDataV0(id, data, decMode, decodeTypeInfo) + return newArrayMetaDataSlabFromDataV0(id, h, data, decMode, decodeTypeInfo) case 1: - return newArrayMetaDataSlabFromDataV1(id, data, decMode, decodeTypeInfo) + return newArrayMetaDataSlabFromDataV1(id, h, data, decMode, decodeTypeInfo) default: - return nil, NewDecodingErrorf("unexpected version %d for array metadata slab", version) + return nil, NewDecodingErrorf("unexpected version %d for array metadata slab", h.version()) } } @@ -1064,6 +1064,7 @@ func newArrayMetaDataSlabFromData( // See ArrayExtraData.Encode() for extra data section format. func newArrayMetaDataSlabFromDataV0( id SlabID, + h head, data []byte, decMode cbor.DecMode, decodeTypeInfo TypeInfoDecoder, @@ -1073,41 +1074,42 @@ func newArrayMetaDataSlabFromDataV0( ) { // NOTE: the following encoded sizes are for version 0 only (changed in later version). const ( - // meta data slab prefix size: version (1 byte) + flag (1 byte) + child header count (2 bytes) - arrayMetaDataSlabPrefixSizeV0 = versionAndFlagSize + 2 + // meta data children array head size: 2 bytes + arrayMetaDataArrayHeadSizeV0 = 2 // slab header size: slab id (16 bytes) + count (4 bytes) + size (4 bytes) arraySlabHeaderSizeV0 = slabIDSize + 4 + 4 ) - // Check minimum data length - if len(data) < versionAndFlagSize { - return nil, NewDecodingErrorf("data is too short for array metadata slab") - } - - flag := data[1] - var err error var extraData *ArrayExtraData - if isRoot(flag) { - extraData, data, err = newArrayExtraDataFromData(data[versionAndFlagSize:], decMode, decodeTypeInfo) + if h.isRoot() { + extraData, data, err = newArrayExtraDataFromData(data, decMode, decodeTypeInfo) if err != nil { // Don't need to wrap because err is already categorized by newArrayExtraDataFromData(). return nil, err } + + // Skip second head (version + flag) here because it is only present in root slab in version 0. + if len(data) < versionAndFlagSize { + return nil, NewDecodingErrorf("data is too short for array data slab") + } + + data = data[versionAndFlagSize:] } // Check data length (after decoding extra data if present) - if len(data) < arrayMetaDataSlabPrefixSizeV0 { + if len(data) < arrayMetaDataArrayHeadSizeV0 { return nil, NewDecodingErrorf("data is too short for array metadata slab") } // Decode number of child headers - const childHeaderCountOffset = versionAndFlagSize - childHeaderCount := binary.BigEndian.Uint16(data[childHeaderCountOffset:]) + childHeaderCount := binary.BigEndian.Uint16(data) + + data = data[arrayMetaDataArrayHeadSizeV0:] - expectedDataLength := arrayMetaDataSlabPrefixSizeV0 + arraySlabHeaderSizeV0*int(childHeaderCount) + expectedDataLength := arraySlabHeaderSizeV0 * int(childHeaderCount) if len(data) != expectedDataLength { return nil, NewDecodingErrorf( "data has unexpected length %d, want %d", @@ -1120,7 +1122,7 @@ func newArrayMetaDataSlabFromDataV0( childrenHeaders := make([]ArraySlabHeader, childHeaderCount) childrenCountSum := make([]uint32, childHeaderCount) totalCount := uint32(0) - offset := childHeaderCountOffset + 2 + offset := 0 for i := 0; i < int(childHeaderCount); i++ { slabID, err := NewSlabIDFromRawBytes(data[offset:]) @@ -1185,6 +1187,7 @@ func newArrayMetaDataSlabFromDataV0( // See ArrayExtraData.Encode() for extra data section format. func newArrayMetaDataSlabFromDataV1( id SlabID, + h head, data []byte, decMode cbor.DecMode, decodeTypeInfo TypeInfoDecoder, @@ -1192,19 +1195,10 @@ func newArrayMetaDataSlabFromDataV1( *ArrayMetaDataSlab, error, ) { - // Check minimum data length - if len(data) < versionAndFlagSize { - return nil, NewDecodingErrorf("data is too short for array metadata slab") - } - - isRoot := isRoot(data[1]) - - data = data[versionAndFlagSize:] - var err error var extraData *ArrayExtraData - if isRoot { + if h.isRoot() { extraData, data, err = newArrayExtraDataFromData(data, decMode, decodeTypeInfo) if err != nil { // Don't need to wrap because err is already categorized by newArrayExtraDataFromData(). @@ -1307,20 +1301,19 @@ func newArrayMetaDataSlabFromDataV1( // See ArrayExtraData.Encode() for extra data section format. func (a *ArrayMetaDataSlab) Encode(enc *Encoder) error { - flag := maskArrayMeta + const version = 1 - if a.extraData != nil { - flag = setRoot(flag) + h, err := newArraySlabHead(version, slabArrayMeta) + if err != nil { + return NewEncodingError(err) } - // Encode version - enc.Scratch[0] = 1 - - // Encode flag - enc.Scratch[1] = flag + if a.extraData != nil { + h.setRoot() + } - // Write version and flag - _, err := enc.Write(enc.Scratch[:versionAndFlagSize]) + // Write head (version + flag) + _, err = enc.Write(h[:]) if err != nil { return NewEncodingError(err) } diff --git a/array_debug.go b/array_debug.go index dc4eeaa2..64cf0a07 100644 --- a/array_debug.go +++ b/array_debug.go @@ -446,17 +446,15 @@ func validArraySlabSerialization( } // Extra check: encoded data size == header.size - encodedExtraDataSize, err := getEncodedArrayExtraDataSize(slab.ExtraData(), cborEncMode) + encodedSlabSize, err := computeSlabSize(data) if err != nil { - // Don't need to wrap error as external error because err is already categorized by getEncodedArrayExtraDataSize(). + // Don't need to wrap error as external error because err is already categorized by computeSlabSize(). return err } - // Need to exclude extra data size from encoded data size. - encodedSlabSize := uint32(len(data) - encodedExtraDataSize) - if slab.Header().size != encodedSlabSize { - return NewFatalError(fmt.Errorf("slab %d encoded size %d != header.size %d (encoded extra data size %d)", - id, encodedSlabSize, slab.Header().size, encodedExtraDataSize)) + if slab.Header().size != uint32(encodedSlabSize) { + return NewFatalError(fmt.Errorf("slab %d encoded size %d != header.size %d", + id, encodedSlabSize, slab.Header().size)) } // Compare encoded data of original slab with encoded data of decoded slab @@ -640,25 +638,6 @@ func arrayExtraDataEqual(expected, actual *ArrayExtraData) error { return nil } -func getEncodedArrayExtraDataSize(extraData *ArrayExtraData, cborEncMode cbor.EncMode) (int, error) { - if extraData == nil { - return 0, nil - } - - var buf bytes.Buffer - enc := NewEncoder(&buf, cborEncMode) - - // Normally the flag shouldn't be 0. But in this case we just need the encoded data size - // so the content of the flag doesn't matter. - err := extraData.Encode(enc) - if err != nil { - // Don't need to wrap error as external error because err is already categorized by ArrayExtraData.Encode(). - return 0, err - } - - return len(buf.Bytes()), nil -} - func ValidValueSerialization( value Value, cborDecMode cbor.DecMode, @@ -690,3 +669,47 @@ func ValidValueSerialization( } return nil } + +func computeSlabSize(data []byte) (int, error) { + if len(data) < versionAndFlagSize { + return 0, NewDecodingError(fmt.Errorf("data is too short")) + } + + h, err := newHeadFromData(data[:versionAndFlagSize]) + if err != nil { + return 0, NewDecodingError(err) + } + + slabExtraDataSize, err := getExtraDataSize(h, data[versionAndFlagSize:]) + if err != nil { + return 0, err + } + + // Computed slab size (slab header size): + // - excludes slab extra data size + // - adds next slab ID for non-root data slab if not encoded + size := len(data) - slabExtraDataSize + + isDataSlab := h.getSlabArrayType() == slabArrayData || + h.getSlabMapType() == slabMapData || + h.getSlabMapType() == slabMapCollisionGroup + + if !h.isRoot() && isDataSlab && !h.hasNextSlabID() { + size += slabIDSize + } + + return size, nil +} + +func getExtraDataSize(h head, data []byte) (int, error) { + if h.isRoot() { + dec := cbor.NewStreamDecoder(bytes.NewBuffer(data)) + b, err := dec.DecodeRawBytes() + if err != nil { + return 0, NewDecodingError(err) + } + return len(b), nil + } + + return 0, nil +} diff --git a/array_test.go b/array_test.go index b5eb3614..ad08ac9d 100644 --- a/array_test.go +++ b/array_test.go @@ -1753,7 +1753,7 @@ func TestArrayEncodeDecode(t *testing.T) { expectedData := []byte{ // version - 0x01, + 0x10, // flag 0x80, @@ -1797,7 +1797,7 @@ func TestArrayEncodeDecode(t *testing.T) { expectedData := []byte{ // version - 0x01, + 0x10, // flag 0x80, @@ -1872,7 +1872,7 @@ func TestArrayEncodeDecode(t *testing.T) { // (metadata slab) headers: [{id:2 size:228 count:9} {id:3 size:270 count:11} ] id1: { // version - 0x01, + 0x10, // flag 0x81, @@ -1899,7 +1899,7 @@ func TestArrayEncodeDecode(t *testing.T) { // (data slab) next: 3, data: [aaaaaaaaaaaaaaaaaaaaaa ... aaaaaaaaaaaaaaaaaaaaaa] id2: { // version - 0x01, + 0x12, // array data slab flag 0x00, // next slab id @@ -1921,11 +1921,9 @@ func TestArrayEncodeDecode(t *testing.T) { // (data slab) next: 0, data: [aaaaaaaaaaaaaaaaaaaaaa ... SlabID(...)] id3: { // version - 0x01, + 0x10, // array data slab flag 0x40, - // next slab id - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // CBOR encoded array head (fixed size 3 byte) 0x99, 0x00, 0x0b, // CBOR encoded array elements @@ -1945,7 +1943,7 @@ func TestArrayEncodeDecode(t *testing.T) { // (data slab) next: 0, data: [0] id4: { // version - 0x01, + 0x10, // extra data flag 0x80, diff --git a/basicarray.go b/basicarray.go index 7e806847..b5267e4c 100644 --- a/basicarray.go +++ b/basicarray.go @@ -62,11 +62,16 @@ func newBasicArrayDataSlabFromData( return nil, NewDecodingErrorf("data is too short for basic array slab") } + h, err := newHeadFromData(data[:versionAndFlagSize]) + if err != nil { + return nil, NewDecodingError(err) + } + // Check flag - if getSlabArrayType(data[1]) != slabBasicArray { + if h.getSlabArrayType() != slabBasicArray { return nil, NewDecodingErrorf( - "data has invalid flag 0x%x, want 0x%x", - data[0], + "data has invalid head 0x%x, want 0x%x", + h[:], maskBasicArray, ) } diff --git a/encode.go b/encode.go index fb84d49a..c88fa3a8 100644 --- a/encode.go +++ b/encode.go @@ -62,14 +62,18 @@ func DecodeSlab( return nil, NewDecodingErrorf("data is too short") } - flag := data[1] + h, err := newHeadFromData(data[:versionAndFlagSize]) + if err != nil { + return nil, NewDecodingError(err) + } - dataType := getSlabType(flag) - switch dataType { + switch h.getSlabType() { case slabArray: - switch arrayDataType := getSlabArrayType(flag); arrayDataType { + arrayDataType := h.getSlabArrayType() + + switch arrayDataType { case slabArrayData: return newArrayDataSlabFromData(id, data, decMode, decodeStorable, decodeTypeInfo) case slabArrayMeta: @@ -77,12 +81,14 @@ func DecodeSlab( case slabBasicArray: return newBasicArrayDataSlabFromData(id, data, decMode, decodeStorable) default: - return nil, NewDecodingErrorf("data has invalid flag 0x%x", flag) + return nil, NewDecodingErrorf("data has invalid head 0x%x", h[:]) } case slabMap: - switch mapDataType := getSlabMapType(flag); mapDataType { + mapDataType := h.getSlabMapType() + + switch mapDataType { case slabMapData: return newMapDataSlabFromData(id, data, decMode, decodeStorable, decodeTypeInfo) case slabMapMeta: @@ -90,7 +96,7 @@ func DecodeSlab( case slabMapCollisionGroup: return newMapDataSlabFromData(id, data, decMode, decodeStorable, decodeTypeInfo) default: - return nil, NewDecodingErrorf("data has invalid flag 0x%x", flag) + return nil, NewDecodingErrorf("data has invalid head 0x%x", h[:]) } case slabStorable: @@ -106,7 +112,7 @@ func DecodeSlab( }, nil default: - return nil, NewDecodingErrorf("data has invalid flag 0x%x", flag) + return nil, NewDecodingErrorf("data has invalid head 0x%x", h[:]) } } diff --git a/flag.go b/flag.go index 59b76d3d..44230072 100644 --- a/flag.go +++ b/flag.go @@ -18,6 +18,10 @@ package atree +import ( + "fmt" +) + type slabType int const ( @@ -47,6 +51,16 @@ const ( slabMapCollisionGroup ) +// Version and flag masks for the 1st byte of encoded slab. +// Flags in this group are only for v1 and above. +const ( + maskVersion byte = 0b1111_0000 + maskHasNextSlabID byte = 0b0000_0010 // This flag is only relevant for data slab. + maskHasInlinedSlabs byte = 0b0000_0001 +) + +// Flag masks for the 2nd byte of encoded slab. +// Flags in this group are available for all versions. const ( // Slab flags: 3 high bits maskSlabRoot byte = 0b100_00000 @@ -69,31 +83,136 @@ const ( maskStorable byte = 0b000_11111 ) -func setRoot(f byte) byte { - return f | maskSlabRoot +const ( + maxVersion = 0b0000_1111 +) + +type head [2]byte + +// newArraySlabHead returns an array slab head of given version and slab type. +func newArraySlabHead(version byte, t slabArrayType) (*head, error) { + if version > maxVersion { + return nil, fmt.Errorf("encoding version must be less than %d, got %d", maxVersion+1, version) + } + + var h head + + h[0] = version << 4 + + switch t { + case slabArrayData: + h[1] = maskArrayData + + case slabArrayMeta: + h[1] = maskArrayMeta + + case slabBasicArray: + h[1] = maskBasicArray + + default: + return nil, fmt.Errorf("unsupported array slab type %d", t) + } + + return &h, nil +} + +// newMapSlabHead returns a map slab head of given version and slab type. +func newMapSlabHead(version byte, t slabMapType) (*head, error) { + if version > maxVersion { + return nil, fmt.Errorf("encoding version must be less than %d, got %d", maxVersion+1, version) + } + + var h head + + h[0] = version << 4 + + switch t { + case slabMapData: + h[1] = maskMapData + + case slabMapMeta: + h[1] = maskMapMeta + + case slabMapCollisionGroup: + h[1] = maskCollisionGroup + + default: + return nil, fmt.Errorf("unsupported map slab type %d", t) + } + + return &h, nil +} + +// newStorableSlabHead returns a storable slab head of given version. +func newStorableSlabHead(version byte) (*head, error) { + if version > maxVersion { + return nil, fmt.Errorf("encoding version must be less than %d, got %d", maxVersion+1, version) + } + + var h head + h[0] = version << 4 + h[1] = maskStorable + return &h, nil +} + +// newHeadFromData returns a head with given data. +func newHeadFromData(data []byte) (head, error) { + if len(data) != 2 { + return head{}, fmt.Errorf("head must be 2 bytes, got %d bytes", len(data)) + } + + return head{data[0], data[1]}, nil +} + +func (h *head) version() byte { + return (h[0] & maskVersion) >> 4 } -func setHasPointers(f byte) byte { - return f | maskSlabHasPointers +func (h *head) isRoot() bool { + return h[1]&maskSlabRoot > 0 } -func setNoSizeLimit(f byte) byte { - return f | maskSlabAnySize +func (h *head) setRoot() { + h[1] |= maskSlabRoot } -func isRoot(f byte) bool { - return f&maskSlabRoot > 0 +func (h *head) hasPointers() bool { + return h[1]&maskSlabHasPointers > 0 } -func hasPointers(f byte) bool { - return f&maskSlabHasPointers > 0 +func (h *head) setHasPointers() { + h[1] |= maskSlabHasPointers } -func hasSizeLimit(f byte) bool { - return f&maskSlabAnySize == 0 +func (h *head) hasSizeLimit() bool { + return h[1]&maskSlabAnySize == 0 } -func getSlabType(f byte) slabType { +func (h *head) setNoSizeLimit() { + h[1] |= maskSlabAnySize +} + +func (h *head) hasInlinedSlabs() bool { + return h[0]&maskHasInlinedSlabs > 0 +} + +func (h *head) setHasInlinedSlabs() { + h[0] |= maskHasInlinedSlabs +} + +func (h *head) hasNextSlabID() bool { + if h.version() == 0 { + return !h.isRoot() + } + return h[0]&maskHasNextSlabID > 0 +} + +func (h *head) setHasNextSlabID() { + h[0] |= maskHasNextSlabID +} + +func (h head) getSlabType() slabType { + f := h[1] // Extract 4th and 5th bits for slab type. dataType := (f & byte(0b000_11000)) >> 3 switch dataType { @@ -111,11 +230,13 @@ func getSlabType(f byte) slabType { } } -func getSlabArrayType(f byte) slabArrayType { - if getSlabType(f) != slabArray { +func (h head) getSlabArrayType() slabArrayType { + if h.getSlabType() != slabArray { return slabArrayUndefined } + f := h[1] + // Extract 3 low bits for slab array type. dataType := (f & byte(0b000_00111)) switch dataType { @@ -132,11 +253,13 @@ func getSlabArrayType(f byte) slabArrayType { } } -func getSlabMapType(f byte) slabMapType { - if getSlabType(f) != slabMap { +func (h head) getSlabMapType() slabMapType { + if h.getSlabType() != slabMap { return slabMapUndefined } + f := h[1] + // Extract 3 low bits for slab map type. dataType := (f & byte(0b000_00111)) switch dataType { diff --git a/flag_test.go b/flag_test.go index d7a05817..e4a81564 100644 --- a/flag_test.go +++ b/flag_test.go @@ -25,76 +25,336 @@ import ( ) func TestFlagIsRoot(t *testing.T) { - for i := 0; i <= 255; i++ { - if i >= 0x80 { - require.True(t, isRoot(byte(i))) - } else { - require.False(t, isRoot(byte(i))) - } + testCases := []struct { + name string + h head + }{ + {"v0", head([2]byte{})}, + {"v1", head([2]byte{1 << 4, 0x0})}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + for i := 0; i <= 255; i++ { + tc.h[1] = byte(i) + if i >= 1<<7 { + require.True(t, tc.h.isRoot()) + } else { + require.False(t, tc.h.isRoot()) + } + } + }) } } -func TestFlagSetRoot(t *testing.T) { +func TestFlagSetRootV1(t *testing.T) { + var h head + h[0] = 1 << 4 // version 1 + for i := 0; i <= 255; i++ { - require.True(t, isRoot(setRoot(byte(i)))) + h[1] = byte(i) + h.setRoot() + require.True(t, h.isRoot()) } } func TestFlagHasPointers(t *testing.T) { - for i := 0; i <= 255; i++ { - if byte(i)&maskSlabHasPointers != 0 { - require.True(t, hasPointers(byte(i))) - } else { - require.False(t, hasPointers(byte(i))) - } + testCases := []struct { + name string + h head + }{ + {"v0", head([2]byte{})}, + {"v1", head([2]byte{1 << 4, 0x0})}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + for i := 0; i <= 255; i++ { + tc.h[1] = byte(i) + + if byte(i)&maskSlabHasPointers != 0 { + require.True(t, tc.h.hasPointers()) + } else { + require.False(t, tc.h.hasPointers()) + } + } + }) } } -func TestFlagSetHasPointers(t *testing.T) { +func TestFlagSetHasPointersV1(t *testing.T) { + var h head + h[0] = 1 << 4 // version 1 + for i := 0; i <= 255; i++ { - require.True(t, hasPointers(setHasPointers(byte(i)))) + h[1] = byte(i) + h.setHasPointers() + + require.True(t, h.hasPointers()) } } func TestFlagHasSizeLimit(t *testing.T) { + testCases := []struct { + name string + h head + }{ + {"v0", head([2]byte{})}, + {"v1", head([2]byte{1 << 4, 0x0})}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + for i := 0; i <= 255; i++ { + tc.h[1] = byte(i) + + if byte(i)&maskSlabAnySize == 0 { + require.True(t, tc.h.hasSizeLimit()) + } else { + require.False(t, tc.h.hasSizeLimit()) + } + } + }) + } +} + +func TestFlagSetNoSizeLimitV1(t *testing.T) { + var h head + h[0] = 1 << 4 // version 1 + for i := 0; i <= 255; i++ { - if byte(i)&maskSlabAnySize == 0 { - require.True(t, hasSizeLimit(byte(i))) - } else { - require.False(t, hasSizeLimit(byte(i))) + h[1] = byte(i) + + h.setNoSizeLimit() + require.False(t, h.hasSizeLimit()) + } +} + +func TestFlagHasNextSlabID(t *testing.T) { + var h head + h[0] = 1 << 4 // v1 + + t.Run("has", func(t *testing.T) { + // Flags in the first byte + for i := 0; i < 32; i++ { + h[0] |= byte(i) + h[0] |= maskHasNextSlabID + + // Flags in the second byte + for j := 0; j <= 255; j++ { + h[1] = byte(j) + require.True(t, h.hasNextSlabID()) + } + } + }) + + t.Run("doesn't have", func(t *testing.T) { + // Flags in the first byte + for i := 0; i < 32; i++ { + h[0] |= byte(i) + h[0] &= ^maskHasNextSlabID + + // Flags in the second byte + for j := 0; j <= 255; j++ { + h[1] = byte(j) + require.False(t, h.hasNextSlabID()) + } + } + }) +} + +func TestFlagSetHasNextSlabIDV1(t *testing.T) { + var h head + h[0] = 1 << 4 // version 1 + + // Flags in the first byte + for i := 0; i < 32; i++ { + h[0] |= byte(i) + + // Flags in the second byte + for i := 0; i <= 255; i++ { + h[1] = byte(i) + + h.setHasNextSlabID() + require.True(t, h.hasNextSlabID()) } } } -func TestFlagSetNoSizeLimit(t *testing.T) { - for i := 0; i <= 255; i++ { - f := setNoSizeLimit(byte(i)) - require.False(t, hasSizeLimit(f)) +func TestFlagHasInlinedSlabs(t *testing.T) { + var h head + h[0] = 1 << 4 // v1 + + t.Run("has", func(t *testing.T) { + // Flags in the first byte + for i := 0; i < 32; i++ { + h[0] |= byte(i) + h[0] |= maskHasInlinedSlabs + + // Flags in the second byte + for j := 0; j <= 255; j++ { + h[1] = byte(j) + require.True(t, h.hasInlinedSlabs()) + } + } + }) + + t.Run("doesn't have", func(t *testing.T) { + // Flags in the first byte + for i := 0; i < 32; i++ { + h[0] |= byte(i) + h[0] &= ^maskHasInlinedSlabs + + // Flags in the second byte + for j := 0; j <= 255; j++ { + h[1] = byte(j) + require.False(t, h.hasInlinedSlabs()) + } + } + }) +} + +func TestFlagSetHasInlinedSlabsV1(t *testing.T) { + var h head + h[0] = 1 << 4 // version 1 + + // Flags in the first byte + for i := 0; i < 32; i++ { + h[0] |= byte(i) + + // Flags in the second byte + for i := 0; i <= 255; i++ { + h[1] = byte(i) + + h.setHasInlinedSlabs() + require.True(t, h.hasInlinedSlabs()) + } } } func TestFlagGetSlabType(t *testing.T) { - for i := 0; i <= 255; i++ { - arrayFlag := byte(i) & 0b111_00111 - mapFlag := arrayFlag | 0b000_01000 - storableFlag := mapFlag | 0b000_11111 + testCases := []struct { + name string + h head + }{ + {"v0", head([2]byte{})}, + {"v1", head([2]byte{1 << 4, 0x0})}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + for i := 0; i <= 255; i++ { + arrayFlag := byte(i) & 0b111_00111 + tc.h[1] = arrayFlag + require.Equal(t, slabArray, tc.h.getSlabType()) - require.Equal(t, slabArray, getSlabType(arrayFlag)) - require.Equal(t, slabMap, getSlabType(mapFlag)) - require.Equal(t, slabStorable, getSlabType(storableFlag)) + mapFlag := arrayFlag | 0b000_01000 + tc.h[1] = mapFlag + require.Equal(t, slabMap, tc.h.getSlabType()) + + storableFlag := arrayFlag | 0b000_11111 + tc.h[1] = storableFlag + require.Equal(t, slabStorable, tc.h.getSlabType()) + } + }) } } func TestFlagGetSlabArrayType(t *testing.T) { - for i := 0; i <= 255; i++ { - arrayDataFlag := byte(i) & 0b111_00000 - arrayMetaFlag := arrayDataFlag | 0b000_00001 - arrayLargeImmutableArrayFlag := arrayDataFlag | 0b000_00010 - basicArrayFlag := arrayDataFlag | 0b000_00011 + testCases := []struct { + name string + h head + }{ + {"v0", head([2]byte{})}, + {"v1", head([2]byte{1 << 4, 0x0})}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + for i := 0; i <= 255; i++ { + arrayDataFlag := byte(i) & 0b111_00000 + tc.h[1] = arrayDataFlag + require.Equal(t, slabArrayData, tc.h.getSlabArrayType()) + + arrayMetaFlag := arrayDataFlag | 0b000_00001 + tc.h[1] = arrayMetaFlag + require.Equal(t, slabArrayMeta, tc.h.getSlabArrayType()) + + arrayLargeImmutableArrayFlag := arrayDataFlag | 0b000_00010 + tc.h[1] = arrayLargeImmutableArrayFlag + require.Equal(t, slabLargeImmutableArray, tc.h.getSlabArrayType()) - require.Equal(t, slabArrayData, getSlabArrayType(arrayDataFlag)) - require.Equal(t, slabArrayMeta, getSlabArrayType(arrayMetaFlag)) - require.Equal(t, slabLargeImmutableArray, getSlabArrayType(arrayLargeImmutableArrayFlag)) - require.Equal(t, slabBasicArray, getSlabArrayType(basicArrayFlag)) + basicArrayFlag := arrayDataFlag | 0b000_00011 + tc.h[1] = basicArrayFlag + require.Equal(t, slabBasicArray, tc.h.getSlabArrayType()) + } + }) } } + +func TestFlagGetSlabMapType(t *testing.T) { + testCases := []struct { + name string + h head + }{ + {"v0", head([2]byte{})}, + {"v1", head([2]byte{1 << 4, 0x0})}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + for i := 0; i <= 255; i++ { + b := byte(i) + b |= 0b000_01000 // turn on map flag + b &= 0b111_01111 // turn off storable flag + + mapDataFlag := b & 0b111_11000 + tc.h[1] = mapDataFlag + require.Equal(t, slabMapData, tc.h.getSlabMapType()) + + mapMetaFlag := mapDataFlag | 0b000_00001 + tc.h[1] = mapMetaFlag + require.Equal(t, slabMapMeta, tc.h.getSlabMapType()) + + mapLargeImmutableArrayFlag := mapDataFlag | 0b000_00010 + tc.h[1] = mapLargeImmutableArrayFlag + require.Equal(t, slabMapLargeEntry, tc.h.getSlabMapType()) + + collisionGroupFlag := mapDataFlag | 0b000_00011 + tc.h[1] = collisionGroupFlag + require.Equal(t, slabMapCollisionGroup, tc.h.getSlabMapType()) + } + }) + } +} + +func TestVersion(t *testing.T) { + t.Run("v0", func(t *testing.T) { + const expectedVersion = byte(0) + + var h head + // Flags in the second byte + for i := 0; i <= 255; i++ { + h[1] = byte(i) + require.Equal(t, expectedVersion, h.version()) + } + }) + + t.Run("v1", func(t *testing.T) { + const expectedVersion = byte(1) + + var h head + h[0] = 0x10 + + // Flags in the first byte + for i := 0; i < 32; i++ { + h[0] |= byte(i) + + // Flags in the second byte + for j := 0; j <= 255; j++ { + h[1] = byte(j) + require.Equal(t, expectedVersion, h.version()) + } + } + }) +} diff --git a/map.go b/map.go index 789e1e7b..7d9eabf8 100644 --- a/map.go +++ b/map.go @@ -2048,28 +2048,31 @@ func newMapDataSlabFromData( return nil, NewDecodingErrorf("data is too short for map data slab") } - version, flag := data[0], data[1] + h, err := newHeadFromData(data[:versionAndFlagSize]) + if err != nil { + return nil, NewDecodingError(err) + } - mapType := getSlabMapType(flag) + mapType := h.getSlabMapType() if mapType != slabMapData && mapType != slabMapCollisionGroup { return nil, NewDecodingErrorf( - "data has invalid flag 0x%x, want 0x%x or 0x%x", - flag, - maskMapData, - maskCollisionGroup, + "data has invalid head 0x%x, want map data slab flag or map collision group flag", + h[:], ) } - switch version { + data = data[versionAndFlagSize:] + + switch h.version() { case 0: - return newMapDataSlabFromDataV0(id, data, decMode, decodeStorable, decodeTypeInfo) + return newMapDataSlabFromDataV0(id, h, data, decMode, decodeStorable, decodeTypeInfo) case 1: - return newMapDataSlabFromDataV1(id, data, decMode, decodeStorable, decodeTypeInfo) + return newMapDataSlabFromDataV1(id, h, data, decMode, decodeStorable, decodeTypeInfo) default: - return nil, NewDecodingErrorf("unexpected version %d for map data slab", version) + return nil, NewDecodingErrorf("unexpected version %d for map data slab", h.version()) } } @@ -2095,6 +2098,7 @@ func newMapDataSlabFromData( // See hkeyElements.Encode() and singleElements.Encode() for elements section format. func newMapDataSlabFromDataV0( id SlabID, + h head, data []byte, decMode cbor.DecMode, decodeStorable StorableDecoder, @@ -2103,60 +2107,46 @@ func newMapDataSlabFromDataV0( *MapDataSlab, error, ) { - // Check minimum data length - if len(data) < versionAndFlagSize { - return nil, NewDecodingErrorf("data is too short for map data slab") - } - - flag := data[1] - mapType := getSlabMapType(flag) - isRootSlab := isRoot(flag) - var err error var extraData *MapExtraData - if isRootSlab { + if h.isRoot() { // Decode extra data - extraData, data, err = newMapExtraDataFromData(data[versionAndFlagSize:], decMode, decodeTypeInfo) + extraData, data, err = newMapExtraDataFromData(data, decMode, decodeTypeInfo) if err != nil { // Don't need to wrap error as external error because err is already categorized by newMapExtraDataFromData(). return nil, err } - } - minDataLength := mapDataSlabPrefixSize - if isRootSlab { - minDataLength = mapRootDataSlabPrefixSize - } + // Skip second head (version + flag) here because it is only present in root slab in version 0. + if len(data) < versionAndFlagSize { + return nil, NewDecodingErrorf("data is too short for map data slab") + } - // Check data length (after decoding extra data if present) - if len(data) < minDataLength { - return nil, NewDecodingErrorf("data is too short for map data slab") + data = data[versionAndFlagSize:] } var next SlabID - var contentOffset int - - if !isRootSlab { + if !h.isRoot() { + // Check data length for next slab ID + if len(data) < slabIDSize { + return nil, NewDecodingErrorf("data is too short for map data slab") + } // Decode next slab ID - const nextSlabIDOffset = versionAndFlagSize var err error - next, err = NewSlabIDFromRawBytes(data[nextSlabIDOffset:]) + next, err = NewSlabIDFromRawBytes(data) if err != nil { // Don't need to wrap error as external error because err is already categorized by NewSlabIDFromRawBytes(). return nil, err } - contentOffset = nextSlabIDOffset + slabIDSize - - } else { - contentOffset = versionAndFlagSize + data = data[slabIDSize:] } // Decode elements - cborDec := decMode.NewByteStreamDecoder(data[contentOffset:]) + cborDec := decMode.NewByteStreamDecoder(data) elements, err := newElementsFromData(cborDec, decodeStorable) if err != nil { // Don't need to wrap error as external error because err is already categorized by newElementsFromDataV0(). @@ -2165,7 +2155,7 @@ func newMapDataSlabFromDataV0( // Compute slab size for version 1. slabSize := versionAndFlagSize + elements.Size() - if !isRootSlab { + if !h.isRoot() { slabSize += slabIDSize } @@ -2180,8 +2170,8 @@ func newMapDataSlabFromDataV0( header: header, elements: elements, extraData: extraData, - anySize: !hasSizeLimit(flag), - collisionGroup: mapType == slabMapCollisionGroup, + anySize: !h.hasSizeLimit(), + collisionGroup: h.getSlabMapType() == slabMapCollisionGroup, }, nil } @@ -2207,6 +2197,7 @@ func newMapDataSlabFromDataV0( // See hkeyElements.Encode() and singleElements.Encode() for elements section format. func newMapDataSlabFromDataV1( id SlabID, + h head, data []byte, decMode cbor.DecMode, decodeStorable StorableDecoder, @@ -2215,35 +2206,25 @@ func newMapDataSlabFromDataV1( *MapDataSlab, error, ) { - // Check minimum data length - if len(data) < versionAndFlagSize { - return nil, NewDecodingErrorf("data is too short for map data slab") - } - - flag := data[1] - mapType := getSlabMapType(flag) - isRootSlab := isRoot(flag) - - data = data[versionAndFlagSize:] - var err error var extraData *MapExtraData var next SlabID - // Decode header - if isRootSlab { - // Decode extra data + // Decode extra data + if h.isRoot() { extraData, data, err = newMapExtraDataFromData(data, decMode, decodeTypeInfo) if err != nil { // Don't need to wrap error as external error because err is already categorized by newMapExtraDataFromData(). return nil, err } - } else { + } + + // Decode next slab ID + if h.hasNextSlabID() { if len(data) < slabIDSize { return nil, NewDecodingErrorf("data is too short for map data slab") } - // Decode next slab ID next, err = NewSlabIDFromRawBytes(data) if err != nil { // Don't need to wrap error as external error because err is already categorized by NewSlabIDFromRawBytes(). @@ -2263,7 +2244,7 @@ func newMapDataSlabFromDataV1( // Compute slab size. slabSize := versionAndFlagSize + elements.Size() - if !isRootSlab { + if !h.isRoot() { slabSize += slabIDSize } @@ -2278,8 +2259,8 @@ func newMapDataSlabFromDataV1( header: header, elements: elements, extraData: extraData, - anySize: !hasSizeLimit(flag), - collisionGroup: mapType == slabMapCollisionGroup, + anySize: !h.hasSizeLimit(), + collisionGroup: h.getSlabMapType() == slabMapCollisionGroup, }, nil } @@ -2307,45 +2288,49 @@ func (m *MapDataSlab) Encode(enc *Encoder) error { const version = 1 - flag := maskMapData - + slabType := slabMapData if m.collisionGroup { - flag = maskCollisionGroup + slabType = slabMapCollisionGroup + } + + h, err := newMapSlabHead(version, slabType) + if err != nil { + return NewEncodingError(err) } if m.hasPointer() { - flag = setHasPointers(flag) + h.setHasPointers() + } + + if m.next != SlabIDUndefined { + h.setHasNextSlabID() } if m.anySize { - flag = setNoSizeLimit(flag) + h.setNoSizeLimit() } if m.extraData != nil { - flag = setRoot(flag) + h.setRoot() } - // Encode version - enc.Scratch[0] = version - - // Encode flag - enc.Scratch[1] = flag - - _, err := enc.Write(enc.Scratch[:versionAndFlagSize]) + // Write head (version + flag) + _, err = enc.Write(h[:]) if err != nil { return NewEncodingError(err) } - // Encode header + // Encode extra data if m.extraData != nil { - // Encode extra data err = m.extraData.Encode(enc) if err != nil { // Don't need to wrap error as external error because err is already categorized by MapExtraData.Encode(). return err } - } else { - // Encode next slab ID to scratch + } + + // Encode next slab ID + if m.next != SlabIDUndefined { n, err := m.next.ToRawBytes(enc.Scratch[:]) if err != nil { // Don't need to wrap error as external error because err is already categorized by SlabID.ToRawBytes(). @@ -2705,25 +2690,29 @@ func newMapMetaDataSlabFromData( return nil, NewDecodingErrorf("data is too short for map metadata slab") } - version, flag := data[0], data[1] + h, err := newHeadFromData(data[:versionAndFlagSize]) + if err != nil { + return nil, NewDecodingError(err) + } - if getSlabMapType(flag) != slabMapMeta { + if h.getSlabMapType() != slabMapMeta { return nil, NewDecodingErrorf( - "data has invalid flag 0x%x, want 0x%x", - flag, - maskMapMeta, + "data has invalid head 0x%x, want map metadata slab flag", + h[:], ) } - switch version { + data = data[versionAndFlagSize:] + + switch h.version() { case 0: - return newMapMetaDataSlabFromDataV0(id, data, decMode, decodeTypeInfo) + return newMapMetaDataSlabFromDataV0(id, h, data, decMode, decodeTypeInfo) case 1: - return newMapMetaDataSlabFromDataV1(id, data, decMode, decodeTypeInfo) + return newMapMetaDataSlabFromDataV1(id, h, data, decMode, decodeTypeInfo) default: - return nil, NewDecodingErrorf("unexpected version %d for map metadata slab", version) + return nil, NewDecodingErrorf("unexpected version %d for map metadata slab", h.version()) } } @@ -2748,43 +2737,46 @@ func newMapMetaDataSlabFromData( // See MapExtraData.Encode() for extra data section format. func newMapMetaDataSlabFromDataV0( id SlabID, + h head, data []byte, decMode cbor.DecMode, decodeTypeInfo TypeInfoDecoder, ) (*MapMetaDataSlab, error) { const ( - mapMetaDataSlabPrefixSizeV0 = versionAndFlagSize + 2 - mapSlabHeaderSizeV0 = slabIDSize + 4 + digestSize + mapMetaDataArrayHeadSizeV0 = 2 + mapSlabHeaderSizeV0 = slabIDSize + 4 + digestSize ) - // Check minimum data length - if len(data) < versionAndFlagSize { - return nil, NewDecodingErrorf("data is too short for map metadata slab") - } - + var err error var extraData *MapExtraData // Check flag for extra data - if isRoot(data[1]) { + if h.isRoot() { // Decode extra data - var err error - extraData, data, err = newMapExtraDataFromData(data[versionAndFlagSize:], decMode, decodeTypeInfo) + extraData, data, err = newMapExtraDataFromData(data, decMode, decodeTypeInfo) if err != nil { // Don't need to wrap error as external error because err is already categorized by newMapExtraDataFromData(). return nil, err } + + // Skip second head (version + flag) here because it is only present in root slab in version 0. + if len(data) < versionAndFlagSize { + return nil, NewDecodingErrorf("data is too short for array data slab") + } + + data = data[versionAndFlagSize:] } // Check data length (after decoding extra data if present) - if len(data) < mapMetaDataSlabPrefixSizeV0 { + if len(data) < mapMetaDataArrayHeadSizeV0 { return nil, NewDecodingErrorf("data is too short for map metadata slab") } // Decode number of child headers - const childHeaderCountOffset = versionAndFlagSize - childHeaderCount := binary.BigEndian.Uint16(data[childHeaderCountOffset:]) + childHeaderCount := binary.BigEndian.Uint16(data) + data = data[mapMetaDataArrayHeadSizeV0:] - expectedDataLength := mapMetaDataSlabPrefixSizeV0 + mapSlabHeaderSizeV0*int(childHeaderCount) + expectedDataLength := mapSlabHeaderSizeV0 * int(childHeaderCount) if len(data) != expectedDataLength { return nil, NewDecodingErrorf( "data has unexpected length %d, want %d", @@ -2795,7 +2787,7 @@ func newMapMetaDataSlabFromDataV0( // Decode child headers childrenHeaders := make([]MapSlabHeader, childHeaderCount) - offset := childHeaderCountOffset + 2 + offset := 0 for i := 0; i < int(childHeaderCount); i++ { slabID, err := NewSlabIDFromRawBytes(data[offset:]) @@ -2861,23 +2853,16 @@ func newMapMetaDataSlabFromDataV0( // See MapExtraData.Encode() for extra data section format. func newMapMetaDataSlabFromDataV1( id SlabID, + h head, data []byte, decMode cbor.DecMode, decodeTypeInfo TypeInfoDecoder, ) (*MapMetaDataSlab, error) { - // Check minimum data length - if len(data) < versionAndFlagSize { - return nil, NewDecodingErrorf("data is too short for map metadata slab") - } - - isRoot := isRoot(data[1]) - - data = data[versionAndFlagSize:] var err error var extraData *MapExtraData - if isRoot { + if h.isRoot() { // Decode extra data extraData, data, err = newMapExtraDataFromData(data, decMode, decodeTypeInfo) if err != nil { @@ -2980,20 +2965,17 @@ func (m *MapMetaDataSlab) Encode(enc *Encoder) error { const version = 1 - flag := maskMapMeta + h, err := newMapSlabHead(version, slabMapMeta) + if err != nil { + return NewEncodingError(err) + } if m.extraData != nil { - flag = setRoot(flag) + h.setRoot() } - // Encode version - enc.Scratch[0] = version - - // Encode flag - enc.Scratch[1] = flag - - // Write version and flag - _, err := enc.Write(enc.Scratch[:versionAndFlagSize]) + // Write head (version and flag) + _, err = enc.Write(h[:]) if err != nil { return NewEncodingError(err) } diff --git a/map_debug.go b/map_debug.go index 7954b403..051b7acb 100644 --- a/map_debug.go +++ b/map_debug.go @@ -849,18 +849,16 @@ func validMapSlabSerialization( } // Extra check: encoded data size == header.size - encodedExtraDataSize, err := getEncodedMapExtraDataSize(slab.ExtraData(), cborEncMode) + encodedSlabSize, err := computeSlabSize(data) if err != nil { - // Don't need to wrap error as external error because err is already categorized by getEncodedMapExtraDataSize(). + // Don't need to wrap error as external error because err is already categorized by computeSlabSize(). return err } - // Need to exclude extra data size from encoded data size. - encodedSlabSize := uint32(len(data) - encodedExtraDataSize) - if slab.Header().size != encodedSlabSize { + if slab.Header().size != uint32(encodedSlabSize) { return NewFatalError( - fmt.Errorf("slab %d encoded size %d != header.size %d (encoded extra data size %d)", - id, encodedSlabSize, slab.Header().size, encodedExtraDataSize)) + fmt.Errorf("slab %d encoded size %d != header.size %d", + id, encodedSlabSize, slab.Header().size)) } // Compare encoded data of original slab with encoded data of decoded slab @@ -1357,20 +1355,3 @@ func mapExtraDataEqual(expected, actual *MapExtraData) error { return nil } - -func getEncodedMapExtraDataSize(extraData *MapExtraData, cborEncMode cbor.EncMode) (int, error) { - if extraData == nil { - return 0, nil - } - - var buf bytes.Buffer - enc := NewEncoder(&buf, cborEncMode) - - err := extraData.Encode(enc) - if err != nil { - // Don't need to wrap error as external error because err is already categorized by MapExtraData.Encode(). - return 0, err - } - - return len(buf.Bytes()), nil -} diff --git a/map_test.go b/map_test.go index b7b94c47..6ccf380b 100644 --- a/map_test.go +++ b/map_test.go @@ -1836,9 +1836,8 @@ func TestMapDecodeV0(t *testing.T) { // array data slab nestedSlabID: { - // extra data // version - 0x01, + 0x00, // flag: root + array data 0x80, // extra data (CBOR encoded array of 1 elements) @@ -1846,6 +1845,10 @@ func TestMapDecodeV0(t *testing.T) { // type info 0x18, 0x2b, + // version + 0x00, + // flag: root + array data + 0x80, // CBOR encoded array head (fixed size 3 byte) 0x99, 0x00, 0x01, // CBOR encoded array elements @@ -2500,7 +2503,7 @@ func TestMapEncodeDecode(t *testing.T) { expected := map[SlabID][]byte{ id1: { // version - 0x01, + 0x10, // flag: root + map data 0x88, @@ -2584,7 +2587,7 @@ func TestMapEncodeDecode(t *testing.T) { id1: { // version - 0x01, + 0x10, // flag: root + map data 0x88, @@ -2702,7 +2705,7 @@ func TestMapEncodeDecode(t *testing.T) { // metadata slab id1: { // version - 0x01, + 0x10, // flag: root + map meta 0x89, @@ -2734,7 +2737,7 @@ func TestMapEncodeDecode(t *testing.T) { // data slab id2: { // version - 0x01, + 0x12, // flag: map data 0x08, // next slab id @@ -2783,11 +2786,9 @@ func TestMapEncodeDecode(t *testing.T) { // data slab id3: { // version - 0x01, + 0x10, // flag: has pointer + map data 0x48, - // next slab id - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // the following encoded data is valid CBOR @@ -2831,7 +2832,7 @@ func TestMapEncodeDecode(t *testing.T) { // array data slab id4: { // version - 0x01, + 0x10, // flag: root + array data 0x80, // extra data (CBOR encoded array of 1 elements) @@ -2861,7 +2862,8 @@ func TestMapEncodeDecode(t *testing.T) { require.True(t, ok) require.Equal(t, 2, len(meta.childrenHeaders)) require.Equal(t, uint32(len(stored[id2])), meta.childrenHeaders[0].size) - require.Equal(t, uint32(len(stored[id3])), meta.childrenHeaders[1].size) + // Need to add slabIDSize to encoded data slab here because empty slab ID is omitted during encoding. + require.Equal(t, uint32(len(stored[id3])+slabIDSize), meta.childrenHeaders[1].size) // Decode data to new storage storage2 := newTestPersistentStorageWithData(t, stored) @@ -2913,7 +2915,7 @@ func TestMapEncodeDecode(t *testing.T) { // map metadata slab id1: { // version - 0x01, + 0x10, // flag: root + map data 0x88, // extra data (CBOR encoded array of 3 elements) @@ -3100,7 +3102,7 @@ func TestMapEncodeDecode(t *testing.T) { // map data slab id1: { // version - 0x01, + 0x10, // flag: root + map data 0x88, // extra data (CBOR encoded array of 3 elements) @@ -3339,7 +3341,7 @@ func TestMapEncodeDecode(t *testing.T) { // map data slab id1: { // version - 0x01, + 0x10, // flag: root + has pointer + map data 0xc8, // extra data (CBOR encoded array of 3 elements) @@ -3386,11 +3388,9 @@ func TestMapEncodeDecode(t *testing.T) { // external collision group id2: { // version - 0x01, + 0x10, // flag: any size + collision group 0x2b, - // next slab id - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // the following encoded data is valid CBOR @@ -3451,11 +3451,9 @@ func TestMapEncodeDecode(t *testing.T) { // external collision group id3: { // version - 0x01, + 0x10, // flag: any size + collision group 0x2b, - // next slab id - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // the following encoded data is valid CBOR @@ -3558,7 +3556,7 @@ func TestMapEncodeDecode(t *testing.T) { expectedNoPointer := []byte{ // version - 0x01, + 0x10, // flag: root + map data 0x88, // extra data (CBOR encoded array of 3 elements) @@ -3608,7 +3606,7 @@ func TestMapEncodeDecode(t *testing.T) { expectedHasPointer := []byte{ // version - 0x01, + 0x10, // flag: root + pointer + map data 0xc8, // extra data (CBOR encoded array of 3 elements) diff --git a/slab.go b/slab.go index fdfc17ad..2f0a6a9b 100644 --- a/slab.go +++ b/slab.go @@ -38,9 +38,12 @@ func IsRootOfAnObject(slabData []byte) (bool, error) { return false, NewDecodingErrorf("data is too short") } - flag := slabData[1] + h, err := newHeadFromData(slabData[:versionAndFlagSize]) + if err != nil { + return false, NewDecodingError(err) + } - return isRoot(flag), nil + return h.isRoot(), nil } func HasPointers(slabData []byte) (bool, error) { @@ -48,9 +51,12 @@ func HasPointers(slabData []byte) (bool, error) { return false, NewDecodingErrorf("data is too short") } - flag := slabData[1] + h, err := newHeadFromData(slabData[:versionAndFlagSize]) + if err != nil { + return false, NewDecodingError(err) + } - return hasPointers(flag), nil + return h.hasPointers(), nil } func HasSizeLimit(slabData []byte) (bool, error) { @@ -58,7 +64,10 @@ func HasSizeLimit(slabData []byte) (bool, error) { return false, NewDecodingErrorf("data is too short") } - flag := slabData[1] + h, err := newHeadFromData(slabData[:versionAndFlagSize]) + if err != nil { + return false, NewDecodingError(err) + } - return hasSizeLimit(flag), nil + return h.hasSizeLimit(), nil } diff --git a/storable_slab.go b/storable_slab.go index 3a24618a..9cc6d7bd 100644 --- a/storable_slab.go +++ b/storable_slab.go @@ -68,20 +68,21 @@ func (s *StorableSlab) ChildStorables() []Storable { } func (s *StorableSlab) Encode(enc *Encoder) error { - // Encode version - enc.Scratch[0] = 0 - // Encode flag - flag := maskStorable - flag = setNoSizeLimit(flag) + const version = 1 - if _, ok := s.storable.(SlabIDStorable); ok { - flag = setHasPointers(flag) + h, err := newStorableSlabHead(version) + if err != nil { + return NewEncodingError(err) } - enc.Scratch[1] = flag + h.setNoSizeLimit() + + if hasPointer(s.storable) { + h.setHasPointers() + } - _, err := enc.Write(enc.Scratch[:versionAndFlagSize]) + _, err = enc.Write(h[:]) if err != nil { return NewEncodingError(err) } diff --git a/storage_test.go b/storage_test.go index 60b2fb7b..40a4e6c8 100644 --- a/storage_test.go +++ b/storage_test.go @@ -902,7 +902,7 @@ func TestPersistentStorageSlabIterator(t *testing.T) { id1: { // extra data // version - 0x01, + 0x10, // extra data flag 0x81, // array of extra data @@ -927,7 +927,7 @@ func TestPersistentStorageSlabIterator(t *testing.T) { // (data slab) next: 3, data: [aaaaaaaaaaaaaaaaaaaaaa ... aaaaaaaaaaaaaaaaaaaaaa] id2: { // version - 0x01, + 0x12, // array data slab flag 0x00, // next slab id @@ -949,11 +949,9 @@ func TestPersistentStorageSlabIterator(t *testing.T) { // (data slab) next: 0, data: [aaaaaaaaaaaaaaaaaaaaaa ... SlabID(...)] id3: { // version - 0x01, + 0x10, // array data slab flag 0x40, - // next slab id - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // CBOR encoded array head (fixed size 3 byte) 0x99, 0x00, 0x0b, // CBOR encoded array elements @@ -974,7 +972,7 @@ func TestPersistentStorageSlabIterator(t *testing.T) { id4: { // extra data // version - 0x01, + 0x10, // extra data flag 0x80, // array of extra data