Skip to content

Commit

Permalink
Merge branch 'develop' into toni/truncated_signal_not_treated
Browse files Browse the repository at this point in the history
  • Loading branch information
Toni-SM committed Jan 16, 2025
2 parents f44465a + 7f9992b commit c6ece72
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 4 deletions.
5 changes: 3 additions & 2 deletions skrl/memories/jax/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,8 @@ def create_tensor(
:rtype: bool
"""
# compute data size
size = compute_space_size(size, occupied_size=True)
if not keep_dimensions:
size = compute_space_size(size, occupied_size=True)
# check dtype and size if the tensor exists
if name in self.tensors:
tensor = self.tensors[name]
Expand Down Expand Up @@ -406,7 +407,7 @@ def sample_all(
# sequential order
if sequence_length > 1:
if mini_batches > 1:
batches = np.array_split(self.all_sequence_indexes, len(self.all_sequence_indexes) // mini_batches)
batches = np.array_split(self.all_sequence_indexes, mini_batches)
return [[self._get_tensors_view(name)[batch] for name in names] for batch in batches]
return [[self._get_tensors_view(name)[self.all_sequence_indexes] for name in names]]

Expand Down
5 changes: 3 additions & 2 deletions skrl/memories/torch/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,8 @@ def create_tensor(
:rtype: bool
"""
# compute data size
size = compute_space_size(size, occupied_size=True)
if not keep_dimensions:
size = compute_space_size(size, occupied_size=True)
# check dtype and size if the tensor exists
if name in self.tensors:
tensor = self.tensors[name]
Expand Down Expand Up @@ -342,7 +343,7 @@ def sample_all(
# sequential order
if sequence_length > 1:
if mini_batches > 1:
batches = np.array_split(self.all_sequence_indexes, len(self.all_sequence_indexes) // mini_batches)
batches = np.array_split(self.all_sequence_indexes, mini_batches)
return [[self.tensors_view[name][batch] for name in names] for batch in batches]
return [[self.tensors_view[name][self.all_sequence_indexes] for name in names]]

Expand Down

0 comments on commit c6ece72

Please sign in to comment.