From 7f9992b6304fe105d799a31bd91b857d33d33b90 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Antonio=20Serrano=20Mu=C3=B1oz?= Date: Thu, 16 Jan 2025 15:22:58 -0500 Subject: [PATCH] Fix memory sampling when sequence_length is specified --- skrl/memories/jax/base.py | 5 +++-- skrl/memories/torch/base.py | 5 +++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/skrl/memories/jax/base.py b/skrl/memories/jax/base.py index 742ee8cc..7782a47a 100644 --- a/skrl/memories/jax/base.py +++ b/skrl/memories/jax/base.py @@ -188,7 +188,8 @@ def create_tensor( :rtype: bool """ # compute data size - size = compute_space_size(size, occupied_size=True) + if not keep_dimensions: + size = compute_space_size(size, occupied_size=True) # check dtype and size if the tensor exists if name in self.tensors: tensor = self.tensors[name] @@ -406,7 +407,7 @@ def sample_all( # sequential order if sequence_length > 1: if mini_batches > 1: - batches = np.array_split(self.all_sequence_indexes, len(self.all_sequence_indexes) // mini_batches) + batches = np.array_split(self.all_sequence_indexes, mini_batches) return [[self._get_tensors_view(name)[batch] for name in names] for batch in batches] return [[self._get_tensors_view(name)[self.all_sequence_indexes] for name in names]] diff --git a/skrl/memories/torch/base.py b/skrl/memories/torch/base.py index 9d6be2b4..670ce565 100644 --- a/skrl/memories/torch/base.py +++ b/skrl/memories/torch/base.py @@ -157,7 +157,8 @@ def create_tensor( :rtype: bool """ # compute data size - size = compute_space_size(size, occupied_size=True) + if not keep_dimensions: + size = compute_space_size(size, occupied_size=True) # check dtype and size if the tensor exists if name in self.tensors: tensor = self.tensors[name] @@ -342,7 +343,7 @@ def sample_all( # sequential order if sequence_length > 1: if mini_batches > 1: - batches = np.array_split(self.all_sequence_indexes, len(self.all_sequence_indexes) // mini_batches) + batches = np.array_split(self.all_sequence_indexes, mini_batches) return [[self.tensors_view[name][batch] for name in names] for batch in batches] return [[self.tensors_view[name][self.all_sequence_indexes] for name in names]]