Skip to content

Commit

Permalink
fix tests and layout_generator after merge
Browse files Browse the repository at this point in the history
  • Loading branch information
bmielnicki committed Aug 29, 2020
1 parent 2d130e9 commit af85350
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 39 deletions.
58 changes: 34 additions & 24 deletions src/overcooked_ai_py/mdp/layout_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,47 +128,49 @@ def generate_padded_mdp(self, outside_information={}):
Return a PADDED MDP with mdp params specified in self.mdp_params
"""
mdp_gen_params = self.mdp_params_generator.generate(outside_information)

outer_shape = self.outer_shape
if "layout_name" in mdp_gen_params.keys() and mdp_gen_params["layout_name"] is not None:
mdp = OvercookedGridworld.from_layout_name(**mdp_gen_params)
mdp_generator_fn = lambda: self.padded_mdp(mdp)
else:

required_keys = ["inner_shape", "prop_empty", "prop_feats", "display"]
# with generate_all_orders key start_all_orders will be generated inside make_new_layout method
if not mdp_gen_params.get("generate_all_orders"):
required_keys.append("start_all_orders")
missing_keys = [k for k in required_keys if k not in mdp_gen_params.keys()]
if len(missing_keys) != 0:
print("missing keys dict", mdp_gen_params)
assert len(missing_keys) == 0, "These keys were missing from the mdp_params: {}".format(missing_keys)
inner_shape = mdp_gen_params["inner_shape"]
assert inner_shape[0] <= outer_shape[0] and inner_shape[1] <= outer_shape[1], \
"inner_shape cannot fit into the outershap"
layout_generator = LayoutGenerator(self.mdp_params_generator, outer_shape=self.outer_shape)

if "start_all_orders" in mdp_gen_params:
recipe_params = {"start_all_orders": mdp_gen_params["start_all_orders"]}
if "recipe_values" in mdp_gen_params:
recipe_params["recipe_values"] = mdp_gen_params["recipe_values"]
if "recipe_times" in mdp_gen_params:
recipe_params["recipe_times"] = mdp_gen_params["recipe_times"]
else:
recipe_params = LayoutGenerator.add_generated_mdp_params_orders(self.mdp_params)

if "feature_types" not in mdp_gen_params:
mdp_gen_params["feature_types"] = DEFAULT_FEATURE_TYPES

mdp_generator_fn = lambda: layout_generator.make_disjoint_sets_layout(
inner_shape=mdp_gen_params["inner_shape"],
prop_empty=mdp_gen_params["prop_empty"],
prop_features=mdp_gen_params["prop_feats"],
base_param=recipe_params,
feature_types=mdp_gen_params["feature_types"],
display=mdp_gen_params["display"]
)

mdp_generator_fn = lambda: layout_generator.make_new_layout(mdp_gen_params)
return mdp_generator_fn()


@staticmethod
def create_base_params(mdp_gen_params):
assert mdp_gen_params.get("start_all_orders") or mdp_gen_params.get("generate_all_orders")
mdp_gen_params = LayoutGenerator.add_generated_mdp_params_orders(mdp_gen_params)
recipe_params = {"start_all_orders": mdp_gen_params["start_all_orders"]}
if mdp_gen_params.get("start_bonus_orders"):
recipe_params["start_bonus_orders"] = mdp_gen_params["start_bonus_orders"]
if "recipe_values" in mdp_gen_params:
recipe_params["recipe_values"] = mdp_gen_params["recipe_values"]
if "recipe_times" in mdp_gen_params:
recipe_params["recipe_times"] = mdp_gen_params["recipe_times"]
return recipe_params

@staticmethod
def add_generated_mdp_params_orders(mdp_params):
"""
adds generated parameters (i.e. generated orders) to mdp_params
adds generated parameters (i.e. generated orders) to mdp_params,
returns onchanged copy of mdp_params when there is no "generate_all_orders" and "generate_bonus_orders" keys inside mdp_params
"""
mdp_params = copy.deepcopy(mdp_params)
if mdp_params.get("generate_all_orders"):
Expand Down Expand Up @@ -199,10 +201,18 @@ def padded_mdp(self, mdp, display=False):

start_positions = self.get_random_starting_positions(padded_grid)
mdp_grid = self.padded_grid_to_layout_grid(padded_grid, start_positions, display=display)
return OvercookedGridworld.from_grid(mdp_grid)

def make_new_layout(self, mdp_gen_params):
return self.make_disjoint_sets_layout(
inner_shape=mdp_gen_params["inner_shape"],
prop_empty=mdp_gen_params["prop_empty"],
prop_features=mdp_gen_params["prop_feats"],
base_param=LayoutGenerator.create_base_params(mdp_gen_params),
feature_types=mdp_gen_params["feature_types"],
display=mdp_gen_params["display"]
)

mdp_params = LayoutGenerator.add_generated_mdp_params_orders(self.mdp_params)
return OvercookedGridworld.from_grid(mdp_grid, base_layout_params=mdp_params)

def make_disjoint_sets_layout(self, inner_shape, prop_empty, prop_features, base_param, feature_types=DEFAULT_FEATURE_TYPES, display=True):
grid = Grid(inner_shape)
self.dig_space_with_disjoint_sets(grid, prop_empty)
Expand Down
48 changes: 33 additions & 15 deletions testing/overcooked_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -899,9 +899,15 @@ def test_random_layout_feature_types(self):
for optional_features_combo in optional_features_combinations:
left_out_optional_features = optional_features - optional_features_combo
used_features = list(optional_features_combo | mandatory_features)
mdp_gen_params = {"prop_feats": (1, 1),
"feature_types": used_features}
mdp_fn = LayoutGenerator.mdp_gen_fn_from_dict(**mdp_gen_params)
mdp_gen_params = {"prop_feats": 0.9,
"feature_types": used_features,
"prop_empty": 0.1,
"inner_shape": (6, 5),
"display": False,
"start_all_orders" : [
{ "ingredients" : ["onion", "onion", "onion"]}
]}
mdp_fn = LayoutGenerator.mdp_gen_fn_from_dict(mdp_gen_params, outer_shape=(6, 5))
env = OvercookedEnv(mdp_fn, **DEFAULT_ENV_PARAMS)
for _ in range(10):
env.reset()
Expand All @@ -916,31 +922,43 @@ def test_random_layout_generated_recipes(self):
only_onions_dict_recipes = [r.to_dict() for r in only_onions_recipes]

# checking if recipes are generated from mdp_params
mdp_params = {"generate_all_orders": {"n":2, "ingredients": ["onion"], "min_size":2, "max_size":3}}
mdp_gen_params = {"mdp_params": mdp_params}
mdp_fn = LayoutGenerator.mdp_gen_fn_from_dict(**mdp_gen_params)
mdp_gen_params = {"generate_all_orders": {"n":2, "ingredients": ["onion"], "min_size":2, "max_size":3},
"prop_feats": 0.9,
"prop_empty": 0.1,
"inner_shape": (6, 5),
"display": False}
mdp_fn = LayoutGenerator.mdp_gen_fn_from_dict(mdp_gen_params, outer_shape=(6, 5))
env = OvercookedEnv(mdp_fn, **DEFAULT_ENV_PARAMS)
for _ in range(10):
env.reset()
self.assertCountEqual(env.mdp.start_all_orders, only_onions_dict_recipes)
self.assertTrue(len(env.mdp.start_bonus_orders) == 0)
self.assertEqual(len(env.mdp.start_bonus_orders), 0)

# checking if bonus_orders is subset of all_orders even if not specified
mdp_params = {"generate_all_orders": {"n":2, "ingredients": ["onion"], "min_size":2, "max_size":3},
"generate_bonus_orders": {"n":1, "min_size":2, "max_size":3}}
mdp_gen_params = {"mdp_params": mdp_params}
mdp_fn = LayoutGenerator.mdp_gen_fn_from_dict(**mdp_gen_params)

mdp_gen_params = {"generate_all_orders": {"n":2, "ingredients": ["onion"], "min_size":2, "max_size":3},
"generate_bonus_orders": {"n":1, "min_size":2, "max_size":3},
"prop_feats": 0.9,
"prop_empty": 0.1,
"inner_shape": (6, 5),
"display": False}
mdp_fn = LayoutGenerator.mdp_gen_fn_from_dict(mdp_gen_params, outer_shape=(6,5))
env = OvercookedEnv(mdp_fn, **DEFAULT_ENV_PARAMS)
for _ in range(10):
env.reset()
self.assertCountEqual(env.mdp.start_all_orders, only_onions_dict_recipes)
self.assertTrue(len(env.mdp.start_bonus_orders) == 1)
self.assertEqual(len(env.mdp.start_bonus_orders), 1)
self.assertTrue(env.mdp.start_bonus_orders[0] in only_onions_dict_recipes)

# checking if after reset there are new recipes generated
mdp_params = {"generate_all_orders": {"n":3, "min_size":2, "max_size":3}}
mdp_gen_params = {"mdp_params": mdp_params}
mdp_fn = LayoutGenerator.mdp_gen_fn_from_dict(**mdp_gen_params)
mdp_gen_params = {"generate_all_orders": {"n":3, "min_size":2, "max_size":3},
"prop_feats": 0.9,
"prop_empty": 0.1,
"inner_shape": (6, 5),
"display": False,
"feature_types": [POT, DISH_DISPENSER, SERVING_LOC, ONION_DISPENSER, TOMATO_DISPENSER]
}
mdp_fn = LayoutGenerator.mdp_gen_fn_from_dict(mdp_gen_params, outer_shape=(6,5))
env = OvercookedEnv(mdp_fn, **DEFAULT_ENV_PARAMS)
generated_recipes_strings = set()
for _ in range(20):
Expand Down

0 comments on commit af85350

Please sign in to comment.