Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

golang: Enforce maximum decoding depth #169

Merged
merged 3 commits into from
Sep 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 35 additions & 16 deletions lib/xdrgen/generators/go.rb
Original file line number Diff line number Diff line change
Expand Up @@ -536,7 +536,11 @@ def check_error(str)
def render_struct_decode_from_interface(out, struct)
name = name(struct)
out.puts "// DecodeFrom decodes this value using the Decoder."
out.puts "func (s *#{name}) DecodeFrom(d *xdr.Decoder) (int, error) {"
out.puts "func (s *#{name}) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) {"
out.puts " if maxDepth == 0 {"
out.puts " return 0, fmt.Errorf(\"decoding #{name}: %w\", ErrMaxDecodingDepthReached)"
out.puts " }"
out.puts " maxDepth -= 1"
out.puts " var err error"
out.puts " var n, nTmp int"
declared_variables = []
Expand All @@ -552,7 +556,11 @@ def render_struct_decode_from_interface(out, struct)
def render_union_decode_from_interface(out, union)
name = name(union)
out.puts "// DecodeFrom decodes this value using the Decoder."
out.puts "func (u *#{name}) DecodeFrom(d *xdr.Decoder) (int, error) {"
out.puts "func (u *#{name}) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) {"
out.puts " if maxDepth == 0 {"
out.puts " return 0, fmt.Errorf(\"decoding #{name}: %w\", ErrMaxDecodingDepthReached)"
out.puts " }"
out.puts " maxDepth -= 1"
out.puts " var err error"
out.puts " var n, nTmp int"
render_decode_from_body(out, "u.#{name(union.discriminant)}", union.discriminant.type, declared_variables: [], self_encode: false)
Expand Down Expand Up @@ -581,10 +589,14 @@ def render_enum_decode_from_interface(out, typedef)
type = typedef
out.puts <<-EOS.strip_heredoc
// DecodeFrom decodes this value using the Decoder.
func (e *#{name}) DecodeFrom(d *xdr.Decoder) (int, error) {
func (e *#{name}) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) {
if maxDepth == 0 {
return 0, fmt.Errorf("decoding #{name}: %w", ErrMaxDecodingDepthReached)
}
maxDepth -= 1
v, n, err := d.DecodeInt()
if err != nil {
return n, fmt.Errorf("decoding #{name}: %s", err)
return n, fmt.Errorf("decoding #{name}: %w", err)
}
if _, ok := #{private_name type}Map[v]; !ok {
return n, fmt.Errorf("'%d' is not a valid #{name} enum value", v)
Expand All @@ -599,7 +611,11 @@ def render_typedef_decode_from_interface(out, typedef)
name = name(typedef)
type = typedef.declaration.type
out.puts "// DecodeFrom decodes this value using the Decoder."
out.puts "func (s *#{name}) DecodeFrom(d *xdr.Decoder) (int, error) {"
out.puts "func (s *#{name}) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) {"
out.puts " if maxDepth == 0 {"
out.puts " return 0, fmt.Errorf(\"decoding #{name}: %w\", ErrMaxDecodingDepthReached)"
out.puts " }"
out.puts " maxDepth -= 1"
out.puts " var err error"
out.puts " var n, nTmp int"
var = "s"
Expand Down Expand Up @@ -636,7 +652,7 @@ def render_decode_from_body(out, var, type, declared_variables:, self_encode:)
tail = <<-EOS.strip_heredoc
n += nTmp
if err != nil {
return n, fmt.Errorf("decoding #{name type}: %s", err)
return n, fmt.Errorf("decoding #{name type}: %w", err)
}
EOS
optional = type.sub_type == :optional
Expand Down Expand Up @@ -692,7 +708,7 @@ def render_decode_from_body(out, var, type, declared_variables:, self_encode:)
out.puts " #{var} = new(#{name type.resolved_type.declaration.type})"
end
var = "(*#{name type})(#{var})" if self_encode
out.puts " nTmp, err = #{var}.DecodeFrom(d)"
out.puts " nTmp, err = #{var}.DecodeFrom(d, maxDepth)"
out.puts tail
if optional_within
out.puts " }"
Expand All @@ -709,7 +725,7 @@ def render_decode_from_body(out, var, type, declared_variables:, self_encode:)
out.puts " if eb {"
var = "(*#{element_var})"
end
out.puts " nTmp, err = #{element_var}.DecodeFrom(d)"
out.puts " nTmp, err = #{element_var}.DecodeFrom(d, maxDepth)"
out.puts tail
if optional_within
out.puts " }"
Expand Down Expand Up @@ -739,7 +755,7 @@ def render_decode_from_body(out, var, type, declared_variables:, self_encode:)
out.puts " #{element_var} = new(#{name type.resolved_type.declaration.type})"
var = "(*#{element_var})"
end
out.puts " nTmp, err = #{element_var}.DecodeFrom(d)"
out.puts " nTmp, err = #{element_var}.DecodeFrom(d, maxDepth)"
out.puts tail
if optional_within
out.puts " }"
Expand All @@ -751,13 +767,13 @@ def render_decode_from_body(out, var, type, declared_variables:, self_encode:)
end
when AST::Definitions::Base
if self_encode
out.puts " nTmp, err = #{name type}(#{var}).DecodeFrom(d)"
out.puts " nTmp, err = #{name type}(#{var}).DecodeFrom(d, maxDepth)"
else
out.puts " nTmp, err = #{var}.DecodeFrom(d)"
out.puts " nTmp, err = #{var}.DecodeFrom(d, maxDepth)"
end
out.puts tail
else
out.puts " nTmp, err = d.Decode(&#{var})"
out.puts " nTmp, err = d.DecodeWithMaxDepth(&#{var}, maxDepth)"
out.puts tail
end
if optional
Expand All @@ -778,7 +794,7 @@ def render_binary_interface(out, name)
out.puts "func (s *#{name}) UnmarshalBinary(inp []byte) error {"
out.puts " r := bytes.NewReader(inp)"
out.puts " d := xdr.NewDecoder(r)"
out.puts " _, err := s.DecodeFrom(d)"
out.puts " _, err := s.DecodeFrom(d, xdr.DecodeDefaultMaxDepth)"
out.puts " return err"
out.puts "}"
out.break
Expand Down Expand Up @@ -817,6 +833,7 @@ def render_top_matter(out)
import (
"bytes"
"encoding"
"errors"
"io"
"fmt"

Expand All @@ -832,19 +849,21 @@ def render_top_matter(out)
EOS
out.break
out.puts <<-EOS.strip_heredoc
var ErrMaxDecodingDepthReached = errors.New("maximum decoding depth reached")

type xdrType interface {
xdrType()
}

type decoderFrom interface {
DecodeFrom(d *xdr.Decoder) (int, error)
DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error)
}

// Unmarshal reads an xdr element from `r` into `v`.
func Unmarshal(r io.Reader, v interface{}) (int, error) {
if decodable, ok := v.(decoderFrom); ok {
d := xdr.NewDecoder(r)
return decodable.DecodeFrom(d)
return decodable.DecodeFrom(d, xdr.DecodeDefaultMaxDepth)
}
// delegate to xdr package's Unmarshal
return xdr.Unmarshal(r, v)
Expand Down Expand Up @@ -979,7 +998,7 @@ def render_union_constructor(out, union)
<<-EOS
tv, ok := value.(#{reference arm.type})
if !ok {
err = fmt.Errorf("invalid value, must be #{reference arm.type}")
err = errors.New("invalid value, must be #{reference arm.type}")
return
}
result.#{name arm} = &tv
Expand Down
17 changes: 12 additions & 5 deletions spec/output/generator_spec_go/block_comments.x/MyXDR_generated.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ package MyXDR
import (
"bytes"
"encoding"
"errors"
"io"
"fmt"

Expand All @@ -22,19 +23,21 @@ var XdrFilesSHA256 = map[string]string{
"spec/fixtures/generator/block_comments.x": "e13131bc4134f38da17b9d5e9f67d2695a69ef98e3ef272833f4c18d0cc88a30",
}

var ErrMaxDecodingDepthReached = errors.New("maximum decoding depth reached")

type xdrType interface {
xdrType()
}

type decoderFrom interface {
DecodeFrom(d *xdr.Decoder) (int, error)
DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error)
}

// Unmarshal reads an xdr element from `r` into `v`.
func Unmarshal(r io.Reader, v interface{}) (int, error) {
if decodable, ok := v.(decoderFrom); ok {
d := xdr.NewDecoder(r)
return decodable.DecodeFrom(d)
return decodable.DecodeFrom(d, xdr.DecodeDefaultMaxDepth)
}
// delegate to xdr package's Unmarshal
return xdr.Unmarshal(r, v)
Expand Down Expand Up @@ -92,10 +95,14 @@ func (e AccountFlags) EncodeTo(enc *xdr.Encoder) error {
}
var _ decoderFrom = (*AccountFlags)(nil)
// DecodeFrom decodes this value using the Decoder.
func (e *AccountFlags) DecodeFrom(d *xdr.Decoder) (int, error) {
func (e *AccountFlags) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) {
if maxDepth == 0 {
return 0, fmt.Errorf("decoding AccountFlags: %w", ErrMaxDecodingDepthReached)
}
maxDepth -= 1
v, n, err := d.DecodeInt()
if err != nil {
return n, fmt.Errorf("decoding AccountFlags: %s", err)
return n, fmt.Errorf("decoding AccountFlags: %w", err)
}
if _, ok := accountFlagsMap[v]; !ok {
return n, fmt.Errorf("'%d' is not a valid AccountFlags enum value", v)
Expand All @@ -115,7 +122,7 @@ func (s AccountFlags) MarshalBinary() ([]byte, error) {
func (s *AccountFlags) UnmarshalBinary(inp []byte) error {
r := bytes.NewReader(inp)
d := xdr.NewDecoder(r)
_, err := s.DecodeFrom(d)
_, err := s.DecodeFrom(d, xdr.DecodeDefaultMaxDepth)
return err
}

Expand Down
27 changes: 19 additions & 8 deletions spec/output/generator_spec_go/const.x/MyXDR_generated.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ package MyXDR
import (
"bytes"
"encoding"
"errors"
"io"
"fmt"

Expand All @@ -22,19 +23,21 @@ var XdrFilesSHA256 = map[string]string{
"spec/fixtures/generator/const.x": "0bff3b37592fcc16cad2fe10b9a72f5d39d033a114917c24e86a9ebd9cda9c37",
}

var ErrMaxDecodingDepthReached = errors.New("maximum decoding depth reached")

type xdrType interface {
xdrType()
}

type decoderFrom interface {
DecodeFrom(d *xdr.Decoder) (int, error)
DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error)
}

// Unmarshal reads an xdr element from `r` into `v`.
func Unmarshal(r io.Reader, v interface{}) (int, error) {
if decodable, ok := v.(decoderFrom); ok {
d := xdr.NewDecoder(r)
return decodable.DecodeFrom(d)
return decodable.DecodeFrom(d, xdr.DecodeDefaultMaxDepth)
}
// delegate to xdr package's Unmarshal
return xdr.Unmarshal(r, v)
Expand Down Expand Up @@ -77,14 +80,18 @@ if _, err = e.EncodeInt(int32(s)); err != nil {

var _ decoderFrom = (*TestArray)(nil)
// DecodeFrom decodes this value using the Decoder.
func (s *TestArray) DecodeFrom(d *xdr.Decoder) (int, error) {
func (s *TestArray) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) {
if maxDepth == 0 {
return 0, fmt.Errorf("decoding TestArray: %w", ErrMaxDecodingDepthReached)
}
maxDepth -= 1
var err error
var n, nTmp int
var v [Foo]int32
v, nTmp, err = d.DecodeInt()
n += nTmp
if err != nil {
return n, fmt.Errorf("decoding Int: %s", err)
return n, fmt.Errorf("decoding Int: %w", err)
}
*s = TestArray(v)
return n, nil
Expand All @@ -102,7 +109,7 @@ func (s TestArray) MarshalBinary() ([]byte, error) {
func (s *TestArray) UnmarshalBinary(inp []byte) error {
r := bytes.NewReader(inp)
d := xdr.NewDecoder(r)
_, err := s.DecodeFrom(d)
_, err := s.DecodeFrom(d, xdr.DecodeDefaultMaxDepth)
return err
}

Expand Down Expand Up @@ -137,14 +144,18 @@ if _, err = e.EncodeInt(int32(s)); err != nil {

var _ decoderFrom = (*TestArray2)(nil)
// DecodeFrom decodes this value using the Decoder.
func (s *TestArray2) DecodeFrom(d *xdr.Decoder) (int, error) {
func (s *TestArray2) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) {
if maxDepth == 0 {
return 0, fmt.Errorf("decoding TestArray2: %w", ErrMaxDecodingDepthReached)
}
maxDepth -= 1
var err error
var n, nTmp int
var v []int32
v, nTmp, err = d.DecodeInt()
n += nTmp
if err != nil {
return n, fmt.Errorf("decoding Int: %s", err)
return n, fmt.Errorf("decoding Int: %w", err)
}
*s = TestArray2(v)
return n, nil
Expand All @@ -162,7 +173,7 @@ func (s TestArray2) MarshalBinary() ([]byte, error) {
func (s *TestArray2) UnmarshalBinary(inp []byte) error {
r := bytes.NewReader(inp)
d := xdr.NewDecoder(r)
_, err := s.DecodeFrom(d)
_, err := s.DecodeFrom(d, xdr.DecodeDefaultMaxDepth)
return err
}

Expand Down
Loading