diff --git a/pkg/appconsts/appconsts.go b/pkg/appconsts/appconsts.go index 1a104ff3b0..5fb7193dec 100644 --- a/pkg/appconsts/appconsts.go +++ b/pkg/appconsts/appconsts.go @@ -6,4 +6,7 @@ import ( "github.com/tendermint/tendermint/pkg/consts" ) +// MaxShareVersion is the maximum value a share version can be. +const MaxShareVersion = 127 + var NameSpacedPaddedShareBytes = bytes.Repeat([]byte{0}, consts.MsgShareSize) diff --git a/pkg/shares/info_reserved_byte.go b/pkg/shares/info_reserved_byte.go new file mode 100644 index 0000000000..5562f05c76 --- /dev/null +++ b/pkg/shares/info_reserved_byte.go @@ -0,0 +1,37 @@ +package shares + +import ( + "fmt" + + "github.com/celestiaorg/celestia-app/pkg/appconsts" +) + +// InfoReservedByte is a byte with the following structure: the first 7 bits are +// reserved for version information in big endian form (initially `0000000`). +// The last bit is a "message start indicator", that is `1` if the share is at +// the start of a message and `0` otherwise. +type InfoReservedByte byte + +func NewInfoReservedByte(version uint8, isMessageStart bool) (InfoReservedByte, error) { + if version > appconsts.MaxShareVersion { + return 0, fmt.Errorf("version %d must be less than or equal to %d", version, appconsts.MaxShareVersion) + } + + prefix := version << 1 + if isMessageStart { + return InfoReservedByte(prefix + 1), nil + } + return InfoReservedByte(prefix), nil +} + +// Version returns the version encoded in this InfoReservedByte. Version is +// expected to be between 0 and appconsts.MaxShareVersion (inclusive). +func (i InfoReservedByte) Version() uint8 { + version := uint8(i) >> 1 + return version +} + +// IsMessageStart returns whether this share is the start of a message. +func (i InfoReservedByte) IsMessageStart() bool { + return uint(i)%2 == 1 +} diff --git a/pkg/shares/info_reserved_byte_test.go b/pkg/shares/info_reserved_byte_test.go new file mode 100644 index 0000000000..05ea5a973f --- /dev/null +++ b/pkg/shares/info_reserved_byte_test.go @@ -0,0 +1,73 @@ +package shares + +import "testing" + +func TestInfoReservedByte(t *testing.T) { + messageStart := true + notMessageStart := false + + type testCase struct { + version uint8 + isMessageStart bool + } + tests := []testCase{ + {0, messageStart}, + {1, messageStart}, + {2, messageStart}, + {127, messageStart}, + + {0, notMessageStart}, + {1, notMessageStart}, + {2, notMessageStart}, + {127, notMessageStart}, + } + + for _, test := range tests { + irb, err := NewInfoReservedByte(test.version, test.isMessageStart) + if err != nil { + t.Errorf("got %v want no error", err) + } + if got := irb.Version(); got != test.version { + t.Errorf("got version %v want %v", got, test.version) + } + if got := irb.IsMessageStart(); got != test.isMessageStart { + t.Errorf("got isMessageStart %v want %v", got, test.isMessageStart) + } + } +} + +func TestInfoReservedByteErrors(t *testing.T) { + messageStart := true + notMessageStart := false + + type testCase struct { + version uint8 + isMessageStart bool + } + + tests := []testCase{ + {128, notMessageStart}, + {255, notMessageStart}, + {128, messageStart}, + {255, messageStart}, + } + + for _, test := range tests { + _, err := NewInfoReservedByte(test.version, false) + if err == nil { + t.Errorf("got nil but want error when version > 127") + } + } +} + +func FuzzNewInfoReservedByte(f *testing.F) { + f.Fuzz(func(t *testing.T, version uint8, isMessageStart bool) { + if version > 127 { + t.Skip() + } + _, err := NewInfoReservedByte(version, isMessageStart) + if err != nil { + t.Errorf("got nil but want error when version > 127") + } + }) +}