Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Nov 19, 2024
2 parents 48e368f + bb09072 commit 6632069
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 2 deletions.
2 changes: 1 addition & 1 deletion .github/unittest/linux_sota/scripts/test_sota.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@ def run_command(command):
if output == "" and process.poll() is not None:
break
if output:
print(output.strip())
print(output.strip()) # noqa: T201
return_code = process.wait()
if return_code != 0:
raise subprocess.CalledProcessError(return_code, command)
Expand Down
23 changes: 23 additions & 0 deletions test/test_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -3172,6 +3172,29 @@ def make_and_test_policy(
)


@pytest.mark.parametrize(
"ctype", [SyncDataCollector, MultiaSyncDataCollector, MultiSyncDataCollector]
)
def test_no_stopiteration(ctype):
# Tests that there is no StopIteration raised and that the length of the collector is properly set
if ctype is SyncDataCollector:
envs = SerialEnv(16, CountingEnv)
else:
envs = [SerialEnv(8, CountingEnv), SerialEnv(8, CountingEnv)]

collector = ctype(create_env_fn=envs, frames_per_batch=173, total_frames=300)
try:
c_iter = iter(collector)
assert len(collector) == 2
for i in range(len(collector)): # noqa: B007
c = next(c_iter)
assert c is not None
assert i == 1
finally:
collector.shutdown()
del collector


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
3 changes: 2 additions & 1 deletion torchrl/collectors/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -740,8 +740,8 @@ def __init__(
f" ({-(-frames_per_batch // self.n_env) * self.n_env}). "
"To silence this message, set the environment variable RL_WARNINGS to False."
)
self.requested_frames_per_batch = int(frames_per_batch)
self.frames_per_batch = -(-frames_per_batch // self.n_env)
self.requested_frames_per_batch = self.frames_per_batch * self.n_env
self.exploration_type = (
exploration_type if exploration_type else DEFAULT_EXPLORATION_TYPE
)
Expand Down Expand Up @@ -1656,6 +1656,7 @@ def __init__(
self._get_weights_fn_dict[policy_device] = get_weights_fn
self.policy = policy

remainder = 0
if total_frames is None or total_frames < 0:
total_frames = float("inf")
else:
Expand Down

0 comments on commit 6632069

Please sign in to comment.