Skip to content

Commit

Permalink
Merge pull request #30 from YerevaNN/minor_changes_for_run
Browse files Browse the repository at this point in the history
minor changes for run
  • Loading branch information
MenuaB authored May 20, 2024
2 parents 79bbd5d + e671bc8 commit 6e6e81a
Show file tree
Hide file tree
Showing 5 changed files with 44 additions and 47 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ train_config:
bf16_full_eval: true
fp16: false
tf32: true
evaluation_strategy: "steps"
evaluation_strategy: "no"
save_total_limit: 8
grad_accumulation_scheduler: false
dynamic_grad_accumulation: false
Expand All @@ -27,4 +27,5 @@ model_config:
block_size: 2048
vocab_size: 50000
separator_token: </s>
separator_token_id: 2
tokenizer_path: "./chemlactica/tokenizer/ChemLacticaTokenizer66"
13 changes: 7 additions & 6 deletions chemlactica/utils/dataset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,14 +73,15 @@ def process_str(str, random_number_generator, model_config):
# it's wierd workaround but works for now
try:
compound = load_jsonl_line(str["text"])
compound = delete_empty_tags(compound)
str["text"] = generate_formatted_string(
compound, random_number_generator, model_config
)
string = str
except Exception as e:
print(e)
return ""
compound = delete_empty_tags(compound)
str["text"] = generate_formatted_string(
compound, random_number_generator, model_config
)
return str
string = ""
return string


def group_texts(examples, model_config, eos_token_id):
Expand Down
63 changes: 29 additions & 34 deletions chemlactica/utils/text_format_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,41 +105,36 @@ def generate_formatted_string(compound_json, rng, model_config):


def format_key_value(key, value, rng):
try:
if key == "CID":
return ""
formatted_string = ""
if key == "related":
if len(value) > 10:
# value = random.sample(value, 5)
value = rng.choice(value, size=10, replace=False, shuffle=False)
for pair in value:
rounded_sim = "{:.2f}".format(float(pair["similarity"]))
formatted_string += f"{SPECIAL_TAGS['similarity']['start']}{pair['SMILES']} {rounded_sim}{SPECIAL_TAGS['similarity']['end']}" # noqa
elif key == "experimental":
for pair in value:
formatted_string += f"[PROPERTY]{pair['PROPERTY_NAME']} {pair['PROPERTY_VALUE']}[/PROPERTY]" # noqa
elif key == "synonyms":
for val in value:
formatted_string += f"{SPECIAL_TAGS['synonym']['start']}{val['name']}{SPECIAL_TAGS['synonym']['end']}" # noqa
else:
try:
if SPECIAL_TAGS[key].get("type") is float:
value = "{:.2f}".format(float(value))
assert len(value.split(".")[-1]) == 2
start = SPECIAL_TAGS[key]["start"]
end = SPECIAL_TAGS[key]["end"]
except Exception as e:
print(e)
print("Failed to parse: ", key, value)
start = value = end = ""
return f"{start}{value}{end}"

return formatted_string
except Exception as e:
print(e)
print("Failed to parse: ", key, value)
if key == "CID":
return ""
formatted_string = ""
if key == "related":
if len(value) > 10:
# value = random.sample(value, 5)
value = rng.choice(value, size=10, replace=False, shuffle=False)
for pair in value:
rounded_sim = "{:.2f}".format(float(pair["similarity"]))
formatted_string += f"{SPECIAL_TAGS['similarity']['start']}{pair['SMILES']} {rounded_sim}{SPECIAL_TAGS['similarity']['end']}" # noqa
elif key == "experimental":
for pair in value:
formatted_string += f"[PROPERTY]{pair['PROPERTY_NAME']} {pair['PROPERTY_VALUE']}[/PROPERTY]" # noqa
elif key == "synonyms":
for val in value:
formatted_string += f"{SPECIAL_TAGS['synonym']['start']}{val['name']}{SPECIAL_TAGS['synonym']['end']}" # noqa
else:
try:
if SPECIAL_TAGS[key].get("type") is float:
value = "{:.2f}".format(float(value))
assert len(value.split(".")[-1]) == 2
start = SPECIAL_TAGS[key]["start"]
end = SPECIAL_TAGS[key]["end"]
except Exception as e:
print(e)
print("Failed to parse: ", key, value)
start = value = end = ""
return f"{start}{value}{end}"

return formatted_string


def main():
Expand Down
10 changes: 5 additions & 5 deletions submit_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,15 @@
model_size = "125m"
train_type = "pretrain"
train_name = "_".join([model_name, model_size, train_type])
job_name = "gal_relform2"
job_name = "gal_relform3"

slurm_params = {
"slurm_job_name": job_name,
"timeout_min": 60 * 24 * 2,
"nodes": 1,
"tasks_per_node": 1,
"gpus_per_node": num_gpus,
"cpus_per_task": num_gpus * 20,
"cpus_per_task": num_gpus * 10,
"mem_gb": num_gpus * 40.0 + 20.0,
"stderr_to_stdout": True,
}
Expand All @@ -41,10 +41,10 @@
"dir_data_types": "computed",
"training_data_dirs": "/nfs/ap/mnt/sxtn/rdkit_computed_rel+form/train_rdkit_computed_rel+form",
"valid_data_dir": "/nfs/ap/mnt/sxtn/rdkit_computed_rel+form/valid_rdkit_computed_rel+form",
"max_steps": 20000,
"max_steps": 19000,
# "num_train_epochs": 24,
"eval_steps": 2000,
"save_steps": 2000,
"eval_steps": 0,
"save_steps": 6300,
"train_batch_size": 16,
"valid_batch_size": 16,
"dataloader_num_workers": 1,
Expand Down
2 changes: 1 addition & 1 deletion test_status.yaml
Original file line number Diff line number Diff line change
@@ -1 +1 @@
33f0567f336bd041b6b687cb3258a855d948b6b8: PASS
210ba73adff95aacc5f6c215dc6218923d85f767: FAIL

0 comments on commit 6e6e81a

Please sign in to comment.