From e671bc85b7b66b2beaefe5affe973dd9db4fc65d Mon Sep 17 00:00:00 2001 From: MenuaB Date: Mon, 20 May 2024 16:35:06 +0400 Subject: [PATCH] minor changes for run --- .../galactica_125m_pretrain_config.yaml | 3 +- chemlactica/utils/dataset_utils.py | 13 ++-- chemlactica/utils/text_format_utils.py | 63 +++++++++---------- submit_run.py | 10 +-- test_status.yaml | 2 +- 5 files changed, 44 insertions(+), 47 deletions(-) diff --git a/chemlactica/config/config_yamls/galactica_125m_pretrain_config.yaml b/chemlactica/config/config_yamls/galactica_125m_pretrain_config.yaml index f3da4b7..546cdaf 100644 --- a/chemlactica/config/config_yamls/galactica_125m_pretrain_config.yaml +++ b/chemlactica/config/config_yamls/galactica_125m_pretrain_config.yaml @@ -13,7 +13,7 @@ train_config: bf16_full_eval: true fp16: false tf32: true - evaluation_strategy: "steps" + evaluation_strategy: "no" save_total_limit: 8 grad_accumulation_scheduler: false dynamic_grad_accumulation: false @@ -27,4 +27,5 @@ model_config: block_size: 2048 vocab_size: 50000 separator_token: + separator_token_id: 2 tokenizer_path: "./chemlactica/tokenizer/ChemLacticaTokenizer66" diff --git a/chemlactica/utils/dataset_utils.py b/chemlactica/utils/dataset_utils.py index af4e9bb..08e77fd 100644 --- a/chemlactica/utils/dataset_utils.py +++ b/chemlactica/utils/dataset_utils.py @@ -73,14 +73,15 @@ def process_str(str, random_number_generator, model_config): # it's wierd workaround but works for now try: compound = load_jsonl_line(str["text"]) + compound = delete_empty_tags(compound) + str["text"] = generate_formatted_string( + compound, random_number_generator, model_config + ) + string = str except Exception as e: print(e) - return "" - compound = delete_empty_tags(compound) - str["text"] = generate_formatted_string( - compound, random_number_generator, model_config - ) - return str + string = "" + return string def group_texts(examples, model_config, eos_token_id): diff --git a/chemlactica/utils/text_format_utils.py b/chemlactica/utils/text_format_utils.py index 3237d72..507e770 100644 --- a/chemlactica/utils/text_format_utils.py +++ b/chemlactica/utils/text_format_utils.py @@ -105,41 +105,36 @@ def generate_formatted_string(compound_json, rng, model_config): def format_key_value(key, value, rng): - try: - if key == "CID": - return "" - formatted_string = "" - if key == "related": - if len(value) > 10: - # value = random.sample(value, 5) - value = rng.choice(value, size=10, replace=False, shuffle=False) - for pair in value: - rounded_sim = "{:.2f}".format(float(pair["similarity"])) - formatted_string += f"{SPECIAL_TAGS['similarity']['start']}{pair['SMILES']} {rounded_sim}{SPECIAL_TAGS['similarity']['end']}" # noqa - elif key == "experimental": - for pair in value: - formatted_string += f"[PROPERTY]{pair['PROPERTY_NAME']} {pair['PROPERTY_VALUE']}[/PROPERTY]" # noqa - elif key == "synonyms": - for val in value: - formatted_string += f"{SPECIAL_TAGS['synonym']['start']}{val['name']}{SPECIAL_TAGS['synonym']['end']}" # noqa - else: - try: - if SPECIAL_TAGS[key].get("type") is float: - value = "{:.2f}".format(float(value)) - assert len(value.split(".")[-1]) == 2 - start = SPECIAL_TAGS[key]["start"] - end = SPECIAL_TAGS[key]["end"] - except Exception as e: - print(e) - print("Failed to parse: ", key, value) - start = value = end = "" - return f"{start}{value}{end}" - - return formatted_string - except Exception as e: - print(e) - print("Failed to parse: ", key, value) + if key == "CID": return "" + formatted_string = "" + if key == "related": + if len(value) > 10: + # value = random.sample(value, 5) + value = rng.choice(value, size=10, replace=False, shuffle=False) + for pair in value: + rounded_sim = "{:.2f}".format(float(pair["similarity"])) + formatted_string += f"{SPECIAL_TAGS['similarity']['start']}{pair['SMILES']} {rounded_sim}{SPECIAL_TAGS['similarity']['end']}" # noqa + elif key == "experimental": + for pair in value: + formatted_string += f"[PROPERTY]{pair['PROPERTY_NAME']} {pair['PROPERTY_VALUE']}[/PROPERTY]" # noqa + elif key == "synonyms": + for val in value: + formatted_string += f"{SPECIAL_TAGS['synonym']['start']}{val['name']}{SPECIAL_TAGS['synonym']['end']}" # noqa + else: + try: + if SPECIAL_TAGS[key].get("type") is float: + value = "{:.2f}".format(float(value)) + assert len(value.split(".")[-1]) == 2 + start = SPECIAL_TAGS[key]["start"] + end = SPECIAL_TAGS[key]["end"] + except Exception as e: + print(e) + print("Failed to parse: ", key, value) + start = value = end = "" + return f"{start}{value}{end}" + + return formatted_string def main(): diff --git a/submit_run.py b/submit_run.py index eee08ab..80070ea 100644 --- a/submit_run.py +++ b/submit_run.py @@ -13,7 +13,7 @@ model_size = "125m" train_type = "pretrain" train_name = "_".join([model_name, model_size, train_type]) -job_name = "gal_relform2" +job_name = "gal_relform3" slurm_params = { "slurm_job_name": job_name, @@ -21,7 +21,7 @@ "nodes": 1, "tasks_per_node": 1, "gpus_per_node": num_gpus, - "cpus_per_task": num_gpus * 20, + "cpus_per_task": num_gpus * 10, "mem_gb": num_gpus * 40.0 + 20.0, "stderr_to_stdout": True, } @@ -41,10 +41,10 @@ "dir_data_types": "computed", "training_data_dirs": "/nfs/ap/mnt/sxtn/rdkit_computed_rel+form/train_rdkit_computed_rel+form", "valid_data_dir": "/nfs/ap/mnt/sxtn/rdkit_computed_rel+form/valid_rdkit_computed_rel+form", - "max_steps": 20000, + "max_steps": 19000, # "num_train_epochs": 24, - "eval_steps": 2000, - "save_steps": 2000, + "eval_steps": 0, + "save_steps": 6300, "train_batch_size": 16, "valid_batch_size": 16, "dataloader_num_workers": 1, diff --git a/test_status.yaml b/test_status.yaml index 8ead52f..53f3ff3 100644 --- a/test_status.yaml +++ b/test_status.yaml @@ -1 +1 @@ -33f0567f336bd041b6b687cb3258a855d948b6b8: PASS +210ba73adff95aacc5f6c215dc6218923d85f767: FAIL