-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathwrappers.py
363 lines (285 loc) · 15.6 KB
/
wrappers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
import random
import numpy as np
from matplotlib import pyplot as plt
from minigrid.wrappers import FullyObsWrapper
from minigrid.core.constants import OBJECT_TO_IDX, COLOR_TO_IDX, STATE_TO_IDX
from sklearn.preprocessing import OneHotEncoder
from gymnasium import spaces
from customize_minigrid.custom_env import CustomEnv
class FullyObsSB3MLPWrapper(FullyObsWrapper):
def __init__(self, env: CustomEnv, to_print=False, ):
super().__init__(env)
self.env = env
self.to_print = to_print
# Initialise separate OneHotEncoders for each feature
self.object_encoder = OneHotEncoder(categories=[range(len(OBJECT_TO_IDX))], sparse_output=False)
self.colour_encoder = OneHotEncoder(categories=[range(len(COLOR_TO_IDX))], sparse_output=False)
# the state value can either be the state of the doors or direction of the agent.
self.state_encoder = OneHotEncoder(categories=[range(max(len(STATE_TO_IDX), 4))], sparse_output=False)
# direction_encoder = OneHotEncoder(categories=[range(4)], sparse=False)
# Fit each encoder to its corresponding range
self.object_encoder.fit(np.array(range(len(OBJECT_TO_IDX))).reshape(-1, 1))
self.colour_encoder.fit(np.array(range(len(COLOR_TO_IDX))).reshape(-1, 1))
self.state_encoder.fit(np.array(range(len(STATE_TO_IDX))).reshape(-1, 1))
# direction_encoder.fit(np.array(range(4)).reshape(-1, 1))
# Define the number of features per grid cell
self.num_object_features = len(OBJECT_TO_IDX) # Number of object types
self.num_colour_features = len(COLOR_TO_IDX) # Number of colours
self.num_state_features = max(len(STATE_TO_IDX), 4) # Number of additional states
# num_direction_features = 4 # One-hot for direction (4 possible directions)
self.num_carrying_features = len(OBJECT_TO_IDX)
self.num_carrying_colour_features = len(COLOR_TO_IDX)
# self.num_carrying_contains_features = len(OBJECT_TO_IDX)
# self.num_carrying_contains_colour_features = len(COLOR_TO_IDX)
# Total number of features for each grid cell
self.num_cell_features = self.num_object_features + self.num_colour_features + self.num_state_features
# Observation space shape after flattening and adding direction
self.num_cells = self.env.width * self.env.height
self.total_features = self.num_cells * self.num_cell_features # + num_direction_features
self.total_features += self.num_carrying_features + self.num_carrying_colour_features
self.total_features += self.num_carrying_features + self.num_carrying_colour_features
# self.total_features += self.num_carrying_contains_features + self.num_carrying_contains_colour_features
# Define the new observation space
self.observation_space = spaces.Box(
low=0,
high=1,
shape=(self.total_features,),
dtype=np.float32,
)
def observation(self, obs):
obs = super().observation(obs)
# Get the image part of the observation
image = obs['image'] # (grid_size, grid_size, 3)
# Flatten the image to (grid_size * grid_size, num_channels)
flattened_image = image.reshape(-1, image.shape[2])
# One-hot encode object types, colours, and states separately
object_types_onehot = self.object_encoder.transform(flattened_image[:, 0].reshape(-1, 1))
colours_onehot = self.colour_encoder.transform(flattened_image[:, 1].reshape(-1, 1))
states_onehot = self.state_encoder.transform(flattened_image[:, 2].reshape(-1, 1))
# Concatenate one-hot encodings for the grid cells
processed_obs = np.concatenate([object_types_onehot, colours_onehot, states_onehot], axis=1)
# Flatten processed observation
processed_obs_flat = processed_obs.flatten()
# Add direction as a separate feature (one-hot encoding)
# direction_onehot = direction_encoder.transform(np.array([obs['direction']]).reshape(-1, 1)).flatten()
# Add carried things as separate features (one-hot encoding)
carrying_onehot = self.object_encoder.transform(np.array([obs['carrying']['carrying']]).reshape(-1, 1)).flatten()
carrying_colour_onehot = self.colour_encoder.transform(np.array([obs['carrying']['carrying_colour']]).reshape(-1, 1)).flatten()
# carrying_contains_onehot = self.object_encoder.transform(np.array([obs['carrying']['carrying_contains']]).reshape(-1, 1)).flatten()
# carrying_contains_colour_onehot = self.colour_encoder.transform(np.array([obs['carrying']['carrying_contains_colour']]).reshape(-1, 1)).flatten()
# add overlapped things
overlap_onehot = self.object_encoder.transform(
np.array([obs['overlap']['obj']]).reshape(-1, 1)).flatten()
overlap_colour_onehot = self.colour_encoder.transform(
np.array([obs['overlap']['colour']]).reshape(-1, 1)).flatten()
# Concatenate the flattened grid encoding with the direction encoding and carried things
# not needed for direction because it's also in state layer.
final_obs = np.concatenate([processed_obs_flat, carrying_onehot, carrying_colour_onehot, overlap_onehot, overlap_colour_onehot]) # carrying_contains_onehot, carrying_contains_colour_onehot])
# final_obs = processed_obs_flat
if self.to_print:
# Print the image content and format
print(f"Image shape: {image.shape}")
# Print the grid with each layer separately
# Channel 0: Object types
print("Object Types (Channel 0):")
print(image[:, :, 0].transpose(1, 0))
# Channel 1: Colors
print("Colors (Channel 1):")
print(image[:, :, 1].transpose(1, 0))
# Channel 2: Additional State
print("Additional State (Channel 2):")
print(image[:, :, 2].transpose(1, 0))
direction = obs['direction']
mission = obs['mission']
# Print the direction and mission
print(f"Direction: {direction}")
print(f"Mission: {mission}")
print("final obs:")
print(final_obs)
return final_obs
def decode_to_original_obs(self, one_hot_vector: np.ndarray):
"""
Decodes a one-hot encoded vector back to the original obs dictionary structure using the class's encoders.
"""
assert len(one_hot_vector) == self.total_features, "Encoded vector length does not match the expected total features."
# Calculate where the grid encoding ends and the carried item encoding begins
grid_encoded_end = self.num_cells * self.num_cell_features
grid_encoded = one_hot_vector[:grid_encoded_end]
grid_obs = grid_encoded.reshape(self.env.width * self.env.height, self.num_cell_features)
# Decoding grid information
object_types = self.object_encoder.inverse_transform(grid_obs[:, :self.num_object_features]).reshape(self.env.width, self.env.height)
colours = self.colour_encoder.inverse_transform(grid_obs[:, self.num_object_features:self.num_object_features + self.num_colour_features]).reshape(self.env.width, self.env.height)
states = self.state_encoder.inverse_transform(grid_obs[:, self.num_object_features + self.num_colour_features:self.num_cell_features]).reshape(self.env.width, self.env.height)
# Constructing the image from the decoded object types, colours, and states
image = np.stack([object_types, colours, states], axis=-1)
# Decode carried item information
start_idx = grid_encoded_end
carrying = self.object_encoder.inverse_transform(one_hot_vector[start_idx:start_idx + self.num_object_features].reshape(1, -1))[0]
start_idx += self.num_object_features
carrying_colour = self.colour_encoder.inverse_transform(one_hot_vector[start_idx:start_idx + self.num_colour_features].reshape(1, -1))[0]
start_idx += self.num_colour_features
# carrying_contains = self.object_encoder.inverse_transform(one_hot_vector[start_idx:start_idx + self.num_carrying_contains_features].reshape(1, -1))[0]
# start_idx += self.num_carrying_contains_features
# carrying_contains_colour = self.colour_encoder.inverse_transform(one_hot_vector[start_idx:start_idx + self.num_carrying_contains_colour_features].reshape(1, -1))[0]
overlap = self.object_encoder.inverse_transform(
one_hot_vector[start_idx:start_idx + self.num_object_features].reshape(1, -1))[0]
start_idx += self.num_object_features
overlap_colour = self.colour_encoder.inverse_transform(
one_hot_vector[start_idx:start_idx + self.num_colour_features].reshape(1, -1))[0]
start_idx += self.num_colour_features
decoded_obs = {
'image': image,
'carrying': {
'carrying': carrying,
'carrying_colour': carrying_colour,
# 'carrying_contains': carrying_contains,
# 'carrying_contains_colour': carrying_contains_colour,
},
'overlap': {
"obj": overlap,
"colour": overlap_colour,
}
}
return decoded_obs
def set_env_with_code(self, one_hot_vector: np.ndarray):
decoded_obs = self.decode_to_original_obs(one_hot_vector)
obs, _ = self.env.set_env_by_obs(decoded_obs)
return self.observation(obs)
def force_reset(self):
self.env.skip_reset = False
self.env.reset()
def test_encode_decode_consistency(env: FullyObsSB3MLPWrapper, num_epochs=10, num_steps=10):
"""
Tests the consistency of the encode-decode process by comparing the original observation
with the decoded observation after encoding, over a number of steps.
Parameters:
- env: The environment instance with the FullyObsSB3MLPWrapper.
- num_epochs: Number of epochs to test the encoding and decoding process.
- num_steps: Number of steps to perform in each epoch for the test.
"""
for epoch in range(num_epochs):
encoded_vector, _ = env.reset() # Reset the environment to get the initial observation
broken = False
# actions = [1, 1, 3, 2, 5, 2, 2]
# Test the encode-decode consistency for each step
# for action, step in zip(actions, range(num_steps)):
for step in range(num_steps):
action = env.action_space.sample() # Random action
if action == 3:
pass
encoded_vector, _, done, truncated, _ = env.step(action) # Get new observation after action
# Check if the episode should end
if done or truncated:
env.env.skip_reset = False
print(f"Episode ended at epoch {epoch+1}, step {step+1}: {'due to environment termination.' if done else 'due to truncation.'}")
break
# Encode and decode the observation
decoded_obs = env.decode_to_original_obs(encoded_vector)
encoded_decoded_vector = env.observation(decoded_obs)
if not np.array_equal(encoded_vector, encoded_decoded_vector):
print(f"Test failed at epoch {epoch+1}, step {step+1}: The decoded image does not match the original image.")
broken = True
break # If a test fails, stop further testing
reconstructed_vector = env.set_env_with_code(encoded_vector)
if not np.array_equal(encoded_vector, reconstructed_vector):
print(f"Test failed at epoch {epoch+1}, step {step+1}: The reconstructed image does not match the original image.")
broken = True
break # If a test fails, stop further testing
if broken:
print("Stopping tests due to failure.")
break
else:
print("All tests passed successfully.")
class FullyObsImageWrapper(FullyObsSB3MLPWrapper):
def __init__(self, env: CustomEnv, to_print=False,):
super().__init__(env, to_print)
# Update the observation space to reflect the new 2D image output
self.observation_space = spaces.Box(
low=0.0,
high=1.0,
shape=(
3,
(self.env.height + 1) * self.env.get_wrapper_attr('tile_size'),
self.env.width * self.env.get_wrapper_attr('tile_size'),
),
dtype=np.float32,
)
def observation(self, obs):
image = self.env.get_frame(highlight=False, tile_size=self.env.get_wrapper_attr('tile_size'))
# Convert the image to a float32 numpy array in the range [0, 1]
image = image.astype(np.float32) / 255.0
# Reorder dimensions from (height, width, channels) to (channels, height, width)
image = np.transpose(image, (2, 0, 1))
if self.to_print:
# Display the 2D image array and additional information
print("2D Image array (normalized to 0-1):")
print(image)
print(f"Mission: {obs['mission']}")
print(f"Direction: {obs['direction']}")
# Return the image as the new observation
return image
def reset(self, **kwargs):
# Reset the environment and obtain the initial observation
obs, info = self.env.reset(**kwargs)
# Wrap the observation into the new format
image = self.observation(obs)
info['encoded_obs'] = super().observation(obs)
return image, info
def step(self, action):
# Perform the step in the environment
obs, reward, done, truncated, info = self.env.step(action)
# Wrap the observation into the new format
image = self.observation(obs)
info['encoded_obs'] = super().observation(obs)
return image, reward, done, truncated, info
class RandomChannelSwapWrapper(FullyObsImageWrapper):
def __init__(self, env: CustomEnv, seed=None, to_print=False):
super().__init__(env, to_print)
self.seed = seed
def observation(self, obs):
image = super().observation(obs)
if self.seed is not None:
random.seed(self.seed)
np.random.seed(self.seed)
channels = np.arange(image.shape[0])
np.random.shuffle(channels)
image = image[channels, :, :]
if self.to_print:
print("Randomly Swapped Image Array:")
print(image)
print(f"Mission: {obs['mission']}")
print(f"Direction: {obs['direction']}")
return image
if __name__ == '__main__':
from custom_env import CustomEnv
from minigrid.manual_control import ManualControl
# Initialize the environment and wrapper
env = CustomEnv(
txt_file_path=f"../maps/small_maze.txt",
rand_gen_shape=None,
display_size=5,
display_mode='middle',
random_rotate=False,
random_flip=False,
custom_mission="Explore and interact with objects.",
max_steps=128,
)
env = RandomChannelSwapWrapper(env, to_print=False)
obs, info = env.reset()
# Create a loop to step through the environment
while True:
# Sample a random action
action = env.action_space.sample()
# Take a step in the environment
obs, reward, done, truncated, info = env.step(action)
# Convert observation to numpy format for plotting
# Ensure it's transposed back to (height, width, channels)
image = np.transpose(obs, (1, 2, 0))
# Plot the image
plt.imshow(image)
plt.title("Environment Observation")
plt.axis("off") # Turn off axis for clarity
plt.show()
# Break the loop if the environment is done
if done:
obs, info = env.reset()