Skip to content

Commit ae18e03

Browse files
authored
Fix var names and docs to mention numpy instead of jax in vector NumpyToTorch
Seems that NumpyToTorch was made from JaxToTorch, so some artifacts of referring jax instead of numpy in docs and variable names exist.
1 parent 9ff8bf4 commit ae18e03

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

gymnasium/wrappers/vector/numpy_to_torch.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def __init__(self, env: VectorEnv, device: Device | None = None):
4242
"""Wrapper class to change inputs and outputs of environment to PyTorch tensors.
4343
4444
Args:
45-
env: The Jax-based vector environment to wrap
45+
env: The NumPy-based vector environment to wrap
4646
device: The device the torch Tensors should be moved to
4747
"""
4848
super().__init__(env)
@@ -60,8 +60,8 @@ def step(
6060
Returns:
6161
The PyTorch-based Tensor next observation, reward, termination, truncation, and extra info
6262
"""
63-
jax_action = torch_to_numpy(actions)
64-
obs, reward, terminated, truncated, info = self.env.step(jax_action)
63+
numpy_action = torch_to_numpy(actions)
64+
obs, reward, terminated, truncated, info = self.env.step(numpy_action)
6565

6666
return (
6767
numpy_to_torch(obs, self.device),
@@ -81,7 +81,7 @@ def reset(
8181
8282
Args:
8383
seed: The seed for resetting the environment
84-
options: The options for resetting the environment, these are converted to jax arrays.
84+
options: The options for resetting the environment, these are converted to NumPy arrays.
8585
8686
Returns:
8787
PyTorch-based observations and info

0 commit comments

Comments
 (0)