Skip to content

Commit

Permalink
release v.0.2
Browse files Browse the repository at this point in the history
  • Loading branch information
mayalenE committed Dec 7, 2022
1 parent d94ae7e commit 8ff29c5
Show file tree
Hide file tree
Showing 18 changed files with 3,875 additions and 658 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
*.npy
.idea/
*__pycache__/
*.pyc
Expand Down
23 changes: 7 additions & 16 deletions autodiscjax/experiment_pipelines/imgep_evaluation_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,10 @@
import os

def run_imgep_evaluation(jax_platform_name: str, seed: int, n_perturbations: int, save_folder: str,
intervention_params_library: adx.DictTree, intervention_fn: eqx.Module,
experiment_system_output_library: adx.DictTree, experiment_intervention_params_library: adx.DictTree, intervention_fn: eqx.Module,
perturbation_generator: eqx.Module, perturbation_fn: eqx.Module,
system_rollout: eqx.Module, rollout_statistics_encoder: eqx.Module,
goal_embedding_encoder: eqx.Module,
out_sanity_check=True):
out_sanity_check=True, save_modules=False):

# Set platform device
jax.config.update("jax_platform_name", jax_platform_name)
Expand All @@ -26,8 +25,6 @@ def run_imgep_evaluation(jax_platform_name: str, seed: int, n_perturbations: int
perturbation_generator.out_shape, perturbation_generator.out_dtype, is_leaf=lambda node: isinstance(node, tuple))
history.system_output_library = jtu.tree_map(lambda shape, dtype: jnp.empty(shape=(shape[0], 0, ) + shape[1:], dtype=dtype),
system_rollout.out_shape, system_rollout.out_dtype, is_leaf=lambda node: isinstance(node, tuple))
history.reached_goal_embedding_library = jtu.tree_map(lambda shape, dtype: jnp.empty(shape=(shape[0], 0, ) + shape[1:], dtype=dtype),
goal_embedding_encoder.out_shape, goal_embedding_encoder.out_dtype, is_leaf=lambda node: isinstance(node, tuple))
history.system_rollout_statistics_library = jtu.tree_map(lambda shape, dtype: jnp.empty(shape=(shape[0], 0, ) + shape[1:], dtype=dtype),
rollout_statistics_encoder.out_shape, rollout_statistics_encoder.out_dtype, is_leaf=lambda node: isinstance(node, tuple))

Expand All @@ -36,21 +33,16 @@ def run_imgep_evaluation(jax_platform_name: str, seed: int, n_perturbations: int

# generate perturbation
key, subkey = jrandom.split(key)
perturbations_params = perturbation_generator(subkey)
perturbations_params = perturbation_generator(subkey, experiment_system_output_library)
if out_sanity_check:
perturbation_generator.out_sanity_check(perturbations_params)

# rollout system
key, subkey = jrandom.split(key)
system_outputs = system_rollout(subkey, intervention_fn, intervention_params_library, perturbation_fn, perturbations_params)
system_outputs = system_rollout(subkey, intervention_fn, experiment_intervention_params_library, perturbation_fn, perturbations_params)
if out_sanity_check:
system_rollout.out_sanity_check(system_outputs)

# represent outputs -> goals
key, subkey = jrandom.split(key)
reached_goals_embeddings = goal_embedding_encoder(subkey, system_outputs)
if out_sanity_check:
goal_embedding_encoder.out_sanity_check(reached_goals_embeddings)

# represent outputs -> other statistics
key, subkey = jrandom.split(key)
Expand All @@ -62,15 +54,14 @@ def run_imgep_evaluation(jax_platform_name: str, seed: int, n_perturbations: int
# Append to history
perturbations_params = jtu.tree_map(lambda val: val[:, jnp.newaxis], perturbations_params)
system_outputs = jtu.tree_map(lambda val: val[:, jnp.newaxis], system_outputs)
reached_goals_embeddings = jtu.tree_map(lambda val: val[:, jnp.newaxis], reached_goals_embeddings)
system_rollouts_statistics = jtu.tree_map(lambda val: val[:, jnp.newaxis], system_rollouts_statistics)
history = history.update_node("perturbation_params_library", perturbations_params, merge_concatenate, axis=1)
history = history.update_node("system_output_library", system_outputs, merge_concatenate, axis=1)
history = history.update_node("reached_goal_embedding_library", reached_goals_embeddings, merge_concatenate, axis=1)
history = history.update_node("system_rollout_statistics_library", system_rollouts_statistics, merge_concatenate, axis=1)


# Save history and modules
history.save(os.path.join(save_folder, "evaluation_history.pickle"), overwrite=True)
eqx.tree_serialise_leaves(os.path.join(save_folder, "perturbation_generator.eqx"), perturbation_generator)
eqx.tree_serialise_leaves(os.path.join(save_folder, "perturbation_fn.eqx"), perturbation_fn)
if save_modules:
eqx.tree_serialise_leaves(os.path.join(save_folder, "perturbation_generator.eqx"), perturbation_generator)
eqx.tree_serialise_leaves(os.path.join(save_folder, "perturbation_fn.eqx"), perturbation_fn)
35 changes: 22 additions & 13 deletions autodiscjax/experiment_pipelines/imgep_experiment_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def run_imgep_experiment(jax_platform_name: str, seed: int, n_random_batches: in
gc_intervention_optimizer: eqx.Module,
goal_embedding_encoder: eqx.Module,
goal_achievement_loss: eqx.Module,
out_sanity_check=True):
out_sanity_check=True, save_modules=False):

# Set platform device
jax.config.update("jax_platform_name", jax_platform_name)
Expand Down Expand Up @@ -56,6 +56,7 @@ def run_imgep_experiment(jax_platform_name: str, seed: int, n_random_batches: in
for iteration_idx in range(n_random_batches + n_imgep_batches):

if iteration_idx < n_random_batches:
print("Generate random intervention")
# generate random intervention
key, subkey = jrandom.split(key)
interventions_params = random_intervention_generator(subkey)
Expand All @@ -70,43 +71,50 @@ def run_imgep_experiment(jax_platform_name: str, seed: int, n_random_batches: in

else:
# sample goal
print("Generate target goals")
key, subkey = jrandom.split(key)
target_goals_embeddings = goal_generator(subkey, history.reached_goal_embedding_library, history.system_rollout_statistics_library)
target_goals_embeddings = goal_generator(subkey, history.target_goal_embedding_library, history.reached_goal_embedding_library, history.system_rollout_statistics_library)
if out_sanity_check:
goal_generator.out_sanity_check(target_goals_embeddings)

# goal-conditioned selection of source intervention from history
print("Select closes intervention")
key, subkey = jrandom.split(key)
source_interventions_ids = gc_intervention_selector(subkey, target_goals_embeddings, history.reached_goal_embedding_library, history.system_rollout_statistics_library)
if out_sanity_check:
gc_intervention_selector.out_sanity_check(source_interventions_ids)
interventions_params = jtu.tree_map(lambda x: x[source_interventions_ids], history.intervention_params_library)

# goal-conditioned optimization of source intervention
print("Optimize the selected intervention")
key, subkey = jrandom.split(key)
interventions_params = gc_intervention_optimizer(subkey, intervention_fn, interventions_params, system_rollout, goal_embedding_encoder, goal_achievement_loss, target_goals_embeddings)
if out_sanity_check:
gc_intervention_optimizer.out_sanity_check(interventions_params)

# generate perturbation
print("Generate the perturbation")
key, subkey = jrandom.split(key)
perturbations_params = perturbation_generator(subkey)
if out_sanity_check:
perturbation_generator.out_sanity_check(perturbations_params)

# rollout system
print("Rollout the system")
key, subkey = jrandom.split(key)
system_outputs = system_rollout(subkey, intervention_fn, interventions_params, perturbation_fn, perturbations_params)
if out_sanity_check:
system_rollout.out_sanity_check(system_outputs)

# represent outputs -> goals
print("Encode the goal")
key, subkey = jrandom.split(key)
reached_goals_embeddings = goal_embedding_encoder(subkey, system_outputs)
if out_sanity_check:
goal_embedding_encoder.out_sanity_check(reached_goals_embeddings)

# represent outputs -> other statistics
print("Encode the rollout statistics")
key, subkey = jrandom.split(key)
system_rollouts_statistics = rollout_statistics_encoder(subkey, system_outputs)
if out_sanity_check:
Expand All @@ -125,14 +133,15 @@ def run_imgep_experiment(jax_platform_name: str, seed: int, n_random_batches: in

# Save history and modules
history.save(os.path.join(save_folder, "experiment_history.pickle"), overwrite=True)
eqx.tree_serialise_leaves(os.path.join(save_folder, "random_intervention_generator.eqx"), random_intervention_generator)
eqx.tree_serialise_leaves(os.path.join(save_folder, "intervention_fn.eqx"), intervention_fn)
eqx.tree_serialise_leaves(os.path.join(save_folder, "perturbation_generator.eqx"), perturbation_generator)
eqx.tree_serialise_leaves(os.path.join(save_folder, "perturbation_fn.eqx"), perturbation_fn)
eqx.tree_serialise_leaves(os.path.join(save_folder, "system_rollout.eqx"), system_rollout)
eqx.tree_serialise_leaves(os.path.join(save_folder, "rollout_statistics_encoder.eqx"), rollout_statistics_encoder)
eqx.tree_serialise_leaves(os.path.join(save_folder, "goal_generator.eqx"), goal_generator)
eqx.tree_serialise_leaves(os.path.join(save_folder, "gc_intervention_selector.eqx"), gc_intervention_selector)
eqx.tree_serialise_leaves(os.path.join(save_folder, "gc_intervention_optimizer.eqx"), gc_intervention_optimizer)
eqx.tree_serialise_leaves(os.path.join(save_folder, "goal_embedding_encoder.eqx"), goal_embedding_encoder)
eqx.tree_serialise_leaves(os.path.join(save_folder, "goal_achievement_loss.eqx"), goal_achievement_loss)
if save_modules:
eqx.tree_serialise_leaves(os.path.join(save_folder, "random_intervention_generator.eqx"), random_intervention_generator)
eqx.tree_serialise_leaves(os.path.join(save_folder, "intervention_fn.eqx"), intervention_fn)
eqx.tree_serialise_leaves(os.path.join(save_folder, "perturbation_generator.eqx"), perturbation_generator)
eqx.tree_serialise_leaves(os.path.join(save_folder, "perturbation_fn.eqx"), perturbation_fn)
eqx.tree_serialise_leaves(os.path.join(save_folder, "system_rollout.eqx"), system_rollout)
eqx.tree_serialise_leaves(os.path.join(save_folder, "rollout_statistics_encoder.eqx"), rollout_statistics_encoder)
eqx.tree_serialise_leaves(os.path.join(save_folder, "goal_generator.eqx"), goal_generator)
eqx.tree_serialise_leaves(os.path.join(save_folder, "gc_intervention_selector.eqx"), gc_intervention_selector)
eqx.tree_serialise_leaves(os.path.join(save_folder, "gc_intervention_optimizer.eqx"), gc_intervention_optimizer)
eqx.tree_serialise_leaves(os.path.join(save_folder, "goal_embedding_encoder.eqx"), goal_embedding_encoder)
eqx.tree_serialise_leaves(os.path.join(save_folder, "goal_achievement_loss.eqx"), goal_achievement_loss)
3 changes: 3 additions & 0 deletions autodiscjax/modules/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
import autodiscjax.modules.optimizers
import autodiscjax.modules.imgepwrappers
import autodiscjax.modules.grnwrappers
Loading

0 comments on commit 8ff29c5

Please sign in to comment.