Skip to content

Commit

Permalink
fix(nyz): fix env check multi-discrete bug (#852)
Browse files Browse the repository at this point in the history
  • Loading branch information
PaParaZz1 committed Jan 23, 2025
1 parent f5157c7 commit 4e92de5
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 19 deletions.
37 changes: 18 additions & 19 deletions ding/envs/env/env_implementation_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,38 +24,37 @@ def check_space_dtype(env: 'BaseEnv') -> None:


# Util function
def check_array_space(ndarray: Union[np.ndarray, Sequence, Dict], space: Union['Space', Dict], name: str) -> None:
if isinstance(ndarray, np.ndarray):
def check_array_space(data: Union[np.ndarray, Sequence, Dict], space: Union['Space', Dict], name: str) -> None:
if isinstance(data, np.ndarray):
# print("{}'s type should be np.ndarray".format(name))
assert ndarray.dtype == space.dtype, "{}'s dtype is {}, but requires {}".format(
name, ndarray.dtype, space.dtype
)
assert ndarray.shape == space.shape, "{}'s shape is {}, but requires {}".format(
name, ndarray.shape, space.shape
)
assert data.dtype == space.dtype, "{}'s dtype is {}, but requires {}".format(name, data.dtype, space.dtype)
assert data.shape == space.shape, "{}'s shape is {}, but requires {}".format(name, data.shape, space.shape)
if isinstance(space, Box):
assert (space.low <= ndarray).all() and (ndarray <= space.high).all(
), "{}'s value is {}, but requires in range ({},{})".format(name, ndarray, space.low, space.high)
assert (space.low <= data).all() and (data <= space.high).all(
), "{}'s value is {}, but requires in range ({},{})".format(name, data, space.low, space.high)
elif isinstance(space, (Discrete, MultiDiscrete, MultiBinary)):
print(space.start, space.n)
assert (ndarray >= space.start) and (ndarray <= space.n)
elif isinstance(ndarray, Sequence):
for i in range(len(ndarray)):
if isinstance(space, Discrete):
assert (data >= space.start) and (data <= space.n)
else:
assert (data >= 0).all()
assert all([d < n for d, n in zip(data, space.nvec)])
elif isinstance(data, Sequence):
for i in range(len(data)):
try:
check_array_space(ndarray[i], space[i], name)
check_array_space(data[i], space[i], name)
except AssertionError as e:
print("The following error happens at {}-th index".format(i))
raise e
elif isinstance(ndarray, dict):
for k in ndarray.keys():
elif isinstance(data, dict):
for k in data.keys():
try:
check_array_space(ndarray[k], space[k], name)
check_array_space(data[k], space[k], name)
except AssertionError as e:
print("The following error happens at key {}".format(k))
raise e
else:
raise TypeError(
"Input array should be np.ndarray or sequence/dict of np.ndarray, but found {}".format(type(ndarray))
"Input array should be np.ndarray or sequence/dict of np.ndarray, but found {}".format(type(data))
)


Expand Down
5 changes: 5 additions & 0 deletions ding/envs/env/tests/test_env_implementation_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@ def test_check_array_space():
discrete_array = np.array(11, dtype=np.int64)
with pytest.raises(AssertionError):
check_array_space(discrete_array, discrete_space, 'test_discrete')

multi_discrete_space = gym.spaces.MultiDiscrete([2, 3])
multi_discrete_array = np.array([1, 2], dtype=np.int64)
check_array_space(multi_discrete_array, multi_discrete_space, 'test_multi_discrete')

seq_array = (np.array([1, 2, 3], dtype=np.int64), np.array([4., 5., 6.], dtype=np.float32))
seq_space = [gym.spaces.Box(low=0, high=10, shape=(3, ), dtype=np.int64) for _ in range(2)]
with pytest.raises(AssertionError):
Expand Down

0 comments on commit 4e92de5

Please sign in to comment.