Skip to content

Commit

Permalink
Convert .pyre_configuration.local to fast by default architecture] [b…
Browse files Browse the repository at this point in the history
…atch:32/764] [shard:13/N]

Reviewed By: MaggieMoss

Differential Revision: D63206193

fbshipit-source-id: 16839dc455defeb6748c3983eb32a79bf3328aff
  • Loading branch information
generatedunixname89002005307016 authored and facebook-github-bot committed Sep 22, 2024
1 parent 9989f40 commit c061b43
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 3 deletions.
14 changes: 12 additions & 2 deletions pearl/user_envs/envs/bandit.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
print("gymnasium module is not found")


# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
class MeanVarBanditEnv(gym.Env[np.ndarray, Union[int, np.ndarray]]):
"""environment to test if safe RL algorithms
prefer a policy that achieves lower variance return"""
Expand All @@ -33,20 +34,29 @@ def __init__(
self.observation_space = gym.spaces.Box(-high, high, dtype=np.float32)
self.idx: Optional[int] = None

# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
def get_observation(self) -> np.ndarray:
obs = np.zeros(self._size, dtype=np.float32)
obs[self.idx] = 1.0
return obs

def reset(
self, *, seed: Optional[int] = None, options: Optional[Dict[str, float]] = None
self,
*,
seed: Optional[int] = None,
options: Optional[Dict[str, float]] = None
# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
) -> Tuple[np.ndarray, Dict[str, float]]:
super().reset(seed=seed)
self.idx = 0
return self.get_observation(), {}

# pyre-fixme[3]: Return annotation cannot contain `Any`.
def step(
self, action: Union[int, np.ndarray]
self,
# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
action: Union[int, np.ndarray],
# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
) -> Tuple[np.ndarray, float, bool, bool, Dict[str, Any]]:
reward = 0.0
if action == 0:
Expand Down
6 changes: 5 additions & 1 deletion pearl/utils/instantiations/environments/gym_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def single_element_tensor_to_int(x: Tensor) -> int:
return int(x)


# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
def tensor_to_numpy(x: Tensor) -> np.ndarray:
return x.numpy(force=True)

Expand All @@ -51,6 +52,7 @@ def tensor_to_numpy(x: Tensor) -> np.ndarray:
"Box": BoxSpace,
# Add more here as needed
}
# pyre-fixme[5]: Global expression must be annotated.
PEARL_TO_GYM_ACTION = {
"Discrete": single_element_tensor_to_int,
"Box": tensor_to_numpy,
Expand Down Expand Up @@ -179,7 +181,9 @@ def __str__(self) -> str:


def _get_gym_action(
pearl_action: Action, gym_space: gym.Space
pearl_action: Action,
gym_space: gym.Space,
# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
) -> Union[int, np.ndarray]:
"""A helper function to convert a Pearl `Action` to an action compatible with
the Gym action space `gym_space`."""
Expand Down
2 changes: 2 additions & 0 deletions pearl/utils/instantiations/spaces/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def __init__(self, elements: List[Tensor], seed: Optional[int] = None) -> None:
if len(elements) == 0:
raise ValueError("`DiscreteSpace` requires at least one element.")
self._set_validated_elements(elements=elements) # sets self.elements
# pyre-fixme[28]: Unexpected keyword argument `start`.
self._gym_space = Discrete(n=len(elements), seed=seed, start=0)

def _set_validated_elements(self, elements: List[Tensor]) -> None:
Expand Down Expand Up @@ -94,6 +95,7 @@ def sample(self, mask: Optional[Tensor] = None) -> Tensor:
A randomly sampled (available) element.
"""
mask_np = mask.numpy().astype(int) if mask is not None else None
# pyre-fixme[28]: Unexpected keyword argument `mask`.
idx = self._gym_space.sample(mask=mask_np)
return self.elements[idx]

Expand Down
2 changes: 2 additions & 0 deletions test/unit/with_pytorch/test_cnn_based_q_value_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ def setUp(self) -> None:
self.num_data_points = 5000
self.mnist_train_dataset = Subset(
mnist_dataset,
# pyre-fixme[6]: For 2nd argument expected `Sequence[int]` but got
# `ndarray[typing.Any, dtype[typing.Any]]`.
np.arange(1, self.num_data_points),
)
self.learning_rate = 0.001
Expand Down
2 changes: 2 additions & 0 deletions test/unit/with_pytorch/test_vanilla_cnns.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ def setUp(self) -> None:
self.num_data_points = 5000
self.mnist_train_dataset = Subset(
mnist_dataset,
# pyre-fixme[6]: For 2nd argument expected `Sequence[int]` but got
# `ndarray[typing.Any, dtype[typing.Any]]`.
np.arange(1, self.num_data_points),
)
self.learning_rate = 0.001
Expand Down

0 comments on commit c061b43

Please sign in to comment.