diff --git a/pkg/sr/serde.go b/pkg/sr/serde.go index 10ebad5d..36bfa98d 100644 --- a/pkg/sr/serde.go +++ b/pkg/sr/serde.go @@ -474,6 +474,16 @@ func (*ConfluentHeader) DecodeID(b []byte) (int, []byte, error) { return int(id), b[5:], nil } +// UpdateID replaces the schema ID in b. If the header does not contain the +// magic byte or b contains less than 5 bytes it returns ErrBadHeader. +func (*ConfluentHeader) UpdateID(b []byte, id int) error { + if len(b) < 5 || b[0] != 0 { + return ErrBadHeader + } + binary.BigEndian.PutUint32(b[1:5], uint32(id)) + return nil +} + // DecodeIndex strips and decodes indices from b. It returns the index slice // alongside the unread bytes. It expects b to be the output of DecodeID (schema // ID should already be stripped away). If maxLength is greater than 0 and the diff --git a/pkg/sr/serde_test.go b/pkg/sr/serde_test.go index 92c6ef44..16ecd25b 100644 --- a/pkg/sr/serde_test.go +++ b/pkg/sr/serde_test.go @@ -175,15 +175,17 @@ func TestConfluentHeader(t *testing.T) { var h ConfluentHeader for i, test := range []struct { - id int - index []int - expEnc []byte + id int + newID int + index []int + expEnc []byte + expEncUpd []byte }{ - {id: 1, index: nil, expEnc: []byte{0, 0, 0, 0, 1}}, - {id: 256, index: nil, expEnc: []byte{0, 0, 0, 1, 0}}, - {id: 2, index: []int{0}, expEnc: []byte{0, 0, 0, 0, 2, 0}}, - {id: 3, index: []int{1}, expEnc: []byte{0, 0, 0, 0, 3, 2, 2}}, - {id: 4, index: []int{1, 2, 3}, expEnc: []byte{0, 0, 0, 0, 4, 6, 2, 4, 6}}, + {id: 1, newID: 2, index: nil, expEnc: []byte{0, 0, 0, 0, 1}, expEncUpd: []byte{0, 0, 0, 0, 2}}, + {id: 256, newID: 65536, index: nil, expEnc: []byte{0, 0, 0, 1, 0}, expEncUpd: []byte{0, 0, 1, 0, 0}}, + {id: 2, newID: 3, index: []int{0}, expEnc: []byte{0, 0, 0, 0, 2, 0}, expEncUpd: []byte{0, 0, 0, 0, 3, 0}}, + {id: 3, newID: 4, index: []int{1}, expEnc: []byte{0, 0, 0, 0, 3, 2, 2}, expEncUpd: []byte{0, 0, 0, 0, 4, 2, 2}}, + {id: 4, newID: 5, index: []int{1, 2, 3}, expEnc: []byte{0, 0, 0, 0, 4, 6, 2, 4, 6}, expEncUpd: []byte{0, 0, 0, 0, 5, 6, 2, 4, 6}}, } { b, err := h.AppendEncode(nil, test.id, test.index) if err != nil { @@ -228,6 +230,16 @@ func TestConfluentHeader(t *testing.T) { continue } } + + if err := h.UpdateID(b, test.newID); err != nil { + t.Errorf("#%d UpdateID: got unexpected err %v", i, err) + continue + } + if !bytes.Equal(b, test.expEncUpd) { + t.Errorf("#%d: UpdateID(%v) != exp(%v)", i, b, test.expEncUpd) + continue + } + } if _, _, err := h.DecodeID([]byte{1, 0, 0, 0, 0, 1}); err != ErrBadHeader { @@ -236,6 +248,12 @@ func TestConfluentHeader(t *testing.T) { if _, _, err := h.DecodeID([]byte{0, 0, 0, 0}); err != ErrBadHeader { t.Errorf("got %v != exp ErrBadHeader", err) } + if err := h.UpdateID([]byte{1, 0, 0, 0, 0, 1}, 42); err != ErrBadHeader { + t.Errorf("got %v != exp ErrBadHeader", err) + } + if err := h.UpdateID([]byte{0, 0, 0, 0}, 42); err != ErrBadHeader { + t.Errorf("got %v != exp ErrBadHeader", err) + } if _, _, err := h.DecodeIndex([]byte{2}, 1); err != io.EOF { t.Errorf("got %v != exp io.EOF", err) }