@@ -42,7 +42,7 @@ def __init__(self, env: VectorEnv, device: Device | None = None):
42
42
"""Wrapper class to change inputs and outputs of environment to PyTorch tensors.
43
43
44
44
Args:
45
- env: The Jax -based vector environment to wrap
45
+ env: The NumPy -based vector environment to wrap
46
46
device: The device the torch Tensors should be moved to
47
47
"""
48
48
super ().__init__ (env )
@@ -60,8 +60,8 @@ def step(
60
60
Returns:
61
61
The PyTorch-based Tensor next observation, reward, termination, truncation, and extra info
62
62
"""
63
- jax_action = torch_to_numpy (actions )
64
- obs , reward , terminated , truncated , info = self .env .step (jax_action )
63
+ numpy_action = torch_to_numpy (actions )
64
+ obs , reward , terminated , truncated , info = self .env .step (numpy_action )
65
65
66
66
return (
67
67
numpy_to_torch (obs , self .device ),
@@ -81,7 +81,7 @@ def reset(
81
81
82
82
Args:
83
83
seed: The seed for resetting the environment
84
- options: The options for resetting the environment, these are converted to jax arrays.
84
+ options: The options for resetting the environment, these are converted to NumPy arrays.
85
85
86
86
Returns:
87
87
PyTorch-based observations and info
0 commit comments