diff --git a/lib/xdrgen/generators/go.rb b/lib/xdrgen/generators/go.rb index 16c9c6bb2..f2a6da035 100644 --- a/lib/xdrgen/generators/go.rb +++ b/lib/xdrgen/generators/go.rb @@ -536,7 +536,11 @@ def check_error(str) def render_struct_decode_from_interface(out, struct) name = name(struct) out.puts "// DecodeFrom decodes this value using the Decoder." - out.puts "func (s *#{name}) DecodeFrom(d *xdr.Decoder) (int, error) {" + out.puts "func (s *#{name}) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) {" + out.puts " if maxDepth == 0 {" + out.puts " return 0, fmt.Errorf(\"decoding #{name}: %w\", ErrMaxDecodingDepthReached)" + out.puts " }" + out.puts " maxDepth -= 1" out.puts " var err error" out.puts " var n, nTmp int" declared_variables = [] @@ -552,7 +556,11 @@ def render_struct_decode_from_interface(out, struct) def render_union_decode_from_interface(out, union) name = name(union) out.puts "// DecodeFrom decodes this value using the Decoder." - out.puts "func (u *#{name}) DecodeFrom(d *xdr.Decoder) (int, error) {" + out.puts "func (u *#{name}) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) {" + out.puts " if maxDepth == 0 {" + out.puts " return 0, fmt.Errorf(\"decoding #{name}: %w\", ErrMaxDecodingDepthReached)" + out.puts " }" + out.puts " maxDepth -= 1" out.puts " var err error" out.puts " var n, nTmp int" render_decode_from_body(out, "u.#{name(union.discriminant)}", union.discriminant.type, declared_variables: [], self_encode: false) @@ -581,10 +589,14 @@ def render_enum_decode_from_interface(out, typedef) type = typedef out.puts <<-EOS.strip_heredoc // DecodeFrom decodes this value using the Decoder. - func (e *#{name}) DecodeFrom(d *xdr.Decoder) (int, error) { + func (e *#{name}) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) { + if maxDepth == 0 { + return 0, fmt.Errorf("decoding #{name}: %w", ErrMaxDecodingDepthReached) + } + maxDepth -= 1 v, n, err := d.DecodeInt() if err != nil { - return n, fmt.Errorf("decoding #{name}: %s", err) + return n, fmt.Errorf("decoding #{name}: %w", err) } if _, ok := #{private_name type}Map[v]; !ok { return n, fmt.Errorf("'%d' is not a valid #{name} enum value", v) @@ -599,7 +611,11 @@ def render_typedef_decode_from_interface(out, typedef) name = name(typedef) type = typedef.declaration.type out.puts "// DecodeFrom decodes this value using the Decoder." - out.puts "func (s *#{name}) DecodeFrom(d *xdr.Decoder) (int, error) {" + out.puts "func (s *#{name}) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) {" + out.puts " if maxDepth == 0 {" + out.puts " return 0, fmt.Errorf(\"decoding #{name}: %w\", ErrMaxDecodingDepthReached)" + out.puts " }" + out.puts " maxDepth -= 1" out.puts " var err error" out.puts " var n, nTmp int" var = "s" @@ -636,7 +652,7 @@ def render_decode_from_body(out, var, type, declared_variables:, self_encode:) tail = <<-EOS.strip_heredoc n += nTmp if err != nil { - return n, fmt.Errorf("decoding #{name type}: %s", err) + return n, fmt.Errorf("decoding #{name type}: %w", err) } EOS optional = type.sub_type == :optional @@ -692,7 +708,7 @@ def render_decode_from_body(out, var, type, declared_variables:, self_encode:) out.puts " #{var} = new(#{name type.resolved_type.declaration.type})" end var = "(*#{name type})(#{var})" if self_encode - out.puts " nTmp, err = #{var}.DecodeFrom(d)" + out.puts " nTmp, err = #{var}.DecodeFrom(d, maxDepth)" out.puts tail if optional_within out.puts " }" @@ -709,7 +725,7 @@ def render_decode_from_body(out, var, type, declared_variables:, self_encode:) out.puts " if eb {" var = "(*#{element_var})" end - out.puts " nTmp, err = #{element_var}.DecodeFrom(d)" + out.puts " nTmp, err = #{element_var}.DecodeFrom(d, maxDepth)" out.puts tail if optional_within out.puts " }" @@ -739,7 +755,7 @@ def render_decode_from_body(out, var, type, declared_variables:, self_encode:) out.puts " #{element_var} = new(#{name type.resolved_type.declaration.type})" var = "(*#{element_var})" end - out.puts " nTmp, err = #{element_var}.DecodeFrom(d)" + out.puts " nTmp, err = #{element_var}.DecodeFrom(d, maxDepth)" out.puts tail if optional_within out.puts " }" @@ -751,13 +767,13 @@ def render_decode_from_body(out, var, type, declared_variables:, self_encode:) end when AST::Definitions::Base if self_encode - out.puts " nTmp, err = #{name type}(#{var}).DecodeFrom(d)" + out.puts " nTmp, err = #{name type}(#{var}).DecodeFrom(d, maxDepth)" else - out.puts " nTmp, err = #{var}.DecodeFrom(d)" + out.puts " nTmp, err = #{var}.DecodeFrom(d, maxDepth)" end out.puts tail else - out.puts " nTmp, err = d.Decode(&#{var})" + out.puts " nTmp, err = d.DecodeWithMaxDepth(&#{var}, maxDepth)" out.puts tail end if optional @@ -778,7 +794,7 @@ def render_binary_interface(out, name) out.puts "func (s *#{name}) UnmarshalBinary(inp []byte) error {" out.puts " r := bytes.NewReader(inp)" out.puts " d := xdr.NewDecoder(r)" - out.puts " _, err := s.DecodeFrom(d)" + out.puts " _, err := s.DecodeFrom(d, xdr.DecodeDefaultMaxDepth)" out.puts " return err" out.puts "}" out.break @@ -817,6 +833,7 @@ def render_top_matter(out) import ( "bytes" "encoding" + "errors" "io" "fmt" @@ -832,19 +849,21 @@ def render_top_matter(out) EOS out.break out.puts <<-EOS.strip_heredoc + var ErrMaxDecodingDepthReached = errors.New("maximum decoding depth reached") + type xdrType interface { xdrType() } type decoderFrom interface { - DecodeFrom(d *xdr.Decoder) (int, error) + DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) } // Unmarshal reads an xdr element from `r` into `v`. func Unmarshal(r io.Reader, v interface{}) (int, error) { if decodable, ok := v.(decoderFrom); ok { d := xdr.NewDecoder(r) - return decodable.DecodeFrom(d) + return decodable.DecodeFrom(d, xdr.DecodeDefaultMaxDepth) } // delegate to xdr package's Unmarshal return xdr.Unmarshal(r, v) @@ -979,7 +998,7 @@ def render_union_constructor(out, union) <<-EOS tv, ok := value.(#{reference arm.type}) if !ok { - err = fmt.Errorf("invalid value, must be #{reference arm.type}") + err = errors.New("invalid value, must be #{reference arm.type}") return } result.#{name arm} = &tv diff --git a/spec/output/generator_spec_go/block_comments.x/MyXDR_generated.go b/spec/output/generator_spec_go/block_comments.x/MyXDR_generated.go index b384c9a18..0eaa4ee67 100644 --- a/spec/output/generator_spec_go/block_comments.x/MyXDR_generated.go +++ b/spec/output/generator_spec_go/block_comments.x/MyXDR_generated.go @@ -11,6 +11,7 @@ package MyXDR import ( "bytes" "encoding" + "errors" "io" "fmt" @@ -22,19 +23,21 @@ var XdrFilesSHA256 = map[string]string{ "spec/fixtures/generator/block_comments.x": "e13131bc4134f38da17b9d5e9f67d2695a69ef98e3ef272833f4c18d0cc88a30", } +var ErrMaxDecodingDepthReached = errors.New("maximum decoding depth reached") + type xdrType interface { xdrType() } type decoderFrom interface { - DecodeFrom(d *xdr.Decoder) (int, error) + DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) } // Unmarshal reads an xdr element from `r` into `v`. func Unmarshal(r io.Reader, v interface{}) (int, error) { if decodable, ok := v.(decoderFrom); ok { d := xdr.NewDecoder(r) - return decodable.DecodeFrom(d) + return decodable.DecodeFrom(d, xdr.DecodeDefaultMaxDepth) } // delegate to xdr package's Unmarshal return xdr.Unmarshal(r, v) @@ -92,10 +95,14 @@ func (e AccountFlags) EncodeTo(enc *xdr.Encoder) error { } var _ decoderFrom = (*AccountFlags)(nil) // DecodeFrom decodes this value using the Decoder. -func (e *AccountFlags) DecodeFrom(d *xdr.Decoder) (int, error) { +func (e *AccountFlags) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) { + if maxDepth == 0 { + return 0, fmt.Errorf("decoding AccountFlags: %w", ErrMaxDecodingDepthReached) + } + maxDepth -= 1 v, n, err := d.DecodeInt() if err != nil { - return n, fmt.Errorf("decoding AccountFlags: %s", err) + return n, fmt.Errorf("decoding AccountFlags: %w", err) } if _, ok := accountFlagsMap[v]; !ok { return n, fmt.Errorf("'%d' is not a valid AccountFlags enum value", v) @@ -115,7 +122,7 @@ func (s AccountFlags) MarshalBinary() ([]byte, error) { func (s *AccountFlags) UnmarshalBinary(inp []byte) error { r := bytes.NewReader(inp) d := xdr.NewDecoder(r) - _, err := s.DecodeFrom(d) + _, err := s.DecodeFrom(d, xdr.DecodeDefaultMaxDepth) return err } diff --git a/spec/output/generator_spec_go/const.x/MyXDR_generated.go b/spec/output/generator_spec_go/const.x/MyXDR_generated.go index c884470c4..46a8aecff 100644 --- a/spec/output/generator_spec_go/const.x/MyXDR_generated.go +++ b/spec/output/generator_spec_go/const.x/MyXDR_generated.go @@ -11,6 +11,7 @@ package MyXDR import ( "bytes" "encoding" + "errors" "io" "fmt" @@ -22,19 +23,21 @@ var XdrFilesSHA256 = map[string]string{ "spec/fixtures/generator/const.x": "0bff3b37592fcc16cad2fe10b9a72f5d39d033a114917c24e86a9ebd9cda9c37", } +var ErrMaxDecodingDepthReached = errors.New("maximum decoding depth reached") + type xdrType interface { xdrType() } type decoderFrom interface { - DecodeFrom(d *xdr.Decoder) (int, error) + DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) } // Unmarshal reads an xdr element from `r` into `v`. func Unmarshal(r io.Reader, v interface{}) (int, error) { if decodable, ok := v.(decoderFrom); ok { d := xdr.NewDecoder(r) - return decodable.DecodeFrom(d) + return decodable.DecodeFrom(d, xdr.DecodeDefaultMaxDepth) } // delegate to xdr package's Unmarshal return xdr.Unmarshal(r, v) @@ -77,14 +80,18 @@ if _, err = e.EncodeInt(int32(s)); err != nil { var _ decoderFrom = (*TestArray)(nil) // DecodeFrom decodes this value using the Decoder. -func (s *TestArray) DecodeFrom(d *xdr.Decoder) (int, error) { +func (s *TestArray) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) { + if maxDepth == 0 { + return 0, fmt.Errorf("decoding TestArray: %w", ErrMaxDecodingDepthReached) + } + maxDepth -= 1 var err error var n, nTmp int var v [Foo]int32 v, nTmp, err = d.DecodeInt() n += nTmp if err != nil { - return n, fmt.Errorf("decoding Int: %s", err) + return n, fmt.Errorf("decoding Int: %w", err) } *s = TestArray(v) return n, nil @@ -102,7 +109,7 @@ func (s TestArray) MarshalBinary() ([]byte, error) { func (s *TestArray) UnmarshalBinary(inp []byte) error { r := bytes.NewReader(inp) d := xdr.NewDecoder(r) - _, err := s.DecodeFrom(d) + _, err := s.DecodeFrom(d, xdr.DecodeDefaultMaxDepth) return err } @@ -137,14 +144,18 @@ if _, err = e.EncodeInt(int32(s)); err != nil { var _ decoderFrom = (*TestArray2)(nil) // DecodeFrom decodes this value using the Decoder. -func (s *TestArray2) DecodeFrom(d *xdr.Decoder) (int, error) { +func (s *TestArray2) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) { + if maxDepth == 0 { + return 0, fmt.Errorf("decoding TestArray2: %w", ErrMaxDecodingDepthReached) + } + maxDepth -= 1 var err error var n, nTmp int var v []int32 v, nTmp, err = d.DecodeInt() n += nTmp if err != nil { - return n, fmt.Errorf("decoding Int: %s", err) + return n, fmt.Errorf("decoding Int: %w", err) } *s = TestArray2(v) return n, nil @@ -162,7 +173,7 @@ func (s TestArray2) MarshalBinary() ([]byte, error) { func (s *TestArray2) UnmarshalBinary(inp []byte) error { r := bytes.NewReader(inp) d := xdr.NewDecoder(r) - _, err := s.DecodeFrom(d) + _, err := s.DecodeFrom(d, xdr.DecodeDefaultMaxDepth) return err } diff --git a/spec/output/generator_spec_go/enum.x/MyXDR_generated.go b/spec/output/generator_spec_go/enum.x/MyXDR_generated.go index b11df486b..592762bc4 100644 --- a/spec/output/generator_spec_go/enum.x/MyXDR_generated.go +++ b/spec/output/generator_spec_go/enum.x/MyXDR_generated.go @@ -11,6 +11,7 @@ package MyXDR import ( "bytes" "encoding" + "errors" "io" "fmt" @@ -22,19 +23,21 @@ var XdrFilesSHA256 = map[string]string{ "spec/fixtures/generator/enum.x": "35cf5e97e2057039640ed260e8b38bb2733a3c3ca8529c93877bdec02a999d7f", } +var ErrMaxDecodingDepthReached = errors.New("maximum decoding depth reached") + type xdrType interface { xdrType() } type decoderFrom interface { - DecodeFrom(d *xdr.Decoder) (int, error) + DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) } // Unmarshal reads an xdr element from `r` into `v`. func Unmarshal(r io.Reader, v interface{}) (int, error) { if decodable, ok := v.(decoderFrom); ok { d := xdr.NewDecoder(r) - return decodable.DecodeFrom(d) + return decodable.DecodeFrom(d, xdr.DecodeDefaultMaxDepth) } // delegate to xdr package's Unmarshal return xdr.Unmarshal(r, v) @@ -137,10 +140,14 @@ func (e MessageType) EncodeTo(enc *xdr.Encoder) error { } var _ decoderFrom = (*MessageType)(nil) // DecodeFrom decodes this value using the Decoder. -func (e *MessageType) DecodeFrom(d *xdr.Decoder) (int, error) { +func (e *MessageType) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) { + if maxDepth == 0 { + return 0, fmt.Errorf("decoding MessageType: %w", ErrMaxDecodingDepthReached) + } + maxDepth -= 1 v, n, err := d.DecodeInt() if err != nil { - return n, fmt.Errorf("decoding MessageType: %s", err) + return n, fmt.Errorf("decoding MessageType: %w", err) } if _, ok := messageTypeMap[v]; !ok { return n, fmt.Errorf("'%d' is not a valid MessageType enum value", v) @@ -160,7 +167,7 @@ func (s MessageType) MarshalBinary() ([]byte, error) { func (s *MessageType) UnmarshalBinary(inp []byte) error { r := bytes.NewReader(inp) d := xdr.NewDecoder(r) - _, err := s.DecodeFrom(d) + _, err := s.DecodeFrom(d, xdr.DecodeDefaultMaxDepth) return err } @@ -217,10 +224,14 @@ func (e Color) EncodeTo(enc *xdr.Encoder) error { } var _ decoderFrom = (*Color)(nil) // DecodeFrom decodes this value using the Decoder. -func (e *Color) DecodeFrom(d *xdr.Decoder) (int, error) { +func (e *Color) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) { + if maxDepth == 0 { + return 0, fmt.Errorf("decoding Color: %w", ErrMaxDecodingDepthReached) + } + maxDepth -= 1 v, n, err := d.DecodeInt() if err != nil { - return n, fmt.Errorf("decoding Color: %s", err) + return n, fmt.Errorf("decoding Color: %w", err) } if _, ok := colorMap[v]; !ok { return n, fmt.Errorf("'%d' is not a valid Color enum value", v) @@ -240,7 +251,7 @@ func (s Color) MarshalBinary() ([]byte, error) { func (s *Color) UnmarshalBinary(inp []byte) error { r := bytes.NewReader(inp) d := xdr.NewDecoder(r) - _, err := s.DecodeFrom(d) + _, err := s.DecodeFrom(d, xdr.DecodeDefaultMaxDepth) return err } @@ -297,10 +308,14 @@ func (e Color2) EncodeTo(enc *xdr.Encoder) error { } var _ decoderFrom = (*Color2)(nil) // DecodeFrom decodes this value using the Decoder. -func (e *Color2) DecodeFrom(d *xdr.Decoder) (int, error) { +func (e *Color2) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) { + if maxDepth == 0 { + return 0, fmt.Errorf("decoding Color2: %w", ErrMaxDecodingDepthReached) + } + maxDepth -= 1 v, n, err := d.DecodeInt() if err != nil { - return n, fmt.Errorf("decoding Color2: %s", err) + return n, fmt.Errorf("decoding Color2: %w", err) } if _, ok := color2Map[v]; !ok { return n, fmt.Errorf("'%d' is not a valid Color2 enum value", v) @@ -320,7 +335,7 @@ func (s Color2) MarshalBinary() ([]byte, error) { func (s *Color2) UnmarshalBinary(inp []byte) error { r := bytes.NewReader(inp) d := xdr.NewDecoder(r) - _, err := s.DecodeFrom(d) + _, err := s.DecodeFrom(d, xdr.DecodeDefaultMaxDepth) return err } diff --git a/spec/output/generator_spec_go/nesting.x/MyXDR_generated.go b/spec/output/generator_spec_go/nesting.x/MyXDR_generated.go index d353ad983..7f8219c92 100644 --- a/spec/output/generator_spec_go/nesting.x/MyXDR_generated.go +++ b/spec/output/generator_spec_go/nesting.x/MyXDR_generated.go @@ -11,6 +11,7 @@ package MyXDR import ( "bytes" "encoding" + "errors" "io" "fmt" @@ -22,19 +23,21 @@ var XdrFilesSHA256 = map[string]string{ "spec/fixtures/generator/nesting.x": "5537949272c11f1bd09cf613a3751668b5018d686a1c2aaa3baa91183ca18f6a", } +var ErrMaxDecodingDepthReached = errors.New("maximum decoding depth reached") + type xdrType interface { xdrType() } type decoderFrom interface { - DecodeFrom(d *xdr.Decoder) (int, error) + DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) } // Unmarshal reads an xdr element from `r` into `v`. func Unmarshal(r io.Reader, v interface{}) (int, error) { if decodable, ok := v.(decoderFrom); ok { d := xdr.NewDecoder(r) - return decodable.DecodeFrom(d) + return decodable.DecodeFrom(d, xdr.DecodeDefaultMaxDepth) } // delegate to xdr package's Unmarshal return xdr.Unmarshal(r, v) @@ -97,10 +100,14 @@ func (e UnionKey) EncodeTo(enc *xdr.Encoder) error { } var _ decoderFrom = (*UnionKey)(nil) // DecodeFrom decodes this value using the Decoder. -func (e *UnionKey) DecodeFrom(d *xdr.Decoder) (int, error) { +func (e *UnionKey) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) { + if maxDepth == 0 { + return 0, fmt.Errorf("decoding UnionKey: %w", ErrMaxDecodingDepthReached) + } + maxDepth -= 1 v, n, err := d.DecodeInt() if err != nil { - return n, fmt.Errorf("decoding UnionKey: %s", err) + return n, fmt.Errorf("decoding UnionKey: %w", err) } if _, ok := unionKeyMap[v]; !ok { return n, fmt.Errorf("'%d' is not a valid UnionKey enum value", v) @@ -120,7 +127,7 @@ func (s UnionKey) MarshalBinary() ([]byte, error) { func (s *UnionKey) UnmarshalBinary(inp []byte) error { r := bytes.NewReader(inp) d := xdr.NewDecoder(r) - _, err := s.DecodeFrom(d) + _, err := s.DecodeFrom(d, xdr.DecodeDefaultMaxDepth) return err } @@ -152,14 +159,18 @@ if _, err = e.EncodeInt(int32(s)); err != nil { var _ decoderFrom = (*Foo)(nil) // DecodeFrom decodes this value using the Decoder. -func (s *Foo) DecodeFrom(d *xdr.Decoder) (int, error) { +func (s *Foo) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) { + if maxDepth == 0 { + return 0, fmt.Errorf("decoding Foo: %w", ErrMaxDecodingDepthReached) + } + maxDepth -= 1 var err error var n, nTmp int var v int32 v, nTmp, err = d.DecodeInt() n += nTmp if err != nil { - return n, fmt.Errorf("decoding Int: %s", err) + return n, fmt.Errorf("decoding Int: %w", err) } *s = Foo(v) return n, nil @@ -177,7 +188,7 @@ func (s Foo) MarshalBinary() ([]byte, error) { func (s *Foo) UnmarshalBinary(inp []byte) error { r := bytes.NewReader(inp) d := xdr.NewDecoder(r) - _, err := s.DecodeFrom(d) + _, err := s.DecodeFrom(d, xdr.DecodeDefaultMaxDepth) return err } @@ -213,13 +224,17 @@ if _, err = e.EncodeInt(int32(s.SomeInt)); err != nil { var _ decoderFrom = (*MyUnionOne)(nil) // DecodeFrom decodes this value using the Decoder. -func (s *MyUnionOne) DecodeFrom(d *xdr.Decoder) (int, error) { +func (s *MyUnionOne) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) { + if maxDepth == 0 { + return 0, fmt.Errorf("decoding MyUnionOne: %w", ErrMaxDecodingDepthReached) + } + maxDepth -= 1 var err error var n, nTmp int s.SomeInt, nTmp, err = d.DecodeInt() n += nTmp if err != nil { - return n, fmt.Errorf("decoding Int: %s", err) + return n, fmt.Errorf("decoding Int: %w", err) } return n, nil } @@ -236,7 +251,7 @@ func (s MyUnionOne) MarshalBinary() ([]byte, error) { func (s *MyUnionOne) UnmarshalBinary(inp []byte) error { r := bytes.NewReader(inp) d := xdr.NewDecoder(r) - _, err := s.DecodeFrom(d) + _, err := s.DecodeFrom(d, xdr.DecodeDefaultMaxDepth) return err } @@ -277,18 +292,22 @@ if err = s.Foo.EncodeTo(e); err != nil { var _ decoderFrom = (*MyUnionTwo)(nil) // DecodeFrom decodes this value using the Decoder. -func (s *MyUnionTwo) DecodeFrom(d *xdr.Decoder) (int, error) { +func (s *MyUnionTwo) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) { + if maxDepth == 0 { + return 0, fmt.Errorf("decoding MyUnionTwo: %w", ErrMaxDecodingDepthReached) + } + maxDepth -= 1 var err error var n, nTmp int s.SomeInt, nTmp, err = d.DecodeInt() n += nTmp if err != nil { - return n, fmt.Errorf("decoding Int: %s", err) + return n, fmt.Errorf("decoding Int: %w", err) } - nTmp, err = s.Foo.DecodeFrom(d) + nTmp, err = s.Foo.DecodeFrom(d, maxDepth) n += nTmp if err != nil { - return n, fmt.Errorf("decoding Foo: %s", err) + return n, fmt.Errorf("decoding Foo: %w", err) } return n, nil } @@ -305,7 +324,7 @@ func (s MyUnionTwo) MarshalBinary() ([]byte, error) { func (s *MyUnionTwo) UnmarshalBinary(inp []byte) error { r := bytes.NewReader(inp) d := xdr.NewDecoder(r) - _, err := s.DecodeFrom(d) + _, err := s.DecodeFrom(d, xdr.DecodeDefaultMaxDepth) return err } @@ -372,14 +391,14 @@ switch UnionKey(aType) { case UnionKeyOne: tv, ok := value.(MyUnionOne) if !ok { - err = fmt.Errorf("invalid value, must be MyUnionOne") + err = errors.New("invalid value, must be MyUnionOne") return } result.One = &tv case UnionKeyTwo: tv, ok := value.(MyUnionTwo) if !ok { - err = fmt.Errorf("invalid value, must be MyUnionTwo") + err = errors.New("invalid value, must be MyUnionTwo") return } result.Two = &tv @@ -463,29 +482,33 @@ return nil var _ decoderFrom = (*MyUnion)(nil) // DecodeFrom decodes this value using the Decoder. -func (u *MyUnion) DecodeFrom(d *xdr.Decoder) (int, error) { +func (u *MyUnion) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) { + if maxDepth == 0 { + return 0, fmt.Errorf("decoding MyUnion: %w", ErrMaxDecodingDepthReached) + } + maxDepth -= 1 var err error var n, nTmp int - nTmp, err = u.Type.DecodeFrom(d) + nTmp, err = u.Type.DecodeFrom(d, maxDepth) n += nTmp if err != nil { - return n, fmt.Errorf("decoding UnionKey: %s", err) + return n, fmt.Errorf("decoding UnionKey: %w", err) } switch UnionKey(u.Type) { case UnionKeyOne: u.One = new(MyUnionOne) - nTmp, err = (*u.One).DecodeFrom(d) + nTmp, err = (*u.One).DecodeFrom(d, maxDepth) n += nTmp if err != nil { - return n, fmt.Errorf("decoding MyUnionOne: %s", err) + return n, fmt.Errorf("decoding MyUnionOne: %w", err) } return n, nil case UnionKeyTwo: u.Two = new(MyUnionTwo) - nTmp, err = (*u.Two).DecodeFrom(d) + nTmp, err = (*u.Two).DecodeFrom(d, maxDepth) n += nTmp if err != nil { - return n, fmt.Errorf("decoding MyUnionTwo: %s", err) + return n, fmt.Errorf("decoding MyUnionTwo: %w", err) } return n, nil case UnionKeyOffer: @@ -507,7 +530,7 @@ func (s MyUnion) MarshalBinary() ([]byte, error) { func (s *MyUnion) UnmarshalBinary(inp []byte) error { r := bytes.NewReader(inp) d := xdr.NewDecoder(r) - _, err := s.DecodeFrom(d) + _, err := s.DecodeFrom(d, xdr.DecodeDefaultMaxDepth) return err } diff --git a/spec/output/generator_spec_go/optional.x/MyXDR_generated.go b/spec/output/generator_spec_go/optional.x/MyXDR_generated.go index 0f8db98de..cb030c26d 100644 --- a/spec/output/generator_spec_go/optional.x/MyXDR_generated.go +++ b/spec/output/generator_spec_go/optional.x/MyXDR_generated.go @@ -11,6 +11,7 @@ package MyXDR import ( "bytes" "encoding" + "errors" "io" "fmt" @@ -22,19 +23,21 @@ var XdrFilesSHA256 = map[string]string{ "spec/fixtures/generator/optional.x": "3241e832fcf00bca4315ecb6c259621dafb0e302a63a993f5504b0b5cebb6bd7", } +var ErrMaxDecodingDepthReached = errors.New("maximum decoding depth reached") + type xdrType interface { xdrType() } type decoderFrom interface { - DecodeFrom(d *xdr.Decoder) (int, error) + DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) } // Unmarshal reads an xdr element from `r` into `v`. func Unmarshal(r io.Reader, v interface{}) (int, error) { if decodable, ok := v.(decoderFrom); ok { d := xdr.NewDecoder(r) - return decodable.DecodeFrom(d) + return decodable.DecodeFrom(d, xdr.DecodeDefaultMaxDepth) } // delegate to xdr package's Unmarshal return xdr.Unmarshal(r, v) @@ -71,14 +74,18 @@ if _, err = e.EncodeInt(int32(s)); err != nil { var _ decoderFrom = (*Arr)(nil) // DecodeFrom decodes this value using the Decoder. -func (s *Arr) DecodeFrom(d *xdr.Decoder) (int, error) { +func (s *Arr) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) { + if maxDepth == 0 { + return 0, fmt.Errorf("decoding Arr: %w", ErrMaxDecodingDepthReached) + } + maxDepth -= 1 var err error var n, nTmp int var v [2]int32 v, nTmp, err = d.DecodeInt() n += nTmp if err != nil { - return n, fmt.Errorf("decoding Int: %s", err) + return n, fmt.Errorf("decoding Int: %w", err) } *s = Arr(v) return n, nil @@ -96,7 +103,7 @@ func (s Arr) MarshalBinary() ([]byte, error) { func (s *Arr) UnmarshalBinary(inp []byte) error { r := bytes.NewReader(inp) d := xdr.NewDecoder(r) - _, err := s.DecodeFrom(d) + _, err := s.DecodeFrom(d, xdr.DecodeDefaultMaxDepth) return err } @@ -158,14 +165,18 @@ if err = (*s.ThirdOption).EncodeTo(e); err != nil { var _ decoderFrom = (*HasOptions)(nil) // DecodeFrom decodes this value using the Decoder. -func (s *HasOptions) DecodeFrom(d *xdr.Decoder) (int, error) { +func (s *HasOptions) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) { + if maxDepth == 0 { + return 0, fmt.Errorf("decoding HasOptions: %w", ErrMaxDecodingDepthReached) + } + maxDepth -= 1 var err error var n, nTmp int var b bool b, nTmp, err = d.DecodeBool() n += nTmp if err != nil { - return n, fmt.Errorf("decoding Int: %s", err) + return n, fmt.Errorf("decoding Int: %w", err) } s.FirstOption = nil if b { @@ -173,13 +184,13 @@ if err != nil { s.FirstOption, nTmp, err = d.DecodeInt() n += nTmp if err != nil { - return n, fmt.Errorf("decoding Int: %s", err) + return n, fmt.Errorf("decoding Int: %w", err) } } b, nTmp, err = d.DecodeBool() n += nTmp if err != nil { - return n, fmt.Errorf("decoding Int: %s", err) + return n, fmt.Errorf("decoding Int: %w", err) } s.SecondOption = nil if b { @@ -187,21 +198,21 @@ if err != nil { s.SecondOption, nTmp, err = d.DecodeInt() n += nTmp if err != nil { - return n, fmt.Errorf("decoding Int: %s", err) + return n, fmt.Errorf("decoding Int: %w", err) } } b, nTmp, err = d.DecodeBool() n += nTmp if err != nil { - return n, fmt.Errorf("decoding Arr: %s", err) + return n, fmt.Errorf("decoding Arr: %w", err) } s.ThirdOption = nil if b { s.ThirdOption = new(Arr) - nTmp, err = s.ThirdOption.DecodeFrom(d) + nTmp, err = s.ThirdOption.DecodeFrom(d, maxDepth) n += nTmp if err != nil { - return n, fmt.Errorf("decoding Arr: %s", err) + return n, fmt.Errorf("decoding Arr: %w", err) } } return n, nil @@ -219,7 +230,7 @@ func (s HasOptions) MarshalBinary() ([]byte, error) { func (s *HasOptions) UnmarshalBinary(inp []byte) error { r := bytes.NewReader(inp) d := xdr.NewDecoder(r) - _, err := s.DecodeFrom(d) + _, err := s.DecodeFrom(d, xdr.DecodeDefaultMaxDepth) return err } diff --git a/spec/output/generator_spec_go/struct.x/MyXDR_generated.go b/spec/output/generator_spec_go/struct.x/MyXDR_generated.go index 9f4c1cb73..b2f7d834b 100644 --- a/spec/output/generator_spec_go/struct.x/MyXDR_generated.go +++ b/spec/output/generator_spec_go/struct.x/MyXDR_generated.go @@ -11,6 +11,7 @@ package MyXDR import ( "bytes" "encoding" + "errors" "io" "fmt" @@ -22,19 +23,21 @@ var XdrFilesSHA256 = map[string]string{ "spec/fixtures/generator/struct.x": "c6911a83390e3b499c078fd0c579132eacce88a4a0538d3b8b5e57747a58db4a", } +var ErrMaxDecodingDepthReached = errors.New("maximum decoding depth reached") + type xdrType interface { xdrType() } type decoderFrom interface { - DecodeFrom(d *xdr.Decoder) (int, error) + DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) } // Unmarshal reads an xdr element from `r` into `v`. func Unmarshal(r io.Reader, v interface{}) (int, error) { if decodable, ok := v.(decoderFrom); ok { d := xdr.NewDecoder(r) - return decodable.DecodeFrom(d) + return decodable.DecodeFrom(d, xdr.DecodeDefaultMaxDepth) } // delegate to xdr package's Unmarshal return xdr.Unmarshal(r, v) @@ -72,14 +75,18 @@ if _, err = e.EncodeHyper(int64(s)); err != nil { var _ decoderFrom = (*Int64)(nil) // DecodeFrom decodes this value using the Decoder. -func (s *Int64) DecodeFrom(d *xdr.Decoder) (int, error) { +func (s *Int64) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) { + if maxDepth == 0 { + return 0, fmt.Errorf("decoding Int64: %w", ErrMaxDecodingDepthReached) + } + maxDepth -= 1 var err error var n, nTmp int var v int64 v, nTmp, err = d.DecodeHyper() n += nTmp if err != nil { - return n, fmt.Errorf("decoding Hyper: %s", err) + return n, fmt.Errorf("decoding Hyper: %w", err) } *s = Int64(v) return n, nil @@ -97,7 +104,7 @@ func (s Int64) MarshalBinary() ([]byte, error) { func (s *Int64) UnmarshalBinary(inp []byte) error { r := bytes.NewReader(inp) d := xdr.NewDecoder(r) - _, err := s.DecodeFrom(d) + _, err := s.DecodeFrom(d, xdr.DecodeDefaultMaxDepth) return err } @@ -154,33 +161,37 @@ if _, err = e.EncodeString(string(s.MaxString)); err != nil { var _ decoderFrom = (*MyStruct)(nil) // DecodeFrom decodes this value using the Decoder. -func (s *MyStruct) DecodeFrom(d *xdr.Decoder) (int, error) { +func (s *MyStruct) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) { + if maxDepth == 0 { + return 0, fmt.Errorf("decoding MyStruct: %w", ErrMaxDecodingDepthReached) + } + maxDepth -= 1 var err error var n, nTmp int s.SomeInt, nTmp, err = d.DecodeInt() n += nTmp if err != nil { - return n, fmt.Errorf("decoding Int: %s", err) + return n, fmt.Errorf("decoding Int: %w", err) } - nTmp, err = s.ABigInt.DecodeFrom(d) + nTmp, err = s.ABigInt.DecodeFrom(d, maxDepth) n += nTmp if err != nil { - return n, fmt.Errorf("decoding Int64: %s", err) + return n, fmt.Errorf("decoding Int64: %w", err) } nTmp, err = d.DecodeFixedOpaqueInplace(s.SomeOpaque[:]) n += nTmp if err != nil { - return n, fmt.Errorf("decoding SomeOpaque: %s", err) + return n, fmt.Errorf("decoding SomeOpaque: %w", err) } s.SomeString, nTmp, err = d.DecodeString(0) n += nTmp if err != nil { - return n, fmt.Errorf("decoding SomeString: %s", err) + return n, fmt.Errorf("decoding SomeString: %w", err) } s.MaxString, nTmp, err = d.DecodeString(100) n += nTmp if err != nil { - return n, fmt.Errorf("decoding MaxString: %s", err) + return n, fmt.Errorf("decoding MaxString: %w", err) } return n, nil } @@ -197,7 +208,7 @@ func (s MyStruct) MarshalBinary() ([]byte, error) { func (s *MyStruct) UnmarshalBinary(inp []byte) error { r := bytes.NewReader(inp) d := xdr.NewDecoder(r) - _, err := s.DecodeFrom(d) + _, err := s.DecodeFrom(d, xdr.DecodeDefaultMaxDepth) return err } diff --git a/spec/output/generator_spec_go/test.x/MyXDR_generated.go b/spec/output/generator_spec_go/test.x/MyXDR_generated.go index 62ac4168c..0bd57e8fb 100644 --- a/spec/output/generator_spec_go/test.x/MyXDR_generated.go +++ b/spec/output/generator_spec_go/test.x/MyXDR_generated.go @@ -11,6 +11,7 @@ package MyXDR import ( "bytes" "encoding" + "errors" "io" "fmt" @@ -22,19 +23,21 @@ var XdrFilesSHA256 = map[string]string{ "spec/fixtures/generator/test.x": "d29a98a6a3b9bf533a3e6712d928e0bed655e0f462ac4dae810c65d52ca9af41", } +var ErrMaxDecodingDepthReached = errors.New("maximum decoding depth reached") + type xdrType interface { xdrType() } type decoderFrom interface { - DecodeFrom(d *xdr.Decoder) (int, error) + DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) } // Unmarshal reads an xdr element from `r` into `v`. func Unmarshal(r io.Reader, v interface{}) (int, error) { if decodable, ok := v.(decoderFrom); ok { d := xdr.NewDecoder(r) - return decodable.DecodeFrom(d) + return decodable.DecodeFrom(d, xdr.DecodeDefaultMaxDepth) } // delegate to xdr package's Unmarshal return xdr.Unmarshal(r, v) @@ -76,13 +79,17 @@ if _, err = e.EncodeFixedOpaque(s[:]); err != nil { var _ decoderFrom = (*Uint512)(nil) // DecodeFrom decodes this value using the Decoder. -func (s *Uint512) DecodeFrom(d *xdr.Decoder) (int, error) { +func (s *Uint512) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) { + if maxDepth == 0 { + return 0, fmt.Errorf("decoding Uint512: %w", ErrMaxDecodingDepthReached) + } + maxDepth -= 1 var err error var n, nTmp int nTmp, err = d.DecodeFixedOpaqueInplace(s[:]) n += nTmp if err != nil { - return n, fmt.Errorf("decoding Uint512: %s", err) + return n, fmt.Errorf("decoding Uint512: %w", err) } return n, nil } @@ -99,7 +106,7 @@ func (s Uint512) MarshalBinary() ([]byte, error) { func (s *Uint512) UnmarshalBinary(inp []byte) error { r := bytes.NewReader(inp) d := xdr.NewDecoder(r) - _, err := s.DecodeFrom(d) + _, err := s.DecodeFrom(d, xdr.DecodeDefaultMaxDepth) return err } @@ -135,13 +142,17 @@ if _, err = e.EncodeOpaque(s[:]); err != nil { var _ decoderFrom = (*Uint513)(nil) // DecodeFrom decodes this value using the Decoder. -func (s *Uint513) DecodeFrom(d *xdr.Decoder) (int, error) { +func (s *Uint513) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) { + if maxDepth == 0 { + return 0, fmt.Errorf("decoding Uint513: %w", ErrMaxDecodingDepthReached) + } + maxDepth -= 1 var err error var n, nTmp int (*s), nTmp, err = d.DecodeOpaque(64) n += nTmp if err != nil { - return n, fmt.Errorf("decoding Uint513: %s", err) + return n, fmt.Errorf("decoding Uint513: %w", err) } return n, nil } @@ -158,7 +169,7 @@ func (s Uint513) MarshalBinary() ([]byte, error) { func (s *Uint513) UnmarshalBinary(inp []byte) error { r := bytes.NewReader(inp) d := xdr.NewDecoder(r) - _, err := s.DecodeFrom(d) + _, err := s.DecodeFrom(d, xdr.DecodeDefaultMaxDepth) return err } @@ -190,13 +201,17 @@ if _, err = e.EncodeOpaque(s[:]); err != nil { var _ decoderFrom = (*Uint514)(nil) // DecodeFrom decodes this value using the Decoder. -func (s *Uint514) DecodeFrom(d *xdr.Decoder) (int, error) { +func (s *Uint514) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) { + if maxDepth == 0 { + return 0, fmt.Errorf("decoding Uint514: %w", ErrMaxDecodingDepthReached) + } + maxDepth -= 1 var err error var n, nTmp int (*s), nTmp, err = d.DecodeOpaque(0) n += nTmp if err != nil { - return n, fmt.Errorf("decoding Uint514: %s", err) + return n, fmt.Errorf("decoding Uint514: %w", err) } return n, nil } @@ -213,7 +228,7 @@ func (s Uint514) MarshalBinary() ([]byte, error) { func (s *Uint514) UnmarshalBinary(inp []byte) error { r := bytes.NewReader(inp) d := xdr.NewDecoder(r) - _, err := s.DecodeFrom(d) + _, err := s.DecodeFrom(d, xdr.DecodeDefaultMaxDepth) return err } @@ -249,14 +264,18 @@ if _, err = e.EncodeString(string(s)); err != nil { var _ decoderFrom = (*Str)(nil) // DecodeFrom decodes this value using the Decoder. -func (s *Str) DecodeFrom(d *xdr.Decoder) (int, error) { +func (s *Str) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) { + if maxDepth == 0 { + return 0, fmt.Errorf("decoding Str: %w", ErrMaxDecodingDepthReached) + } + maxDepth -= 1 var err error var n, nTmp int var v string v, nTmp, err = d.DecodeString(64) n += nTmp if err != nil { - return n, fmt.Errorf("decoding Str: %s", err) + return n, fmt.Errorf("decoding Str: %w", err) } *s = Str(v) return n, nil @@ -274,7 +293,7 @@ func (s Str) MarshalBinary() ([]byte, error) { func (s *Str) UnmarshalBinary(inp []byte) error { r := bytes.NewReader(inp) d := xdr.NewDecoder(r) - _, err := s.DecodeFrom(d) + _, err := s.DecodeFrom(d, xdr.DecodeDefaultMaxDepth) return err } @@ -306,14 +325,18 @@ if _, err = e.EncodeString(string(s)); err != nil { var _ decoderFrom = (*Str2)(nil) // DecodeFrom decodes this value using the Decoder. -func (s *Str2) DecodeFrom(d *xdr.Decoder) (int, error) { +func (s *Str2) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) { + if maxDepth == 0 { + return 0, fmt.Errorf("decoding Str2: %w", ErrMaxDecodingDepthReached) + } + maxDepth -= 1 var err error var n, nTmp int var v string v, nTmp, err = d.DecodeString(0) n += nTmp if err != nil { - return n, fmt.Errorf("decoding Str2: %s", err) + return n, fmt.Errorf("decoding Str2: %w", err) } *s = Str2(v) return n, nil @@ -331,7 +354,7 @@ func (s Str2) MarshalBinary() ([]byte, error) { func (s *Str2) UnmarshalBinary(inp []byte) error { r := bytes.NewReader(inp) d := xdr.NewDecoder(r) - _, err := s.DecodeFrom(d) + _, err := s.DecodeFrom(d, xdr.DecodeDefaultMaxDepth) return err } @@ -367,13 +390,17 @@ if _, err = e.EncodeFixedOpaque(s[:]); err != nil { var _ decoderFrom = (*Hash)(nil) // DecodeFrom decodes this value using the Decoder. -func (s *Hash) DecodeFrom(d *xdr.Decoder) (int, error) { +func (s *Hash) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) { + if maxDepth == 0 { + return 0, fmt.Errorf("decoding Hash: %w", ErrMaxDecodingDepthReached) + } + maxDepth -= 1 var err error var n, nTmp int nTmp, err = d.DecodeFixedOpaqueInplace(s[:]) n += nTmp if err != nil { - return n, fmt.Errorf("decoding Hash: %s", err) + return n, fmt.Errorf("decoding Hash: %w", err) } return n, nil } @@ -390,7 +417,7 @@ func (s Hash) MarshalBinary() ([]byte, error) { func (s *Hash) UnmarshalBinary(inp []byte) error { r := bytes.NewReader(inp) d := xdr.NewDecoder(r) - _, err := s.DecodeFrom(d) + _, err := s.DecodeFrom(d, xdr.DecodeDefaultMaxDepth) return err } @@ -423,14 +450,18 @@ if err = s[i].EncodeTo(e); err != nil { var _ decoderFrom = (*Hashes1)(nil) // DecodeFrom decodes this value using the Decoder. -func (s *Hashes1) DecodeFrom(d *xdr.Decoder) (int, error) { +func (s *Hashes1) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) { + if maxDepth == 0 { + return 0, fmt.Errorf("decoding Hashes1: %w", ErrMaxDecodingDepthReached) + } + maxDepth -= 1 var err error var n, nTmp int for i := 0; i < len(s); i++ { - nTmp, err = s[i].DecodeFrom(d) + nTmp, err = s[i].DecodeFrom(d, maxDepth) n += nTmp if err != nil { - return n, fmt.Errorf("decoding Hash: %s", err) + return n, fmt.Errorf("decoding Hash: %w", err) } } return n, nil @@ -448,7 +479,7 @@ func (s Hashes1) MarshalBinary() ([]byte, error) { func (s *Hashes1) UnmarshalBinary(inp []byte) error { r := bytes.NewReader(inp) d := xdr.NewDecoder(r) - _, err := s.DecodeFrom(d) + _, err := s.DecodeFrom(d, xdr.DecodeDefaultMaxDepth) return err } @@ -488,14 +519,18 @@ if err = s[i].EncodeTo(e); err != nil { var _ decoderFrom = (*Hashes2)(nil) // DecodeFrom decodes this value using the Decoder. -func (s *Hashes2) DecodeFrom(d *xdr.Decoder) (int, error) { +func (s *Hashes2) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) { + if maxDepth == 0 { + return 0, fmt.Errorf("decoding Hashes2: %w", ErrMaxDecodingDepthReached) + } + maxDepth -= 1 var err error var n, nTmp int var l uint32 l, nTmp, err = d.DecodeUint() n += nTmp if err != nil { - return n, fmt.Errorf("decoding Hash: %s", err) + return n, fmt.Errorf("decoding Hash: %w", err) } if l > 12 { return n, fmt.Errorf("decoding Hash: data size (%d) exceeds size limit (12)", l) @@ -504,10 +539,10 @@ if err != nil { if l > 0 { (*s) = make([]Hash, l) for i := uint32(0); i < l; i++ { - nTmp, err = (*s)[i].DecodeFrom(d) + nTmp, err = (*s)[i].DecodeFrom(d, maxDepth) n += nTmp if err != nil { - return n, fmt.Errorf("decoding Hash: %s", err) + return n, fmt.Errorf("decoding Hash: %w", err) } } } @@ -526,7 +561,7 @@ func (s Hashes2) MarshalBinary() ([]byte, error) { func (s *Hashes2) UnmarshalBinary(inp []byte) error { r := bytes.NewReader(inp) d := xdr.NewDecoder(r) - _, err := s.DecodeFrom(d) + _, err := s.DecodeFrom(d, xdr.DecodeDefaultMaxDepth) return err } @@ -562,23 +597,27 @@ if err = s[i].EncodeTo(e); err != nil { var _ decoderFrom = (*Hashes3)(nil) // DecodeFrom decodes this value using the Decoder. -func (s *Hashes3) DecodeFrom(d *xdr.Decoder) (int, error) { +func (s *Hashes3) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) { + if maxDepth == 0 { + return 0, fmt.Errorf("decoding Hashes3: %w", ErrMaxDecodingDepthReached) + } + maxDepth -= 1 var err error var n, nTmp int var l uint32 l, nTmp, err = d.DecodeUint() n += nTmp if err != nil { - return n, fmt.Errorf("decoding Hash: %s", err) + return n, fmt.Errorf("decoding Hash: %w", err) } (*s) = nil if l > 0 { (*s) = make([]Hash, l) for i := uint32(0); i < l; i++ { - nTmp, err = (*s)[i].DecodeFrom(d) + nTmp, err = (*s)[i].DecodeFrom(d, maxDepth) n += nTmp if err != nil { - return n, fmt.Errorf("decoding Hash: %s", err) + return n, fmt.Errorf("decoding Hash: %w", err) } } } @@ -597,7 +636,7 @@ func (s Hashes3) MarshalBinary() ([]byte, error) { func (s *Hashes3) UnmarshalBinary(inp []byte) error { r := bytes.NewReader(inp) d := xdr.NewDecoder(r) - _, err := s.DecodeFrom(d) + _, err := s.DecodeFrom(d, xdr.DecodeDefaultMaxDepth) return err } @@ -639,14 +678,18 @@ if _, err = e.EncodeInt(int32(s)); err != nil { var _ decoderFrom = (*Int1)(nil) // DecodeFrom decodes this value using the Decoder. -func (s *Int1) DecodeFrom(d *xdr.Decoder) (int, error) { +func (s *Int1) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) { + if maxDepth == 0 { + return 0, fmt.Errorf("decoding Int1: %w", ErrMaxDecodingDepthReached) + } + maxDepth -= 1 var err error var n, nTmp int var v int32 v, nTmp, err = d.DecodeInt() n += nTmp if err != nil { - return n, fmt.Errorf("decoding Int: %s", err) + return n, fmt.Errorf("decoding Int: %w", err) } *s = Int1(v) return n, nil @@ -664,7 +707,7 @@ func (s Int1) MarshalBinary() ([]byte, error) { func (s *Int1) UnmarshalBinary(inp []byte) error { r := bytes.NewReader(inp) d := xdr.NewDecoder(r) - _, err := s.DecodeFrom(d) + _, err := s.DecodeFrom(d, xdr.DecodeDefaultMaxDepth) return err } @@ -696,14 +739,18 @@ if _, err = e.EncodeHyper(int64(s)); err != nil { var _ decoderFrom = (*Int2)(nil) // DecodeFrom decodes this value using the Decoder. -func (s *Int2) DecodeFrom(d *xdr.Decoder) (int, error) { +func (s *Int2) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) { + if maxDepth == 0 { + return 0, fmt.Errorf("decoding Int2: %w", ErrMaxDecodingDepthReached) + } + maxDepth -= 1 var err error var n, nTmp int var v int64 v, nTmp, err = d.DecodeHyper() n += nTmp if err != nil { - return n, fmt.Errorf("decoding Hyper: %s", err) + return n, fmt.Errorf("decoding Hyper: %w", err) } *s = Int2(v) return n, nil @@ -721,7 +768,7 @@ func (s Int2) MarshalBinary() ([]byte, error) { func (s *Int2) UnmarshalBinary(inp []byte) error { r := bytes.NewReader(inp) d := xdr.NewDecoder(r) - _, err := s.DecodeFrom(d) + _, err := s.DecodeFrom(d, xdr.DecodeDefaultMaxDepth) return err } @@ -753,14 +800,18 @@ if _, err = e.EncodeUint(uint32(s)); err != nil { var _ decoderFrom = (*Int3)(nil) // DecodeFrom decodes this value using the Decoder. -func (s *Int3) DecodeFrom(d *xdr.Decoder) (int, error) { +func (s *Int3) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) { + if maxDepth == 0 { + return 0, fmt.Errorf("decoding Int3: %w", ErrMaxDecodingDepthReached) + } + maxDepth -= 1 var err error var n, nTmp int var v uint32 v, nTmp, err = d.DecodeUint() n += nTmp if err != nil { - return n, fmt.Errorf("decoding Unsigned int: %s", err) + return n, fmt.Errorf("decoding Unsigned int: %w", err) } *s = Int3(v) return n, nil @@ -778,7 +829,7 @@ func (s Int3) MarshalBinary() ([]byte, error) { func (s *Int3) UnmarshalBinary(inp []byte) error { r := bytes.NewReader(inp) d := xdr.NewDecoder(r) - _, err := s.DecodeFrom(d) + _, err := s.DecodeFrom(d, xdr.DecodeDefaultMaxDepth) return err } @@ -810,14 +861,18 @@ if _, err = e.EncodeUhyper(uint64(s)); err != nil { var _ decoderFrom = (*Int4)(nil) // DecodeFrom decodes this value using the Decoder. -func (s *Int4) DecodeFrom(d *xdr.Decoder) (int, error) { +func (s *Int4) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) { + if maxDepth == 0 { + return 0, fmt.Errorf("decoding Int4: %w", ErrMaxDecodingDepthReached) + } + maxDepth -= 1 var err error var n, nTmp int var v uint64 v, nTmp, err = d.DecodeUhyper() n += nTmp if err != nil { - return n, fmt.Errorf("decoding Unsigned hyper: %s", err) + return n, fmt.Errorf("decoding Unsigned hyper: %w", err) } *s = Int4(v) return n, nil @@ -835,7 +890,7 @@ func (s Int4) MarshalBinary() ([]byte, error) { func (s *Int4) UnmarshalBinary(inp []byte) error { r := bytes.NewReader(inp) d := xdr.NewDecoder(r) - _, err := s.DecodeFrom(d) + _, err := s.DecodeFrom(d, xdr.DecodeDefaultMaxDepth) return err } @@ -907,53 +962,57 @@ if _, err = e.EncodeBool(bool(s.Field7)); err != nil { var _ decoderFrom = (*MyStruct)(nil) // DecodeFrom decodes this value using the Decoder. -func (s *MyStruct) DecodeFrom(d *xdr.Decoder) (int, error) { +func (s *MyStruct) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) { + if maxDepth == 0 { + return 0, fmt.Errorf("decoding MyStruct: %w", ErrMaxDecodingDepthReached) + } + maxDepth -= 1 var err error var n, nTmp int - nTmp, err = s.Field1.DecodeFrom(d) + nTmp, err = s.Field1.DecodeFrom(d, maxDepth) n += nTmp if err != nil { - return n, fmt.Errorf("decoding Uint512: %s", err) + return n, fmt.Errorf("decoding Uint512: %w", err) } var b bool b, nTmp, err = d.DecodeBool() n += nTmp if err != nil { - return n, fmt.Errorf("decoding OptHash1: %s", err) + return n, fmt.Errorf("decoding OptHash1: %w", err) } s.Field2 = nil if b { s.Field2 = new(Hash) - nTmp, err = s.Field2.DecodeFrom(d) + nTmp, err = s.Field2.DecodeFrom(d, maxDepth) n += nTmp if err != nil { - return n, fmt.Errorf("decoding OptHash1: %s", err) + return n, fmt.Errorf("decoding OptHash1: %w", err) } } - nTmp, err = s.Field3.DecodeFrom(d) + nTmp, err = s.Field3.DecodeFrom(d, maxDepth) n += nTmp if err != nil { - return n, fmt.Errorf("decoding Int1: %s", err) + return n, fmt.Errorf("decoding Int1: %w", err) } s.Field4, nTmp, err = d.DecodeUint() n += nTmp if err != nil { - return n, fmt.Errorf("decoding Unsigned int: %s", err) + return n, fmt.Errorf("decoding Unsigned int: %w", err) } - nTmp, err = d.Decode(&s.Field5) + nTmp, err = d.DecodeWithMaxDepth(&s.Field5, maxDepth) n += nTmp if err != nil { - return n, fmt.Errorf("decoding Float: %s", err) + return n, fmt.Errorf("decoding Float: %w", err) } - nTmp, err = d.Decode(&s.Field6) + nTmp, err = d.DecodeWithMaxDepth(&s.Field6, maxDepth) n += nTmp if err != nil { - return n, fmt.Errorf("decoding Double: %s", err) + return n, fmt.Errorf("decoding Double: %w", err) } s.Field7, nTmp, err = d.DecodeBool() n += nTmp if err != nil { - return n, fmt.Errorf("decoding Bool: %s", err) + return n, fmt.Errorf("decoding Bool: %w", err) } return n, nil } @@ -970,7 +1029,7 @@ func (s MyStruct) MarshalBinary() ([]byte, error) { func (s *MyStruct) UnmarshalBinary(inp []byte) error { r := bytes.NewReader(inp) d := xdr.NewDecoder(r) - _, err := s.DecodeFrom(d) + _, err := s.DecodeFrom(d, xdr.DecodeDefaultMaxDepth) return err } @@ -1012,23 +1071,27 @@ if err = s.Members[i].EncodeTo(e); err != nil { var _ decoderFrom = (*LotsOfMyStructs)(nil) // DecodeFrom decodes this value using the Decoder. -func (s *LotsOfMyStructs) DecodeFrom(d *xdr.Decoder) (int, error) { +func (s *LotsOfMyStructs) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) { + if maxDepth == 0 { + return 0, fmt.Errorf("decoding LotsOfMyStructs: %w", ErrMaxDecodingDepthReached) + } + maxDepth -= 1 var err error var n, nTmp int var l uint32 l, nTmp, err = d.DecodeUint() n += nTmp if err != nil { - return n, fmt.Errorf("decoding MyStruct: %s", err) + return n, fmt.Errorf("decoding MyStruct: %w", err) } s.Members = nil if l > 0 { s.Members = make([]MyStruct, l) for i := uint32(0); i < l; i++ { - nTmp, err = s.Members[i].DecodeFrom(d) + nTmp, err = s.Members[i].DecodeFrom(d, maxDepth) n += nTmp if err != nil { - return n, fmt.Errorf("decoding MyStruct: %s", err) + return n, fmt.Errorf("decoding MyStruct: %w", err) } } } @@ -1047,7 +1110,7 @@ func (s LotsOfMyStructs) MarshalBinary() ([]byte, error) { func (s *LotsOfMyStructs) UnmarshalBinary(inp []byte) error { r := bytes.NewReader(inp) d := xdr.NewDecoder(r) - _, err := s.DecodeFrom(d) + _, err := s.DecodeFrom(d, xdr.DecodeDefaultMaxDepth) return err } @@ -1084,13 +1147,17 @@ if err = s.Data.EncodeTo(e); err != nil { var _ decoderFrom = (*HasStuff)(nil) // DecodeFrom decodes this value using the Decoder. -func (s *HasStuff) DecodeFrom(d *xdr.Decoder) (int, error) { +func (s *HasStuff) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) { + if maxDepth == 0 { + return 0, fmt.Errorf("decoding HasStuff: %w", ErrMaxDecodingDepthReached) + } + maxDepth -= 1 var err error var n, nTmp int - nTmp, err = s.Data.DecodeFrom(d) + nTmp, err = s.Data.DecodeFrom(d, maxDepth) n += nTmp if err != nil { - return n, fmt.Errorf("decoding LotsOfMyStructs: %s", err) + return n, fmt.Errorf("decoding LotsOfMyStructs: %w", err) } return n, nil } @@ -1107,7 +1174,7 @@ func (s HasStuff) MarshalBinary() ([]byte, error) { func (s *HasStuff) UnmarshalBinary(inp []byte) error { r := bytes.NewReader(inp) d := xdr.NewDecoder(r) - _, err := s.DecodeFrom(d) + _, err := s.DecodeFrom(d, xdr.DecodeDefaultMaxDepth) return err } @@ -1164,10 +1231,14 @@ func (e Color) EncodeTo(enc *xdr.Encoder) error { } var _ decoderFrom = (*Color)(nil) // DecodeFrom decodes this value using the Decoder. -func (e *Color) DecodeFrom(d *xdr.Decoder) (int, error) { +func (e *Color) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) { + if maxDepth == 0 { + return 0, fmt.Errorf("decoding Color: %w", ErrMaxDecodingDepthReached) + } + maxDepth -= 1 v, n, err := d.DecodeInt() if err != nil { - return n, fmt.Errorf("decoding Color: %s", err) + return n, fmt.Errorf("decoding Color: %w", err) } if _, ok := colorMap[v]; !ok { return n, fmt.Errorf("'%d' is not a valid Color enum value", v) @@ -1187,7 +1258,7 @@ func (s Color) MarshalBinary() ([]byte, error) { func (s *Color) UnmarshalBinary(inp []byte) error { r := bytes.NewReader(inp) d := xdr.NewDecoder(r) - _, err := s.DecodeFrom(d) + _, err := s.DecodeFrom(d, xdr.DecodeDefaultMaxDepth) return err } @@ -1253,10 +1324,14 @@ func (e NesterNestedEnum) EncodeTo(enc *xdr.Encoder) error { } var _ decoderFrom = (*NesterNestedEnum)(nil) // DecodeFrom decodes this value using the Decoder. -func (e *NesterNestedEnum) DecodeFrom(d *xdr.Decoder) (int, error) { +func (e *NesterNestedEnum) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) { + if maxDepth == 0 { + return 0, fmt.Errorf("decoding NesterNestedEnum: %w", ErrMaxDecodingDepthReached) + } + maxDepth -= 1 v, n, err := d.DecodeInt() if err != nil { - return n, fmt.Errorf("decoding NesterNestedEnum: %s", err) + return n, fmt.Errorf("decoding NesterNestedEnum: %w", err) } if _, ok := nestedEnumMap[v]; !ok { return n, fmt.Errorf("'%d' is not a valid NesterNestedEnum enum value", v) @@ -1276,7 +1351,7 @@ func (s NesterNestedEnum) MarshalBinary() ([]byte, error) { func (s *NesterNestedEnum) UnmarshalBinary(inp []byte) error { r := bytes.NewReader(inp) d := xdr.NewDecoder(r) - _, err := s.DecodeFrom(d) + _, err := s.DecodeFrom(d, xdr.DecodeDefaultMaxDepth) return err } @@ -1312,13 +1387,17 @@ if _, err = e.EncodeInt(int32(s.Blah)); err != nil { var _ decoderFrom = (*NesterNestedStruct)(nil) // DecodeFrom decodes this value using the Decoder. -func (s *NesterNestedStruct) DecodeFrom(d *xdr.Decoder) (int, error) { +func (s *NesterNestedStruct) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) { + if maxDepth == 0 { + return 0, fmt.Errorf("decoding NesterNestedStruct: %w", ErrMaxDecodingDepthReached) + } + maxDepth -= 1 var err error var n, nTmp int s.Blah, nTmp, err = d.DecodeInt() n += nTmp if err != nil { - return n, fmt.Errorf("decoding Int: %s", err) + return n, fmt.Errorf("decoding Int: %w", err) } return n, nil } @@ -1335,7 +1414,7 @@ func (s NesterNestedStruct) MarshalBinary() ([]byte, error) { func (s *NesterNestedStruct) UnmarshalBinary(inp []byte) error { r := bytes.NewReader(inp) d := xdr.NewDecoder(r) - _, err := s.DecodeFrom(d) + _, err := s.DecodeFrom(d, xdr.DecodeDefaultMaxDepth) return err } @@ -1390,7 +1469,7 @@ switch Color(color) { default: tv, ok := value.(int32) if !ok { - err = fmt.Errorf("invalid value, must be int32") + err = errors.New("invalid value, must be int32") return } result.Blah2 = &tv @@ -1442,13 +1521,17 @@ return nil var _ decoderFrom = (*NesterNestedUnion)(nil) // DecodeFrom decodes this value using the Decoder. -func (u *NesterNestedUnion) DecodeFrom(d *xdr.Decoder) (int, error) { +func (u *NesterNestedUnion) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) { + if maxDepth == 0 { + return 0, fmt.Errorf("decoding NesterNestedUnion: %w", ErrMaxDecodingDepthReached) + } + maxDepth -= 1 var err error var n, nTmp int - nTmp, err = u.Color.DecodeFrom(d) + nTmp, err = u.Color.DecodeFrom(d, maxDepth) n += nTmp if err != nil { - return n, fmt.Errorf("decoding Color: %s", err) + return n, fmt.Errorf("decoding Color: %w", err) } switch Color(u.Color) { case ColorRed: @@ -1459,7 +1542,7 @@ switch Color(u.Color) { (*u.Blah2), nTmp, err = d.DecodeInt() n += nTmp if err != nil { - return n, fmt.Errorf("decoding Int: %s", err) + return n, fmt.Errorf("decoding Int: %w", err) } return n, nil } @@ -1477,7 +1560,7 @@ func (s NesterNestedUnion) MarshalBinary() ([]byte, error) { func (s *NesterNestedUnion) UnmarshalBinary(inp []byte) error { r := bytes.NewReader(inp) d := xdr.NewDecoder(r) - _, err := s.DecodeFrom(d) + _, err := s.DecodeFrom(d, xdr.DecodeDefaultMaxDepth) return err } @@ -1538,23 +1621,27 @@ if err = s.NestedUnion.EncodeTo(e); err != nil { var _ decoderFrom = (*Nester)(nil) // DecodeFrom decodes this value using the Decoder. -func (s *Nester) DecodeFrom(d *xdr.Decoder) (int, error) { +func (s *Nester) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) { + if maxDepth == 0 { + return 0, fmt.Errorf("decoding Nester: %w", ErrMaxDecodingDepthReached) + } + maxDepth -= 1 var err error var n, nTmp int - nTmp, err = s.NestedEnum.DecodeFrom(d) + nTmp, err = s.NestedEnum.DecodeFrom(d, maxDepth) n += nTmp if err != nil { - return n, fmt.Errorf("decoding NesterNestedEnum: %s", err) + return n, fmt.Errorf("decoding NesterNestedEnum: %w", err) } - nTmp, err = s.NestedStruct.DecodeFrom(d) + nTmp, err = s.NestedStruct.DecodeFrom(d, maxDepth) n += nTmp if err != nil { - return n, fmt.Errorf("decoding NesterNestedStruct: %s", err) + return n, fmt.Errorf("decoding NesterNestedStruct: %w", err) } - nTmp, err = s.NestedUnion.DecodeFrom(d) + nTmp, err = s.NestedUnion.DecodeFrom(d, maxDepth) n += nTmp if err != nil { - return n, fmt.Errorf("decoding NesterNestedUnion: %s", err) + return n, fmt.Errorf("decoding NesterNestedUnion: %w", err) } return n, nil } @@ -1571,7 +1658,7 @@ func (s Nester) MarshalBinary() ([]byte, error) { func (s *Nester) UnmarshalBinary(inp []byte) error { r := bytes.NewReader(inp) d := xdr.NewDecoder(r) - _, err := s.DecodeFrom(d) + _, err := s.DecodeFrom(d, xdr.DecodeDefaultMaxDepth) return err } diff --git a/spec/output/generator_spec_go/union.x/MyXDR_generated.go b/spec/output/generator_spec_go/union.x/MyXDR_generated.go index c15e8edc1..980ddd3f8 100644 --- a/spec/output/generator_spec_go/union.x/MyXDR_generated.go +++ b/spec/output/generator_spec_go/union.x/MyXDR_generated.go @@ -11,6 +11,7 @@ package MyXDR import ( "bytes" "encoding" + "errors" "io" "fmt" @@ -22,19 +23,21 @@ var XdrFilesSHA256 = map[string]string{ "spec/fixtures/generator/union.x": "c251258d967223b341ebcf2d5bb0718e9a039b46232cb743865d9acd0c4bbe41", } +var ErrMaxDecodingDepthReached = errors.New("maximum decoding depth reached") + type xdrType interface { xdrType() } type decoderFrom interface { - DecodeFrom(d *xdr.Decoder) (int, error) + DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) } // Unmarshal reads an xdr element from `r` into `v`. func Unmarshal(r io.Reader, v interface{}) (int, error) { if decodable, ok := v.(decoderFrom); ok { d := xdr.NewDecoder(r) - return decodable.DecodeFrom(d) + return decodable.DecodeFrom(d, xdr.DecodeDefaultMaxDepth) } // delegate to xdr package's Unmarshal return xdr.Unmarshal(r, v) @@ -72,14 +75,18 @@ if _, err = e.EncodeInt(int32(s)); err != nil { var _ decoderFrom = (*Error)(nil) // DecodeFrom decodes this value using the Decoder. -func (s *Error) DecodeFrom(d *xdr.Decoder) (int, error) { +func (s *Error) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) { + if maxDepth == 0 { + return 0, fmt.Errorf("decoding Error: %w", ErrMaxDecodingDepthReached) + } + maxDepth -= 1 var err error var n, nTmp int var v int32 v, nTmp, err = d.DecodeInt() n += nTmp if err != nil { - return n, fmt.Errorf("decoding Int: %s", err) + return n, fmt.Errorf("decoding Int: %w", err) } *s = Error(v) return n, nil @@ -97,7 +104,7 @@ func (s Error) MarshalBinary() ([]byte, error) { func (s *Error) UnmarshalBinary(inp []byte) error { r := bytes.NewReader(inp) d := xdr.NewDecoder(r) - _, err := s.DecodeFrom(d) + _, err := s.DecodeFrom(d, xdr.DecodeDefaultMaxDepth) return err } @@ -129,14 +136,18 @@ if _, err = e.EncodeInt(int32(s)); err != nil { var _ decoderFrom = (*Multi)(nil) // DecodeFrom decodes this value using the Decoder. -func (s *Multi) DecodeFrom(d *xdr.Decoder) (int, error) { +func (s *Multi) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) { + if maxDepth == 0 { + return 0, fmt.Errorf("decoding Multi: %w", ErrMaxDecodingDepthReached) + } + maxDepth -= 1 var err error var n, nTmp int var v int32 v, nTmp, err = d.DecodeInt() n += nTmp if err != nil { - return n, fmt.Errorf("decoding Int: %s", err) + return n, fmt.Errorf("decoding Int: %w", err) } *s = Multi(v) return n, nil @@ -154,7 +165,7 @@ func (s Multi) MarshalBinary() ([]byte, error) { func (s *Multi) UnmarshalBinary(inp []byte) error { r := bytes.NewReader(inp) d := xdr.NewDecoder(r) - _, err := s.DecodeFrom(d) + _, err := s.DecodeFrom(d, xdr.DecodeDefaultMaxDepth) return err } @@ -208,10 +219,14 @@ func (e UnionKey) EncodeTo(enc *xdr.Encoder) error { } var _ decoderFrom = (*UnionKey)(nil) // DecodeFrom decodes this value using the Decoder. -func (e *UnionKey) DecodeFrom(d *xdr.Decoder) (int, error) { +func (e *UnionKey) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) { + if maxDepth == 0 { + return 0, fmt.Errorf("decoding UnionKey: %w", ErrMaxDecodingDepthReached) + } + maxDepth -= 1 v, n, err := d.DecodeInt() if err != nil { - return n, fmt.Errorf("decoding UnionKey: %s", err) + return n, fmt.Errorf("decoding UnionKey: %w", err) } if _, ok := unionKeyMap[v]; !ok { return n, fmt.Errorf("'%d' is not a valid UnionKey enum value", v) @@ -231,7 +246,7 @@ func (s UnionKey) MarshalBinary() ([]byte, error) { func (s *UnionKey) UnmarshalBinary(inp []byte) error { r := bytes.NewReader(inp) d := xdr.NewDecoder(r) - _, err := s.DecodeFrom(d) + _, err := s.DecodeFrom(d, xdr.DecodeDefaultMaxDepth) return err } @@ -289,14 +304,14 @@ switch UnionKey(aType) { case UnionKeyError: tv, ok := value.(Error) if !ok { - err = fmt.Errorf("invalid value, must be Error") + err = errors.New("invalid value, must be Error") return } result.Error = &tv case UnionKeyMulti: tv, ok := value.([]Multi) if !ok { - err = fmt.Errorf("invalid value, must be []Multi") + err = errors.New("invalid value, must be []Multi") return } result.Things = &tv @@ -380,21 +395,25 @@ return nil var _ decoderFrom = (*MyUnion)(nil) // DecodeFrom decodes this value using the Decoder. -func (u *MyUnion) DecodeFrom(d *xdr.Decoder) (int, error) { +func (u *MyUnion) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) { + if maxDepth == 0 { + return 0, fmt.Errorf("decoding MyUnion: %w", ErrMaxDecodingDepthReached) + } + maxDepth -= 1 var err error var n, nTmp int - nTmp, err = u.Type.DecodeFrom(d) + nTmp, err = u.Type.DecodeFrom(d, maxDepth) n += nTmp if err != nil { - return n, fmt.Errorf("decoding UnionKey: %s", err) + return n, fmt.Errorf("decoding UnionKey: %w", err) } switch UnionKey(u.Type) { case UnionKeyError: u.Error = new(Error) - nTmp, err = (*u.Error).DecodeFrom(d) + nTmp, err = (*u.Error).DecodeFrom(d, maxDepth) n += nTmp if err != nil { - return n, fmt.Errorf("decoding Error: %s", err) + return n, fmt.Errorf("decoding Error: %w", err) } return n, nil case UnionKeyMulti: @@ -403,16 +422,16 @@ if err != nil { l, nTmp, err = d.DecodeUint() n += nTmp if err != nil { - return n, fmt.Errorf("decoding Multi: %s", err) + return n, fmt.Errorf("decoding Multi: %w", err) } (*u.Things) = nil if l > 0 { (*u.Things) = make([]Multi, l) for i := uint32(0); i < l; i++ { - nTmp, err = (*u.Things)[i].DecodeFrom(d) + nTmp, err = (*u.Things)[i].DecodeFrom(d, maxDepth) n += nTmp if err != nil { - return n, fmt.Errorf("decoding Multi: %s", err) + return n, fmt.Errorf("decoding Multi: %w", err) } } } @@ -433,7 +452,7 @@ func (s MyUnion) MarshalBinary() ([]byte, error) { func (s *MyUnion) UnmarshalBinary(inp []byte) error { r := bytes.NewReader(inp) d := xdr.NewDecoder(r) - _, err := s.DecodeFrom(d) + _, err := s.DecodeFrom(d, xdr.DecodeDefaultMaxDepth) return err } @@ -490,14 +509,14 @@ switch int32(aType) { case 0: tv, ok := value.(Error) if !ok { - err = fmt.Errorf("invalid value, must be Error") + err = errors.New("invalid value, must be Error") return } result.Error = &tv case 1: tv, ok := value.([]Multi) if !ok { - err = fmt.Errorf("invalid value, must be []Multi") + err = errors.New("invalid value, must be []Multi") return } result.Things = &tv @@ -581,21 +600,25 @@ return nil var _ decoderFrom = (*IntUnion)(nil) // DecodeFrom decodes this value using the Decoder. -func (u *IntUnion) DecodeFrom(d *xdr.Decoder) (int, error) { +func (u *IntUnion) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) { + if maxDepth == 0 { + return 0, fmt.Errorf("decoding IntUnion: %w", ErrMaxDecodingDepthReached) + } + maxDepth -= 1 var err error var n, nTmp int u.Type, nTmp, err = d.DecodeInt() n += nTmp if err != nil { - return n, fmt.Errorf("decoding Int: %s", err) + return n, fmt.Errorf("decoding Int: %w", err) } switch int32(u.Type) { case 0: u.Error = new(Error) - nTmp, err = (*u.Error).DecodeFrom(d) + nTmp, err = (*u.Error).DecodeFrom(d, maxDepth) n += nTmp if err != nil { - return n, fmt.Errorf("decoding Error: %s", err) + return n, fmt.Errorf("decoding Error: %w", err) } return n, nil case 1: @@ -604,16 +627,16 @@ if err != nil { l, nTmp, err = d.DecodeUint() n += nTmp if err != nil { - return n, fmt.Errorf("decoding Multi: %s", err) + return n, fmt.Errorf("decoding Multi: %w", err) } (*u.Things) = nil if l > 0 { (*u.Things) = make([]Multi, l) for i := uint32(0); i < l; i++ { - nTmp, err = (*u.Things)[i].DecodeFrom(d) + nTmp, err = (*u.Things)[i].DecodeFrom(d, maxDepth) n += nTmp if err != nil { - return n, fmt.Errorf("decoding Multi: %s", err) + return n, fmt.Errorf("decoding Multi: %w", err) } } } @@ -634,7 +657,7 @@ func (s IntUnion) MarshalBinary() ([]byte, error) { func (s *IntUnion) UnmarshalBinary(inp []byte) error { r := bytes.NewReader(inp) d := xdr.NewDecoder(r) - _, err := s.DecodeFrom(d) + _, err := s.DecodeFrom(d, xdr.DecodeDefaultMaxDepth) return err } @@ -707,13 +730,17 @@ if err = IntUnion(s).EncodeTo(e); err != nil { var _ decoderFrom = (*IntUnion2)(nil) // DecodeFrom decodes this value using the Decoder. -func (s *IntUnion2) DecodeFrom(d *xdr.Decoder) (int, error) { +func (s *IntUnion2) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) { + if maxDepth == 0 { + return 0, fmt.Errorf("decoding IntUnion2: %w", ErrMaxDecodingDepthReached) + } + maxDepth -= 1 var err error var n, nTmp int - nTmp, err = (*IntUnion)(s).DecodeFrom(d) + nTmp, err = (*IntUnion)(s).DecodeFrom(d, maxDepth) n += nTmp if err != nil { - return n, fmt.Errorf("decoding IntUnion: %s", err) + return n, fmt.Errorf("decoding IntUnion: %w", err) } return n, nil } @@ -730,7 +757,7 @@ func (s IntUnion2) MarshalBinary() ([]byte, error) { func (s *IntUnion2) UnmarshalBinary(inp []byte) error { r := bytes.NewReader(inp) d := xdr.NewDecoder(r) - _, err := s.DecodeFrom(d) + _, err := s.DecodeFrom(d, xdr.DecodeDefaultMaxDepth) return err }