Skip to content

Commit

Permalink
Fix ArrowDecoder.decode to return instead of yield (#2976)
Browse files Browse the repository at this point in the history
  • Loading branch information
lostella authored Aug 22, 2023
1 parent db14c0f commit c11f91f
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 1 deletion.
2 changes: 1 addition & 1 deletion src/gluonts/dataset/arrow/dec.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def from_schema(cls, schema):
)

def decode(self, batch, row_number: int):
yield from self.decode_batch(batch.slice(row_number, row_number + 1))
return next(self.decode_batch(batch.slice(row_number, row_number + 1)))

def decode_batch(self, batch):
for row in batch.to_pandas().to_dict("records"):
Expand Down
2 changes: 2 additions & 0 deletions test/dataset/test_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ def test_arrow(writer, flatten_arrays):

assert_equal(data, dataset)

assert_equal(dataset[4], data[4])

assert len(dataset[:5]) == len(data[:5])
assert_equal(dataset[:5], data[:5])

Expand Down

0 comments on commit c11f91f

Please sign in to comment.