Skip to content

Commit

Permalink
pre merge
Browse files Browse the repository at this point in the history
  • Loading branch information
philippguevorguian committed May 22, 2024
1 parent d95e4d4 commit 1306842
Show file tree
Hide file tree
Showing 2 changed files with 127 additions and 125 deletions.
248 changes: 124 additions & 124 deletions chemlactica/mol_opt/optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,132 +64,132 @@ def optimize(
oracle, config,
additional_properties={}
):
file = open(config["log_dir"], "w")
print("config", config)
# print("molecule generation arguments", config["generation_config"])
pool = Pool(config["pool_size"], validation_perc=config["validation_perc"])

max_score = 0
tol_level = 0
num_iter = 0
prev_train_iter = 0
while True:
model.eval()
iter_optim_entries: List[OptimEntry] = []
while len(iter_optim_entries) < config["num_gens_per_iter"]:
optim_entries = create_optimization_entries(
config["num_gens_per_iter"], pool,
config=config
)
for i in range(len(optim_entries)):
last_entry = MoleculeEntry(smiles="")
last_entry.similar_mol_entries = create_similar_mol_entries(
pool, last_entry, config["num_similars"]
with open(config["log_dir"], "w") as file:
print("config", config)
# print("molecule generation arguments", config["generation_config"])
pool = Pool(config["pool_size"], validation_perc=config["validation_perc"])

max_score = 0
tol_level = 0
num_iter = 0
prev_train_iter = 0
while True:
model.eval()
iter_optim_entries: List[OptimEntry] = []
while len(iter_optim_entries) < config["num_gens_per_iter"]:
optim_entries = create_optimization_entries(
config["num_gens_per_iter"], pool,
config=config
)
for prop_name, prop_spec in additional_properties.items():
last_entry.add_props[prop_name] = prop_spec
optim_entries[i].last_entry = last_entry

prompts = [
optim_entry.to_prompt(is_generation=True, include_oracle_score=prev_train_iter != 0, config=config)
for optim_entry in optim_entries
]
output_texts = []
for i in range(0, len(prompts), config["generation_batch_size"]):
prompt_batch = prompts[i: min(len(prompts), i + config["generation_batch_size"])]
data = tokenizer(prompt_batch, return_tensors="pt", padding=True).to(model.device)
if type(model) == OPTForCausalLM:
del data["token_type_ids"]
for key, value in data.items():
data[key] = value[:, -2048 + config["generation_config"]["max_new_tokens"]:]
output = model.generate(
**data,
**config["generation_config"]
)
gc.collect()
torch.cuda.empty_cache()
output_texts.extend(tokenizer.batch_decode(output))

current_mol_entries = []
current_optim_entries = []
with multiprocessing.Pool(processes=config["num_processes"]) as pol:
for i, entry in enumerate(pol.map(create_molecule_entry, output_texts)):
if entry and not optim_entries[i].contains_entry(entry):
current_mol_entries.append(entry)
current_optim_entries.append(optim_entries[i])

if getattr(oracle, "takes_entry", False):
oracle_scores = oracle(current_mol_entries)
else:
oracle_scores = oracle([e.smiles for e in current_mol_entries])
for i, oracle_score in enumerate(oracle_scores):
entry = current_mol_entries[i]
entry.score = oracle_score
entry.similar_mol_entries = current_optim_entries[i].last_entry.similar_mol_entries
for prop_name, prop_spec in additional_properties.items():
entry.add_props[prop_name] = prop_spec
entry.add_props[prop_name]["value"] = entry.add_props[prop_name]["calculate_value"](entry)
current_optim_entries[i].last_entry = entry
iter_optim_entries.append(current_optim_entries[i])
file.write(f"generated smiles: {entry.smiles}, score: {entry.score:.4f}\n")
if entry.score > max_score:
max_score = entry.score
tol_level = 0
if oracle.finish or len(iter_optim_entries) >= config["num_gens_per_iter"]:
for i in range(len(optim_entries)):
last_entry = MoleculeEntry(smiles="")
last_entry.similar_mol_entries = create_similar_mol_entries(
pool, last_entry, config["num_similars"]
)
for prop_name, prop_spec in additional_properties.items():
last_entry.add_props[prop_name] = prop_spec
optim_entries[i].last_entry = last_entry

prompts = [
optim_entry.to_prompt(is_generation=True, include_oracle_score=prev_train_iter != 0, config=config)
for optim_entry in optim_entries
]
output_texts = []
for i in range(0, len(prompts), config["generation_batch_size"]):
prompt_batch = prompts[i: min(len(prompts), i + config["generation_batch_size"])]
data = tokenizer(prompt_batch, return_tensors="pt", padding=True).to(model.device)
if type(model) == OPTForCausalLM:
del data["token_type_ids"]
for key, value in data.items():
data[key] = value[:, -2048 + config["generation_config"]["max_new_tokens"]:]
output = model.generate(
**data,
**config["generation_config"]
)
gc.collect()
torch.cuda.empty_cache()
output_texts.extend(tokenizer.batch_decode(output))

current_mol_entries = []
current_optim_entries = []
with multiprocessing.Pool(processes=config["num_processes"]) as pol:
for i, entry in enumerate(pol.map(create_molecule_entry, output_texts)):
if entry and not optim_entries[i].contains_entry(entry):
current_mol_entries.append(entry)
current_optim_entries.append(optim_entries[i])

if getattr(oracle, "takes_entry", False):
oracle_scores = oracle(current_mol_entries)
else:
oracle_scores = oracle([e.smiles for e in current_mol_entries])
for i, oracle_score in enumerate(oracle_scores):
entry = current_mol_entries[i]
entry.score = oracle_score
entry.similar_mol_entries = current_optim_entries[i].last_entry.similar_mol_entries
for prop_name, prop_spec in additional_properties.items():
entry.add_props[prop_name] = prop_spec
entry.add_props[prop_name]["value"] = entry.add_props[prop_name]["calculate_value"](entry)
current_optim_entries[i].last_entry = entry
iter_optim_entries.append(current_optim_entries[i])
file.write(f"generated smiles: {entry.smiles}, score: {entry.score:.4f}\n")
if entry.score > max_score:
max_score = entry.score
tol_level = 0
if oracle.finish or len(iter_optim_entries) >= config["num_gens_per_iter"]:
break

if oracle.finish:
break

if oracle.finish:
break

if oracle.finish:
break
initial_num_iter = num_iter
num_iter = len(oracle.mol_buffer) // config["num_gens_per_iter"]
if num_iter > initial_num_iter:
tol_level += 1
print(f"num_iter: {num_iter}, tol_level: {tol_level}, prev_train_iter: {prev_train_iter}")

# diversity_score = 1 / (1 + math.log(1 + repeated_max_score) / math.log(10))
pool.add(iter_optim_entries)
file.write("Pool\n")
for i, optim_entry in enumerate(pool.optim_entries):
file.write(f"\t{i} smiles: {optim_entry.last_entry.smiles}, score: {optim_entry.last_entry.score:.4f}\n")

if "rej-sample-v2" in config["strategy"]:
# round_entries.extend(current_entries)
# round_entries = list(np.unique(round_entries))[::-1]
# top_k = int(len(all_entries) * config["rej_sample_config"]["rej_perc"])
# if top_k >= config["rej_sample_config"]["num_samples_per_round"]:
if config["rej_sample_config"]["should_train"](num_iter, tol_level, prev_train_iter):
train_entries, validation_entries = pool.get_train_valid_entries()
print(f"Num of training examples: {len(train_entries)}, num of validation examples: {len(validation_entries)}.")
file.write("Training entries\n")
for i, optim_entry in enumerate(train_entries):
file.write(f"\t{i} smiles: {optim_entry.last_entry.smiles}, score: {optim_entry.last_entry.score:.4f}\n")
file.write("Validation entries\n")
for i, optim_entry in enumerate(validation_entries):
file.write(f"\t{i} smiles: {optim_entry.last_entry.smiles}, score: {optim_entry.last_entry.score:.4f}\n")

train_dataset = Dataset.from_dict({
"sample": [
optim_entry.to_prompt(is_generation=False, include_oracle_score=True, config=config)
for optim_entry in train_entries
]
})
validation_dataset = Dataset.from_dict({
"sample": [
optim_entry.to_prompt(is_generation=False, include_oracle_score=True, config=config)
for optim_entry in validation_entries
]
})
train_dataset.shuffle(seed=42)
validation_dataset.shuffle(seed=42)
config["rej_sample_config"]["formatting_func"] = lambda x: x["sample"]
supervised_fine_tune(
model, tokenizer,
train_dataset, validation_dataset,
config["rej_sample_config"]
)
gc.collect()
torch.cuda.empty_cache()
prev_train_iter = num_iter
initial_num_iter = num_iter
num_iter = len(oracle.mol_buffer) // config["num_gens_per_iter"]
if num_iter > initial_num_iter:
tol_level += 1
print(f"num_iter: {num_iter}, tol_level: {tol_level}, prev_train_iter: {prev_train_iter}")

# diversity_score = 1 / (1 + math.log(1 + repeated_max_score) / math.log(10))
pool.add(iter_optim_entries)
file.write("Pool\n")
for i, optim_entry in enumerate(pool.optim_entries):
file.write(f"\t{i} smiles: {optim_entry.last_entry.smiles}, score: {optim_entry.last_entry.score:.4f}\n")

if "rej-sample-v2" in config["strategy"]:
# round_entries.extend(current_entries)
# round_entries = list(np.unique(round_entries))[::-1]
# top_k = int(len(all_entries) * config["rej_sample_config"]["rej_perc"])
# if top_k >= config["rej_sample_config"]["num_samples_per_round"]:
if config["rej_sample_config"]["should_train"](num_iter, tol_level, prev_train_iter):
train_entries, validation_entries = pool.get_train_valid_entries()
print(f"Num of training examples: {len(train_entries)}, num of validation examples: {len(validation_entries)}.")
file.write("Training entries\n")
for i, optim_entry in enumerate(train_entries):
file.write(f"\t{i} smiles: {optim_entry.last_entry.smiles}, score: {optim_entry.last_entry.score:.4f}\n")
file.write("Validation entries\n")
for i, optim_entry in enumerate(validation_entries):
file.write(f"\t{i} smiles: {optim_entry.last_entry.smiles}, score: {optim_entry.last_entry.score:.4f}\n")

train_dataset = Dataset.from_dict({
"sample": [
optim_entry.to_prompt(is_generation=False, include_oracle_score=True, config=config)
for optim_entry in train_entries
]
})
validation_dataset = Dataset.from_dict({
"sample": [
optim_entry.to_prompt(is_generation=False, include_oracle_score=True, config=config)
for optim_entry in validation_entries
]
})
train_dataset.shuffle(seed=42)
validation_dataset.shuffle(seed=42)
config["rej_sample_config"]["formatting_func"] = lambda x: x["sample"]
supervised_fine_tune(
model, tokenizer,
train_dataset, validation_dataset,
config["rej_sample_config"]
)
gc.collect()
torch.cuda.empty_cache()
prev_train_iter = num_iter
4 changes: 3 additions & 1 deletion chemlactica/mol_opt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ def __str__(self):

def __repr__(self):
return str(self)
def __hash__(self):
return hash(self.smiles)


class Pool:
Expand All @@ -94,7 +96,7 @@ def __init__(self, size, validation_perc: float):
# self.molecule_entries.pop(rand_ind)
# print(f"Dump {num} random elements from pool, num pool mols {len(self)}")

def add(self, entries: List, diversity_score=1.0):
def add(self, entries: List, diversity_score=1.r9):
assert type(entries) == list
self.optim_entries.extend(entries)
self.optim_entries.sort(key=lambda x: x.last_entry, reverse=True)
Expand Down

0 comments on commit 1306842

Please sign in to comment.