Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix dataset loading, and other minor fixes #21

Open
wants to merge 3 commits into
base: new-api
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions design_bench/datasets/continuous/ant_morphology_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def register_x_shards():

return [DiskResource(
file, is_absolute=False,
download_target=f"{SERVER_URL}/{file}",
download_target=file,
download_method="direct") for file in ANT_MORPHOLOGY_FILES]

@staticmethod
Expand All @@ -213,7 +213,7 @@ def register_y_shards():

return [DiskResource(
file.replace("-x-", "-y-"), is_absolute=False,
download_target=f"{SERVER_URL}/{file.replace('-x-', '-y-')}",
download_target=file.replace('-x-', '-y-'),
download_method="direct") for file in ANT_MORPHOLOGY_FILES]

def __init__(self, **kwargs):
Expand Down
4 changes: 2 additions & 2 deletions design_bench/datasets/continuous/dkitty_morphology_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def register_x_shards():

return [DiskResource(
file, is_absolute=False,
download_target=f"{SERVER_URL}/{file}",
download_target=file,
download_method="direct") for file in DKITTY_MORPHOLOGY_FILES]

@staticmethod
Expand All @@ -213,7 +213,7 @@ def register_y_shards():

return [DiskResource(
file.replace("-x-", "-y-"), is_absolute=False,
download_target=f"{SERVER_URL}/{file.replace('-x-', '-y-')}",
download_target=file.replace('-x-', '-y-'),
download_method="direct") for file in DKITTY_MORPHOLOGY_FILES]

def __init__(self, **kwargs):
Expand Down
4 changes: 2 additions & 2 deletions design_bench/datasets/continuous/hopper_controller_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def register_x_shards():

return [DiskResource(
file, is_absolute=False,
download_target=f"{SERVER_URL}/{file}",
download_target=file,
download_method="direct") for file in HOPPER_CONTROLLER_FILES]

@staticmethod
Expand All @@ -213,7 +213,7 @@ def register_y_shards():

return [DiskResource(
file.replace("-x-", "-y-"), is_absolute=False,
download_target=f"{SERVER_URL}/{file.replace('-x-', '-y-')}",
download_target=file.replace('-x-', '-y-'),
download_method="direct") for file in HOPPER_CONTROLLER_FILES]

def __init__(self, **kwargs):
Expand Down
4 changes: 2 additions & 2 deletions design_bench/datasets/continuous/superconductor_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def register_x_shards():

return [DiskResource(
file, is_absolute=False,
download_target=f"{SERVER_URL}/{file}",
download_target=file,
download_method="direct") for file in SUPERCONDUCTOR_FILES]

@staticmethod
Expand All @@ -217,7 +217,7 @@ def register_y_shards():

return [DiskResource(
file.replace("-x-", "-y-"), is_absolute=False,
download_target=f"{SERVER_URL}/{file.replace('-x-', '-y-')}",
download_target=file.replace('-x-', '-y-'),
download_method="direct") for file in SUPERCONDUCTOR_FILES]

def __init__(self, **kwargs):
Expand Down
4 changes: 2 additions & 2 deletions design_bench/datasets/continuous/toy_continuous_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def register_x_shards():

return [DiskResource(
file, is_absolute=False,
download_target=f"{SERVER_URL}/{file}",
download_target=file,
download_method="direct") for file in TOY_CONTINUOUS_FILES]

@staticmethod
Expand All @@ -226,7 +226,7 @@ def register_y_shards():

return [DiskResource(
file.replace("-x-", "-y-"), is_absolute=False,
download_target=f"{SERVER_URL}/{file.replace('-x-', '-y-')}",
download_target=file.replace('-x-', '-y-'),
download_method="direct") for file in TOY_CONTINUOUS_FILES]

def __init__(self, **kwargs):
Expand Down
1 change: 1 addition & 0 deletions design_bench/datasets/dataset_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,7 @@ def __init__(self, x_shards, y_shards, internal_batch_size=32,
self.map_normalize_x()
if is_normalized_y:
self.map_normalize_y()

self.subsample(max_samples=max_samples,
distribution=distribution,
min_percentile=min_percentile,
Expand Down
4 changes: 2 additions & 2 deletions design_bench/datasets/discrete/chembl_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -627,7 +627,7 @@ def register_x_shards(assay_chembl_id="CHEMBL1794345",

return [DiskResource(
file, is_absolute=False,
download_target=f"{SERVER_URL}/{file}",
download_target=file,
download_method="direct") for file in CHEMBL_FILES
if f"{standard_type}-{assay_chembl_id}" in file]

Expand Down Expand Up @@ -660,7 +660,7 @@ def register_y_shards(assay_chembl_id="CHEMBL1794345",

return [DiskResource(
file.replace("-x-", "-y-"), is_absolute=False,
download_target=f"{SERVER_URL}/{file.replace('-x-', '-y-')}",
download_target=file.replace('-x-', '-y-'),
download_method="direct") for file in CHEMBL_FILES
if f"{standard_type}-{assay_chembl_id}" in file]

Expand Down
4 changes: 2 additions & 2 deletions design_bench/datasets/discrete/cifar_nas_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def register_x_shards():

return [DiskResource(
file, is_absolute=False,
download_target=f"{SERVER_URL}/{file}",
download_target=file,
download_method="direct") for file in NAS_FILES]

@staticmethod
Expand All @@ -238,7 +238,7 @@ def register_y_shards():

return [DiskResource(
file.replace("-x-", "-y-"), is_absolute=False,
download_target=f"{SERVER_URL}/{file.replace('-x-', '-y-')}",
download_target=file.replace('-x-', '-y-'),
download_method="direct") for file in NAS_FILES]

def __init__(self, soft_interpolation=0.6, **kwargs):
Expand Down
4 changes: 2 additions & 2 deletions design_bench/datasets/discrete/gfp_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ def register_x_shards():

return [DiskResource(
file, is_absolute=False,
download_target=f"{SERVER_URL}/{file}",
download_target=file,
download_method="direct") for file in GFP_FILES]

@staticmethod
Expand All @@ -249,7 +249,7 @@ def register_y_shards():

return [DiskResource(
file.replace("-x-", "-y-"), is_absolute=False,
download_target=f"{SERVER_URL}/{file.replace('-x-', '-y-')}",
download_target=file.replace('-x-', '-y-'),
download_method="direct") for file in GFP_FILES]

def __init__(self, soft_interpolation=0.6, **kwargs):
Expand Down
4 changes: 2 additions & 2 deletions design_bench/datasets/discrete/nas_bench_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ def register_x_shards():

return [DiskResource(
file, is_absolute=False,
download_target=f"{SERVER_URL}/{file}",
download_target=file,
download_method="direct") for file in NAS_BENCH_FILES]

@staticmethod
Expand All @@ -263,7 +263,7 @@ def register_y_shards():

return [DiskResource(
file.replace("-x-", "-y-"), is_absolute=False,
download_target=f"{SERVER_URL}/{file.replace('-x-', '-y-')}",
download_target=file.replace('-x-', '-y-'),
download_method="direct") for file in NAS_BENCH_FILES]

def __init__(self, soft_interpolation=0.6, **kwargs):
Expand Down
4 changes: 2 additions & 2 deletions design_bench/datasets/discrete/tf_bind_10_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ def register_x_shards(transcription_factor='pho4'):

return [DiskResource(
file, is_absolute=False,
download_target=f"{SERVER_URL}/{file}",
download_target=file,
download_method="direct") for file in TF_BIND_10_FILES
if transcription_factor in file]

Expand Down Expand Up @@ -253,7 +253,7 @@ def register_y_shards(transcription_factor='pho4'):

return [DiskResource(
file.replace("-x-", "-y-"), is_absolute=False,
download_target=f"{SERVER_URL}/{file.replace('-x-', '-y-')}",
download_target=file.replace('-x-', '-y-'),
download_method="direct") for file in TF_BIND_10_FILES
if transcription_factor in file]

Expand Down
4 changes: 2 additions & 2 deletions design_bench/datasets/discrete/tf_bind_8_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ def register_x_shards(transcription_factor='SIX6_REF_R1'):

return [DiskResource(
file, is_absolute=False,
download_target=f"{SERVER_URL}/{file}",
download_target=file,
download_method="direct") for file in TF_BIND_8_FILES
if transcription_factor in file]

Expand Down Expand Up @@ -253,7 +253,7 @@ def register_y_shards(transcription_factor='SIX6_REF_R1'):

return [DiskResource(
file.replace("-x-", "-y-"), is_absolute=False,
download_target=f"{SERVER_URL}/{file.replace('-x-', '-y-')}",
download_target=file.replace('-x-', '-y-'),
download_method="direct") for file in TF_BIND_8_FILES
if transcription_factor in file]

Expand Down
4 changes: 2 additions & 2 deletions design_bench/datasets/discrete/toy_discrete_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ def register_x_shards():

return [DiskResource(
file, is_absolute=False,
download_target=f"{SERVER_URL}/{file}",
download_target=file,
download_method="direct") for file in TOY_DISCRETE_FILES]

@staticmethod
Expand All @@ -251,7 +251,7 @@ def register_y_shards():

return [DiskResource(
file.replace("-x-", "-y-"), is_absolute=False,
download_target=f"{SERVER_URL}/{file.replace('-x-', '-y-')}",
download_target=file.replace('-x-', '-y-'),
download_method="direct") for file in TOY_DISCRETE_FILES]

def __init__(self, soft_interpolation=0.6, **kwargs):
Expand Down
4 changes: 2 additions & 2 deletions design_bench/datasets/discrete/utr_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def register_x_shards():

return [DiskResource(
file, is_absolute=False,
download_target=f"{SERVER_URL}/{file}",
download_target=file,
download_method="direct") for file in UTR_FILES]

@staticmethod
Expand All @@ -238,7 +238,7 @@ def register_y_shards():

return [DiskResource(
file.replace("-x-", "-y-"), is_absolute=False,
download_target=f"{SERVER_URL}/{file.replace('-x-', '-y-')}",
download_target=file.replace('-x-', '-y-'),
download_method="direct") for file in UTR_FILES]

def __init__(self, soft_interpolation=0.6, **kwargs):
Expand Down
Loading