Skip to content

Commit

Permalink
golang: Enforce maximum decoding depth (#169)
Browse files Browse the repository at this point in the history
  • Loading branch information
2opremio authored Sep 19, 2023
1 parent f0c4145 commit a231a92
Show file tree
Hide file tree
Showing 9 changed files with 433 additions and 222 deletions.
51 changes: 35 additions & 16 deletions lib/xdrgen/generators/go.rb
Original file line number Diff line number Diff line change
Expand Up @@ -536,7 +536,11 @@ def check_error(str)
def render_struct_decode_from_interface(out, struct)
name = name(struct)
out.puts "// DecodeFrom decodes this value using the Decoder."
out.puts "func (s *#{name}) DecodeFrom(d *xdr.Decoder) (int, error) {"
out.puts "func (s *#{name}) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) {"
out.puts " if maxDepth == 0 {"
out.puts " return 0, fmt.Errorf(\"decoding #{name}: %w\", ErrMaxDecodingDepthReached)"
out.puts " }"
out.puts " maxDepth -= 1"
out.puts " var err error"
out.puts " var n, nTmp int"
declared_variables = []
Expand All @@ -552,7 +556,11 @@ def render_struct_decode_from_interface(out, struct)
def render_union_decode_from_interface(out, union)
name = name(union)
out.puts "// DecodeFrom decodes this value using the Decoder."
out.puts "func (u *#{name}) DecodeFrom(d *xdr.Decoder) (int, error) {"
out.puts "func (u *#{name}) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) {"
out.puts " if maxDepth == 0 {"
out.puts " return 0, fmt.Errorf(\"decoding #{name}: %w\", ErrMaxDecodingDepthReached)"
out.puts " }"
out.puts " maxDepth -= 1"
out.puts " var err error"
out.puts " var n, nTmp int"
render_decode_from_body(out, "u.#{name(union.discriminant)}", union.discriminant.type, declared_variables: [], self_encode: false)
Expand Down Expand Up @@ -581,10 +589,14 @@ def render_enum_decode_from_interface(out, typedef)
type = typedef
out.puts <<-EOS.strip_heredoc
// DecodeFrom decodes this value using the Decoder.
func (e *#{name}) DecodeFrom(d *xdr.Decoder) (int, error) {
func (e *#{name}) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) {
if maxDepth == 0 {
return 0, fmt.Errorf("decoding #{name}: %w", ErrMaxDecodingDepthReached)
}
maxDepth -= 1
v, n, err := d.DecodeInt()
if err != nil {
return n, fmt.Errorf("decoding #{name}: %s", err)
return n, fmt.Errorf("decoding #{name}: %w", err)
}
if _, ok := #{private_name type}Map[v]; !ok {
return n, fmt.Errorf("'%d' is not a valid #{name} enum value", v)
Expand All @@ -599,7 +611,11 @@ def render_typedef_decode_from_interface(out, typedef)
name = name(typedef)
type = typedef.declaration.type
out.puts "// DecodeFrom decodes this value using the Decoder."
out.puts "func (s *#{name}) DecodeFrom(d *xdr.Decoder) (int, error) {"
out.puts "func (s *#{name}) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) {"
out.puts " if maxDepth == 0 {"
out.puts " return 0, fmt.Errorf(\"decoding #{name}: %w\", ErrMaxDecodingDepthReached)"
out.puts " }"
out.puts " maxDepth -= 1"
out.puts " var err error"
out.puts " var n, nTmp int"
var = "s"
Expand Down Expand Up @@ -636,7 +652,7 @@ def render_decode_from_body(out, var, type, declared_variables:, self_encode:)
tail = <<-EOS.strip_heredoc
n += nTmp
if err != nil {
return n, fmt.Errorf("decoding #{name type}: %s", err)
return n, fmt.Errorf("decoding #{name type}: %w", err)
}
EOS
optional = type.sub_type == :optional
Expand Down Expand Up @@ -692,7 +708,7 @@ def render_decode_from_body(out, var, type, declared_variables:, self_encode:)
out.puts " #{var} = new(#{name type.resolved_type.declaration.type})"
end
var = "(*#{name type})(#{var})" if self_encode
out.puts " nTmp, err = #{var}.DecodeFrom(d)"
out.puts " nTmp, err = #{var}.DecodeFrom(d, maxDepth)"
out.puts tail
if optional_within
out.puts " }"
Expand All @@ -709,7 +725,7 @@ def render_decode_from_body(out, var, type, declared_variables:, self_encode:)
out.puts " if eb {"
var = "(*#{element_var})"
end
out.puts " nTmp, err = #{element_var}.DecodeFrom(d)"
out.puts " nTmp, err = #{element_var}.DecodeFrom(d, maxDepth)"
out.puts tail
if optional_within
out.puts " }"
Expand Down Expand Up @@ -739,7 +755,7 @@ def render_decode_from_body(out, var, type, declared_variables:, self_encode:)
out.puts " #{element_var} = new(#{name type.resolved_type.declaration.type})"
var = "(*#{element_var})"
end
out.puts " nTmp, err = #{element_var}.DecodeFrom(d)"
out.puts " nTmp, err = #{element_var}.DecodeFrom(d, maxDepth)"
out.puts tail
if optional_within
out.puts " }"
Expand All @@ -751,13 +767,13 @@ def render_decode_from_body(out, var, type, declared_variables:, self_encode:)
end
when AST::Definitions::Base
if self_encode
out.puts " nTmp, err = #{name type}(#{var}).DecodeFrom(d)"
out.puts " nTmp, err = #{name type}(#{var}).DecodeFrom(d, maxDepth)"
else
out.puts " nTmp, err = #{var}.DecodeFrom(d)"
out.puts " nTmp, err = #{var}.DecodeFrom(d, maxDepth)"
end
out.puts tail
else
out.puts " nTmp, err = d.Decode(&#{var})"
out.puts " nTmp, err = d.DecodeWithMaxDepth(&#{var}, maxDepth)"
out.puts tail
end
if optional
Expand All @@ -778,7 +794,7 @@ def render_binary_interface(out, name)
out.puts "func (s *#{name}) UnmarshalBinary(inp []byte) error {"
out.puts " r := bytes.NewReader(inp)"
out.puts " d := xdr.NewDecoder(r)"
out.puts " _, err := s.DecodeFrom(d)"
out.puts " _, err := s.DecodeFrom(d, xdr.DecodeDefaultMaxDepth)"
out.puts " return err"
out.puts "}"
out.break
Expand Down Expand Up @@ -817,6 +833,7 @@ def render_top_matter(out)
import (
"bytes"
"encoding"
"errors"
"io"
"fmt"
Expand All @@ -832,19 +849,21 @@ def render_top_matter(out)
EOS
out.break
out.puts <<-EOS.strip_heredoc
var ErrMaxDecodingDepthReached = errors.New("maximum decoding depth reached")
type xdrType interface {
xdrType()
}
type decoderFrom interface {
DecodeFrom(d *xdr.Decoder) (int, error)
DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error)
}
// Unmarshal reads an xdr element from `r` into `v`.
func Unmarshal(r io.Reader, v interface{}) (int, error) {
if decodable, ok := v.(decoderFrom); ok {
d := xdr.NewDecoder(r)
return decodable.DecodeFrom(d)
return decodable.DecodeFrom(d, xdr.DecodeDefaultMaxDepth)
}
// delegate to xdr package's Unmarshal
return xdr.Unmarshal(r, v)
Expand Down Expand Up @@ -979,7 +998,7 @@ def render_union_constructor(out, union)
<<-EOS
tv, ok := value.(#{reference arm.type})
if !ok {
err = fmt.Errorf("invalid value, must be #{reference arm.type}")
err = errors.New("invalid value, must be #{reference arm.type}")
return
}
result.#{name arm} = &tv
Expand Down
17 changes: 12 additions & 5 deletions spec/output/generator_spec_go/block_comments.x/MyXDR_generated.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ package MyXDR
import (
"bytes"
"encoding"
"errors"
"io"
"fmt"

Expand All @@ -22,19 +23,21 @@ var XdrFilesSHA256 = map[string]string{
"spec/fixtures/generator/block_comments.x": "e13131bc4134f38da17b9d5e9f67d2695a69ef98e3ef272833f4c18d0cc88a30",
}

var ErrMaxDecodingDepthReached = errors.New("maximum decoding depth reached")

type xdrType interface {
xdrType()
}

type decoderFrom interface {
DecodeFrom(d *xdr.Decoder) (int, error)
DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error)
}

// Unmarshal reads an xdr element from `r` into `v`.
func Unmarshal(r io.Reader, v interface{}) (int, error) {
if decodable, ok := v.(decoderFrom); ok {
d := xdr.NewDecoder(r)
return decodable.DecodeFrom(d)
return decodable.DecodeFrom(d, xdr.DecodeDefaultMaxDepth)
}
// delegate to xdr package's Unmarshal
return xdr.Unmarshal(r, v)
Expand Down Expand Up @@ -92,10 +95,14 @@ func (e AccountFlags) EncodeTo(enc *xdr.Encoder) error {
}
var _ decoderFrom = (*AccountFlags)(nil)
// DecodeFrom decodes this value using the Decoder.
func (e *AccountFlags) DecodeFrom(d *xdr.Decoder) (int, error) {
func (e *AccountFlags) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) {
if maxDepth == 0 {
return 0, fmt.Errorf("decoding AccountFlags: %w", ErrMaxDecodingDepthReached)
}
maxDepth -= 1
v, n, err := d.DecodeInt()
if err != nil {
return n, fmt.Errorf("decoding AccountFlags: %s", err)
return n, fmt.Errorf("decoding AccountFlags: %w", err)
}
if _, ok := accountFlagsMap[v]; !ok {
return n, fmt.Errorf("'%d' is not a valid AccountFlags enum value", v)
Expand All @@ -115,7 +122,7 @@ func (s AccountFlags) MarshalBinary() ([]byte, error) {
func (s *AccountFlags) UnmarshalBinary(inp []byte) error {
r := bytes.NewReader(inp)
d := xdr.NewDecoder(r)
_, err := s.DecodeFrom(d)
_, err := s.DecodeFrom(d, xdr.DecodeDefaultMaxDepth)
return err
}

Expand Down
27 changes: 19 additions & 8 deletions spec/output/generator_spec_go/const.x/MyXDR_generated.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ package MyXDR
import (
"bytes"
"encoding"
"errors"
"io"
"fmt"

Expand All @@ -22,19 +23,21 @@ var XdrFilesSHA256 = map[string]string{
"spec/fixtures/generator/const.x": "0bff3b37592fcc16cad2fe10b9a72f5d39d033a114917c24e86a9ebd9cda9c37",
}

var ErrMaxDecodingDepthReached = errors.New("maximum decoding depth reached")

type xdrType interface {
xdrType()
}

type decoderFrom interface {
DecodeFrom(d *xdr.Decoder) (int, error)
DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error)
}

// Unmarshal reads an xdr element from `r` into `v`.
func Unmarshal(r io.Reader, v interface{}) (int, error) {
if decodable, ok := v.(decoderFrom); ok {
d := xdr.NewDecoder(r)
return decodable.DecodeFrom(d)
return decodable.DecodeFrom(d, xdr.DecodeDefaultMaxDepth)
}
// delegate to xdr package's Unmarshal
return xdr.Unmarshal(r, v)
Expand Down Expand Up @@ -77,14 +80,18 @@ if _, err = e.EncodeInt(int32(s)); err != nil {

var _ decoderFrom = (*TestArray)(nil)
// DecodeFrom decodes this value using the Decoder.
func (s *TestArray) DecodeFrom(d *xdr.Decoder) (int, error) {
func (s *TestArray) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) {
if maxDepth == 0 {
return 0, fmt.Errorf("decoding TestArray: %w", ErrMaxDecodingDepthReached)
}
maxDepth -= 1
var err error
var n, nTmp int
var v [Foo]int32
v, nTmp, err = d.DecodeInt()
n += nTmp
if err != nil {
return n, fmt.Errorf("decoding Int: %s", err)
return n, fmt.Errorf("decoding Int: %w", err)
}
*s = TestArray(v)
return n, nil
Expand All @@ -102,7 +109,7 @@ func (s TestArray) MarshalBinary() ([]byte, error) {
func (s *TestArray) UnmarshalBinary(inp []byte) error {
r := bytes.NewReader(inp)
d := xdr.NewDecoder(r)
_, err := s.DecodeFrom(d)
_, err := s.DecodeFrom(d, xdr.DecodeDefaultMaxDepth)
return err
}

Expand Down Expand Up @@ -137,14 +144,18 @@ if _, err = e.EncodeInt(int32(s)); err != nil {

var _ decoderFrom = (*TestArray2)(nil)
// DecodeFrom decodes this value using the Decoder.
func (s *TestArray2) DecodeFrom(d *xdr.Decoder) (int, error) {
func (s *TestArray2) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) {
if maxDepth == 0 {
return 0, fmt.Errorf("decoding TestArray2: %w", ErrMaxDecodingDepthReached)
}
maxDepth -= 1
var err error
var n, nTmp int
var v []int32
v, nTmp, err = d.DecodeInt()
n += nTmp
if err != nil {
return n, fmt.Errorf("decoding Int: %s", err)
return n, fmt.Errorf("decoding Int: %w", err)
}
*s = TestArray2(v)
return n, nil
Expand All @@ -162,7 +173,7 @@ func (s TestArray2) MarshalBinary() ([]byte, error) {
func (s *TestArray2) UnmarshalBinary(inp []byte) error {
r := bytes.NewReader(inp)
d := xdr.NewDecoder(r)
_, err := s.DecodeFrom(d)
_, err := s.DecodeFrom(d, xdr.DecodeDefaultMaxDepth)
return err
}

Expand Down
Loading

0 comments on commit a231a92

Please sign in to comment.