From 92ece3cb226d89d113366051b7668187c58819b2 Mon Sep 17 00:00:00 2001 From: Christopher Beckham Date: Sun, 14 Apr 2024 17:11:33 -0400 Subject: [PATCH 1/3] replace np.loads with pkl.loads --- design_bench/oracles/approximate_oracle.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/design_bench/oracles/approximate_oracle.py b/design_bench/oracles/approximate_oracle.py index 58118be..7952433 100644 --- a/design_bench/oracles/approximate_oracle.py +++ b/design_bench/oracles/approximate_oracle.py @@ -391,7 +391,7 @@ def load_params(self, file): # read the validation rank correlation from the zip file with zip_archive.open('rank_correlation.npy', "r") as file: - rank_correlation = np.loads(file.read()) + rank_correlation = pkl.loads(file.read()) # read the validation parameters from the zip file with zip_archive.open('split_kwargs.pkl', "r") as file: From 4182bed798154c26f30006f8c5e11a063c418fd4 Mon Sep 17 00:00:00 2001 From: Christopher Beckham Date: Fri, 19 Apr 2024 22:28:47 -0400 Subject: [PATCH 2/3] DiskResource is now based on HF datasets --- design_bench/disk_resource.py | 170 +++++++--------------------------- 1 file changed, 36 insertions(+), 134 deletions(-) diff --git a/design_bench/disk_resource.py b/design_bench/disk_resource.py index 9c6340b..57ed0fb 100644 --- a/design_bench/disk_resource.py +++ b/design_bench/disk_resource.py @@ -1,11 +1,12 @@ -import requests -import zipfile import os +from huggingface_hub import hf_hub_download +import zipfile +import warnings -# the public url to objects available for download -SERVER_URL = "https://storage.googleapis.com/design-bench" +import requests +SERVER_URL="" # the global path to a folder that stores all data files DATA_DIR = os.path.join( @@ -13,102 +14,7 @@ os.path.dirname( os.path.dirname(os.path.abspath(__file__)))), 'design_bench_data') - -def get_confirm_token(response): - """Get a confirmation token from the cookies associated with the - google drive file download response - - """ - - for key, value in response.cookies.items(): - if key.startswith('download_warning'): - return value - - -def save_response(response, destination): - """Save the response from google drive at a physical location in the disk - assuming the destination is in a folder that exists - - """ - - with open(destination, "wb") as f: - for chunk in response.iter_content(32768): - if chunk: - f.write(chunk) - - -def google_drive_download(download_target, disk_target): - """Downloads a file from google drive using GET and stores that file - at a specified location on the local disk - - Arguments: - - download_target: str - the file id specified by google which is the 'X' in the url: - https://drive.google.com/file/d/X/view?usp=sharing - disk_target: str - the destination for the file on this device, do not call this - function is the file is already downloaded, as it will be overwritten - - Returns: - - success: bool - a boolean that indicates whether the download was successful is True - or an error was encountered when False (such as a 404 error) - - """ - - # connect to google drive and request the file - session = requests.Session() - response = session.get("https://docs.google.com/uc?export=download", - params={'id': download_target}, stream=True) - valid_response = response.status_code < 400 - if not valid_response: - return valid_response - - # confirm that the download should start - token = get_confirm_token(response) - if token is not None: - response = session.get("https://docs.google.com/uc?export=download", - params={'id': download_target, - 'confirm': token}, stream=True) - valid_response = response.status_code < 400 - if not valid_response: - return valid_response - - # save the content of the file to a local destination - save_response(response, disk_target) - return True - - -def direct_download(download_target, disk_target): - """Downloads a file from a direct url using GET and stores that file - at a specified location on the local disk - - Arguments: - - download_target: str - the direct url where the file is located on a remote server - available for direct download using GET - disk_target: str - the destination for the file on this device, do not call this - function is the file is already downloaded, as it will be overwritten - - Returns: - - success: bool - a boolean that indicates whether the download was successful is True - or an error was encountered when False (such as a 404 error) - - """ - - response = requests.get(download_target, allow_redirects=True) - valid_response = response.status_code < 400 - if valid_response: - with open(disk_target, "wb") as file: - file.write(response.content) - return valid_response - +DATA_DIR_REMOTE = None class DiskResource(object): """A resource manager that downloads files from remote destinations @@ -189,11 +95,17 @@ def __init__(self, disk_target, is_absolute=True, """ + self.repo_id = "beckhamc/design_bench_data" + + print("get: {}".format(download_target)) + self.disk_target = os.path.abspath(disk_target) \ if is_absolute else DiskResource.get_data_path(disk_target) + self.download_target = download_target - self.download_method = download_method - os.makedirs(os.path.dirname(self.disk_target), exist_ok=True) + self.download_method = download_method + + #os.makedirs(os.path.dirname(self.disk_target), exist_ok=True) @property def is_downloaded(self): @@ -204,43 +116,33 @@ def is_downloaded(self): return os.path.exists(self.disk_target) def download(self, unzip=True): - """Download the remote file from either google drive or a direct - remote url and store that file at a certain disk location - - Arguments: - - unzip: bool - a boolean indicator that specifies whether the downloaded file - should be unzipped if the file extension is .zip - - Returns: - - success: bool - a boolean that indicates whether the download was successful is True - or an error was encountered when False (such as a 404 error) - - """ - - # check that a download method for this file exists - if (self.download_target is None - or self.download_method is None): - return False - success = False - # download using a direct method - if self.download_method == "direct": - success = direct_download( - self.download_target, self.disk_target) - - # download using the google drive api - elif self.download_method == "google_drive": - success = google_drive_download( - self.download_target, self.disk_target) + if self.download_target.startswith("/"): + download_target = self.download_target[1:] + else: + download_target = self.download_target + + try: + self.disk_target = hf_hub_download( + repo_id=self.repo_id, + filename=download_target, + local_dir=DATA_DIR, + repo_type="dataset" + ) + success = True + except Exception as err: + warnings.warn( + "Unable to download file from {}: {}. Exception: {}".format( + self.repo_id, self.disk_target_relative, + str(err) + ), + UserWarning + ) # unzip the file if it is zipped if success and unzip and self.disk_target.endswith('.zip'): with zipfile.ZipFile(self.disk_target, 'r') as zip_ref: - zip_ref.extractall(os.path.dirname(self.disk_target)) + zip_ref.extractall(os.path.dirname(self.disk_target)) return success From 34eba51aa680131fb2034ebb2fa3c106ea6fb042 Mon Sep 17 00:00:00 2001 From: Christopher Beckham Date: Sun, 21 Apr 2024 11:41:09 -0400 Subject: [PATCH 3/3] Complete migration to HF-style disk resource, suppress annoying warning in morgan fingerprint --- .../continuous/ant_morphology_dataset.py | 4 ++-- .../continuous/dkitty_morphology_dataset.py | 4 ++-- .../continuous/hopper_controller_dataset.py | 4 ++-- .../continuous/superconductor_dataset.py | 4 ++-- .../continuous/toy_continuous_dataset.py | 4 ++-- design_bench/datasets/dataset_builder.py | 1 + .../datasets/discrete/chembl_dataset.py | 4 ++-- .../datasets/discrete/cifar_nas_dataset.py | 4 ++-- design_bench/datasets/discrete/gfp_dataset.py | 4 ++-- .../datasets/discrete/nas_bench_dataset.py | 4 ++-- .../datasets/discrete/tf_bind_10_dataset.py | 4 ++-- .../datasets/discrete/tf_bind_8_dataset.py | 4 ++-- .../datasets/discrete/toy_discrete_dataset.py | 4 ++-- design_bench/datasets/discrete/utr_dataset.py | 4 ++-- design_bench/disk_resource.py | 23 +++++++++++-------- design_bench/oracles/approximate_oracle.py | 2 +- .../oracles/exact/ant_morphology_oracle.py | 2 +- .../oracles/exact/dkitty_morphology_oracle.py | 2 +- .../morgan_fingerprint_features.py | 6 ++++- 19 files changed, 48 insertions(+), 40 deletions(-) diff --git a/design_bench/datasets/continuous/ant_morphology_dataset.py b/design_bench/datasets/continuous/ant_morphology_dataset.py index 3eb3e7e..85fbf8d 100644 --- a/design_bench/datasets/continuous/ant_morphology_dataset.py +++ b/design_bench/datasets/continuous/ant_morphology_dataset.py @@ -193,7 +193,7 @@ def register_x_shards(): return [DiskResource( file, is_absolute=False, - download_target=f"{SERVER_URL}/{file}", + download_target=file, download_method="direct") for file in ANT_MORPHOLOGY_FILES] @staticmethod @@ -213,7 +213,7 @@ def register_y_shards(): return [DiskResource( file.replace("-x-", "-y-"), is_absolute=False, - download_target=f"{SERVER_URL}/{file.replace('-x-', '-y-')}", + download_target=file.replace('-x-', '-y-'), download_method="direct") for file in ANT_MORPHOLOGY_FILES] def __init__(self, **kwargs): diff --git a/design_bench/datasets/continuous/dkitty_morphology_dataset.py b/design_bench/datasets/continuous/dkitty_morphology_dataset.py index 99b628b..9e2d28b 100644 --- a/design_bench/datasets/continuous/dkitty_morphology_dataset.py +++ b/design_bench/datasets/continuous/dkitty_morphology_dataset.py @@ -193,7 +193,7 @@ def register_x_shards(): return [DiskResource( file, is_absolute=False, - download_target=f"{SERVER_URL}/{file}", + download_target=file, download_method="direct") for file in DKITTY_MORPHOLOGY_FILES] @staticmethod @@ -213,7 +213,7 @@ def register_y_shards(): return [DiskResource( file.replace("-x-", "-y-"), is_absolute=False, - download_target=f"{SERVER_URL}/{file.replace('-x-', '-y-')}", + download_target=file.replace('-x-', '-y-'), download_method="direct") for file in DKITTY_MORPHOLOGY_FILES] def __init__(self, **kwargs): diff --git a/design_bench/datasets/continuous/hopper_controller_dataset.py b/design_bench/datasets/continuous/hopper_controller_dataset.py index 2058bc0..acd12f4 100644 --- a/design_bench/datasets/continuous/hopper_controller_dataset.py +++ b/design_bench/datasets/continuous/hopper_controller_dataset.py @@ -193,7 +193,7 @@ def register_x_shards(): return [DiskResource( file, is_absolute=False, - download_target=f"{SERVER_URL}/{file}", + download_target=file, download_method="direct") for file in HOPPER_CONTROLLER_FILES] @staticmethod @@ -213,7 +213,7 @@ def register_y_shards(): return [DiskResource( file.replace("-x-", "-y-"), is_absolute=False, - download_target=f"{SERVER_URL}/{file.replace('-x-', '-y-')}", + download_target=file.replace('-x-', '-y-'), download_method="direct") for file in HOPPER_CONTROLLER_FILES] def __init__(self, **kwargs): diff --git a/design_bench/datasets/continuous/superconductor_dataset.py b/design_bench/datasets/continuous/superconductor_dataset.py index 1fd5cf3..e66b2cb 100644 --- a/design_bench/datasets/continuous/superconductor_dataset.py +++ b/design_bench/datasets/continuous/superconductor_dataset.py @@ -197,7 +197,7 @@ def register_x_shards(): return [DiskResource( file, is_absolute=False, - download_target=f"{SERVER_URL}/{file}", + download_target=file, download_method="direct") for file in SUPERCONDUCTOR_FILES] @staticmethod @@ -217,7 +217,7 @@ def register_y_shards(): return [DiskResource( file.replace("-x-", "-y-"), is_absolute=False, - download_target=f"{SERVER_URL}/{file.replace('-x-', '-y-')}", + download_target=file.replace('-x-', '-y-'), download_method="direct") for file in SUPERCONDUCTOR_FILES] def __init__(self, **kwargs): diff --git a/design_bench/datasets/continuous/toy_continuous_dataset.py b/design_bench/datasets/continuous/toy_continuous_dataset.py index 1c9fabb..7059e91 100644 --- a/design_bench/datasets/continuous/toy_continuous_dataset.py +++ b/design_bench/datasets/continuous/toy_continuous_dataset.py @@ -206,7 +206,7 @@ def register_x_shards(): return [DiskResource( file, is_absolute=False, - download_target=f"{SERVER_URL}/{file}", + download_target=file, download_method="direct") for file in TOY_CONTINUOUS_FILES] @staticmethod @@ -226,7 +226,7 @@ def register_y_shards(): return [DiskResource( file.replace("-x-", "-y-"), is_absolute=False, - download_target=f"{SERVER_URL}/{file.replace('-x-', '-y-')}", + download_target=file.replace('-x-', '-y-'), download_method="direct") for file in TOY_CONTINUOUS_FILES] def __init__(self, **kwargs): diff --git a/design_bench/datasets/dataset_builder.py b/design_bench/datasets/dataset_builder.py index 4d5102e..c8a908c 100644 --- a/design_bench/datasets/dataset_builder.py +++ b/design_bench/datasets/dataset_builder.py @@ -460,6 +460,7 @@ def __init__(self, x_shards, y_shards, internal_batch_size=32, self.map_normalize_x() if is_normalized_y: self.map_normalize_y() + self.subsample(max_samples=max_samples, distribution=distribution, min_percentile=min_percentile, diff --git a/design_bench/datasets/discrete/chembl_dataset.py b/design_bench/datasets/discrete/chembl_dataset.py index 0627fcd..ab7a4f1 100644 --- a/design_bench/datasets/discrete/chembl_dataset.py +++ b/design_bench/datasets/discrete/chembl_dataset.py @@ -627,7 +627,7 @@ def register_x_shards(assay_chembl_id="CHEMBL1794345", return [DiskResource( file, is_absolute=False, - download_target=f"{SERVER_URL}/{file}", + download_target=file, download_method="direct") for file in CHEMBL_FILES if f"{standard_type}-{assay_chembl_id}" in file] @@ -660,7 +660,7 @@ def register_y_shards(assay_chembl_id="CHEMBL1794345", return [DiskResource( file.replace("-x-", "-y-"), is_absolute=False, - download_target=f"{SERVER_URL}/{file.replace('-x-', '-y-')}", + download_target=file.replace('-x-', '-y-'), download_method="direct") for file in CHEMBL_FILES if f"{standard_type}-{assay_chembl_id}" in file] diff --git a/design_bench/datasets/discrete/cifar_nas_dataset.py b/design_bench/datasets/discrete/cifar_nas_dataset.py index fdeb5ba..1001857 100644 --- a/design_bench/datasets/discrete/cifar_nas_dataset.py +++ b/design_bench/datasets/discrete/cifar_nas_dataset.py @@ -218,7 +218,7 @@ def register_x_shards(): return [DiskResource( file, is_absolute=False, - download_target=f"{SERVER_URL}/{file}", + download_target=file, download_method="direct") for file in NAS_FILES] @staticmethod @@ -238,7 +238,7 @@ def register_y_shards(): return [DiskResource( file.replace("-x-", "-y-"), is_absolute=False, - download_target=f"{SERVER_URL}/{file.replace('-x-', '-y-')}", + download_target=file.replace('-x-', '-y-'), download_method="direct") for file in NAS_FILES] def __init__(self, soft_interpolation=0.6, **kwargs): diff --git a/design_bench/datasets/discrete/gfp_dataset.py b/design_bench/datasets/discrete/gfp_dataset.py index f670931..4cdedee 100644 --- a/design_bench/datasets/discrete/gfp_dataset.py +++ b/design_bench/datasets/discrete/gfp_dataset.py @@ -229,7 +229,7 @@ def register_x_shards(): return [DiskResource( file, is_absolute=False, - download_target=f"{SERVER_URL}/{file}", + download_target=file, download_method="direct") for file in GFP_FILES] @staticmethod @@ -249,7 +249,7 @@ def register_y_shards(): return [DiskResource( file.replace("-x-", "-y-"), is_absolute=False, - download_target=f"{SERVER_URL}/{file.replace('-x-', '-y-')}", + download_target=file.replace('-x-', '-y-'), download_method="direct") for file in GFP_FILES] def __init__(self, soft_interpolation=0.6, **kwargs): diff --git a/design_bench/datasets/discrete/nas_bench_dataset.py b/design_bench/datasets/discrete/nas_bench_dataset.py index d7d66c1..7cfca0a 100644 --- a/design_bench/datasets/discrete/nas_bench_dataset.py +++ b/design_bench/datasets/discrete/nas_bench_dataset.py @@ -243,7 +243,7 @@ def register_x_shards(): return [DiskResource( file, is_absolute=False, - download_target=f"{SERVER_URL}/{file}", + download_target=file, download_method="direct") for file in NAS_BENCH_FILES] @staticmethod @@ -263,7 +263,7 @@ def register_y_shards(): return [DiskResource( file.replace("-x-", "-y-"), is_absolute=False, - download_target=f"{SERVER_URL}/{file.replace('-x-', '-y-')}", + download_target=file.replace('-x-', '-y-'), download_method="direct") for file in NAS_BENCH_FILES] def __init__(self, soft_interpolation=0.6, **kwargs): diff --git a/design_bench/datasets/discrete/tf_bind_10_dataset.py b/design_bench/datasets/discrete/tf_bind_10_dataset.py index aad69d1..49adf11 100644 --- a/design_bench/datasets/discrete/tf_bind_10_dataset.py +++ b/design_bench/datasets/discrete/tf_bind_10_dataset.py @@ -225,7 +225,7 @@ def register_x_shards(transcription_factor='pho4'): return [DiskResource( file, is_absolute=False, - download_target=f"{SERVER_URL}/{file}", + download_target=file, download_method="direct") for file in TF_BIND_10_FILES if transcription_factor in file] @@ -253,7 +253,7 @@ def register_y_shards(transcription_factor='pho4'): return [DiskResource( file.replace("-x-", "-y-"), is_absolute=False, - download_target=f"{SERVER_URL}/{file.replace('-x-', '-y-')}", + download_target=file.replace('-x-', '-y-'), download_method="direct") for file in TF_BIND_10_FILES if transcription_factor in file] diff --git a/design_bench/datasets/discrete/tf_bind_8_dataset.py b/design_bench/datasets/discrete/tf_bind_8_dataset.py index 6111f03..685d94a 100644 --- a/design_bench/datasets/discrete/tf_bind_8_dataset.py +++ b/design_bench/datasets/discrete/tf_bind_8_dataset.py @@ -225,7 +225,7 @@ def register_x_shards(transcription_factor='SIX6_REF_R1'): return [DiskResource( file, is_absolute=False, - download_target=f"{SERVER_URL}/{file}", + download_target=file, download_method="direct") for file in TF_BIND_8_FILES if transcription_factor in file] @@ -253,7 +253,7 @@ def register_y_shards(transcription_factor='SIX6_REF_R1'): return [DiskResource( file.replace("-x-", "-y-"), is_absolute=False, - download_target=f"{SERVER_URL}/{file.replace('-x-', '-y-')}", + download_target=file.replace('-x-', '-y-'), download_method="direct") for file in TF_BIND_8_FILES if transcription_factor in file] diff --git a/design_bench/datasets/discrete/toy_discrete_dataset.py b/design_bench/datasets/discrete/toy_discrete_dataset.py index 6bbb672..6cd46f2 100644 --- a/design_bench/datasets/discrete/toy_discrete_dataset.py +++ b/design_bench/datasets/discrete/toy_discrete_dataset.py @@ -231,7 +231,7 @@ def register_x_shards(): return [DiskResource( file, is_absolute=False, - download_target=f"{SERVER_URL}/{file}", + download_target=file, download_method="direct") for file in TOY_DISCRETE_FILES] @staticmethod @@ -251,7 +251,7 @@ def register_y_shards(): return [DiskResource( file.replace("-x-", "-y-"), is_absolute=False, - download_target=f"{SERVER_URL}/{file.replace('-x-', '-y-')}", + download_target=file.replace('-x-', '-y-'), download_method="direct") for file in TOY_DISCRETE_FILES] def __init__(self, soft_interpolation=0.6, **kwargs): diff --git a/design_bench/datasets/discrete/utr_dataset.py b/design_bench/datasets/discrete/utr_dataset.py index 2bfe59d..5a79e14 100644 --- a/design_bench/datasets/discrete/utr_dataset.py +++ b/design_bench/datasets/discrete/utr_dataset.py @@ -218,7 +218,7 @@ def register_x_shards(): return [DiskResource( file, is_absolute=False, - download_target=f"{SERVER_URL}/{file}", + download_target=file, download_method="direct") for file in UTR_FILES] @staticmethod @@ -238,7 +238,7 @@ def register_y_shards(): return [DiskResource( file.replace("-x-", "-y-"), is_absolute=False, - download_target=f"{SERVER_URL}/{file.replace('-x-', '-y-')}", + download_target=file.replace('-x-', '-y-'), download_method="direct") for file in UTR_FILES] def __init__(self, soft_interpolation=0.6, **kwargs): diff --git a/design_bench/disk_resource.py b/design_bench/disk_resource.py index 57ed0fb..f2b19cd 100644 --- a/design_bench/disk_resource.py +++ b/design_bench/disk_resource.py @@ -4,9 +4,7 @@ import warnings -import requests - -SERVER_URL="" +SERVER_URL = os.environ.get("DB_HF_DATA", "beckhamc/design_bench_data") # the global path to a folder that stores all data files DATA_DIR = os.path.join( @@ -14,7 +12,6 @@ os.path.dirname( os.path.dirname(os.path.abspath(__file__)))), 'design_bench_data') -DATA_DIR_REMOTE = None class DiskResource(object): """A resource manager that downloads files from remote destinations @@ -72,8 +69,12 @@ def get_data_path(file_path): return os.path.join(DATA_DIR, file_path) - def __init__(self, disk_target, is_absolute=True, - download_target=None, download_method=None): + def __init__(self, + disk_target, + is_absolute=True, + download_target=None, + repo_id=None, + download_method=None): """A resource manager that downloads files from remote destinations and loads these files from the disk, used to manage remote datasets for offline model-based optimization problems @@ -95,9 +96,10 @@ def __init__(self, disk_target, is_absolute=True, """ - self.repo_id = "beckhamc/design_bench_data" - - print("get: {}".format(download_target)) + if repo_id is None: + self.repo_id = SERVER_URL + else: + self.repo_id = repo_id self.disk_target = os.path.abspath(disk_target) \ if is_absolute else DiskResource.get_data_path(disk_target) @@ -124,6 +126,7 @@ def download(self, unzip=True): download_target = self.download_target try: + print("repo_id={}, filename={}".format(self.repo_id,download_target)) self.disk_target = hf_hub_download( repo_id=self.repo_id, filename=download_target, @@ -134,7 +137,7 @@ def download(self, unzip=True): except Exception as err: warnings.warn( "Unable to download file from {}: {}. Exception: {}".format( - self.repo_id, self.disk_target_relative, + self.repo_id, download_target, str(err) ), UserWarning diff --git a/design_bench/oracles/approximate_oracle.py b/design_bench/oracles/approximate_oracle.py index 7952433..3fe8114 100644 --- a/design_bench/oracles/approximate_oracle.py +++ b/design_bench/oracles/approximate_oracle.py @@ -332,7 +332,7 @@ def get_disk_resource(self, dataset, is_absolute=is_absolute, download_method=None if disk_target else "direct", download_target=None if disk_target else - f"{SERVER_URL}/{default}") + default) def save_params(self, file, params): """a function that serializes a machine learning model and stores diff --git a/design_bench/oracles/exact/ant_morphology_oracle.py b/design_bench/oracles/exact/ant_morphology_oracle.py index e6a31bf..e6ef6e9 100644 --- a/design_bench/oracles/exact/ant_morphology_oracle.py +++ b/design_bench/oracles/exact/ant_morphology_oracle.py @@ -186,7 +186,7 @@ def __init__(self, dataset: ContinuousDataset, policy = "ant_morphology/ant_oracle.pkl" policy = DiskResource( policy, is_absolute=False, download_method="direct", - download_target=f"{SERVER_URL}/{policy}") + download_target=policy) if not policy.is_downloaded and not policy.download(): raise ValueError("unable to download trained policy for ant") diff --git a/design_bench/oracles/exact/dkitty_morphology_oracle.py b/design_bench/oracles/exact/dkitty_morphology_oracle.py index 317d509..173a576 100644 --- a/design_bench/oracles/exact/dkitty_morphology_oracle.py +++ b/design_bench/oracles/exact/dkitty_morphology_oracle.py @@ -186,7 +186,7 @@ def __init__(self, dataset: ContinuousDataset, policy = "dkitty_morphology/dkitty_oracle.pkl" policy = DiskResource( policy, is_absolute=False, download_method="direct", - download_target=f"{SERVER_URL}/{policy}") + download_target=policy) if not policy.is_downloaded and not policy.download(): raise ValueError("unable to download trained policy for ant") diff --git a/design_bench/oracles/feature_extractors/morgan_fingerprint_features.py b/design_bench/oracles/feature_extractors/morgan_fingerprint_features.py index c82b36c..3a57b48 100644 --- a/design_bench/oracles/feature_extractors/morgan_fingerprint_features.py +++ b/design_bench/oracles/feature_extractors/morgan_fingerprint_features.py @@ -3,10 +3,13 @@ from design_bench.disk_resource import DATA_DIR from design_bench.disk_resource import SERVER_URL from deepchem.feat.smiles_tokenizer import SmilesTokenizer +import transformers import deepchem.feat as feat import os import numpy as np +transformers.logging.set_verbosity_error() + class MorganFingerprintFeatures(FeatureExtractor): """An abstract class for managing transformations applied to model-based @@ -67,7 +70,8 @@ def __init__(self, size=2048, radius=4, dtype=np.int32): vocab_file = DiskResource( os.path.join(DATA_DIR, 'smiles_vocab.txt'), download_method="direct", - download_target=f'{SERVER_URL}/smiles_vocab.txt') + download_target="smiles_vocab.txt" + ) if not vocab_file.is_downloaded: vocab_file.download() self.tokenizer = SmilesTokenizer(