Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rejection sampling #10

Merged
merged 30 commits into from
Mar 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
cf7c304
add rejection sampling fine tunning
tigranfah Feb 1, 2024
d082b4f
fixes
tigranfah Feb 2, 2024
184ecbd
change lr to 1e-4
tigranfah Feb 3, 2024
617c791
change lr anneal from 1e-4 to 1e-5
tigranfah Feb 4, 2024
b4286c3
add length constrain to inputs, to fit 64 batch size
tigranfah Feb 4, 2024
ddd9c92
initial iteration of rejection sampling for single moleclue
tigranfah Feb 7, 2024
cae4c29
polish initial iteration
tigranfah Feb 7, 2024
db23d1e
run for single molecule
tigranfah Feb 7, 2024
7ffde13
run for single molecule (pick the best)
tigranfah Feb 7, 2024
b5dbacd
minor changes
tigranfah Feb 9, 2024
e4effad
initial version of single molecule optim + databank
tigranfah Feb 11, 2024
7a308b6
small fix
tigranfah Feb 11, 2024
1bf33cf
small fix
tigranfah Feb 11, 2024
bf34c80
small fix
tigranfah Feb 12, 2024
a863984
add learning rate to cl args
tigranfah Feb 12, 2024
1b1cac0
remove 'dbank' training
tigranfah Feb 15, 2024
32ec80e
clean up
tigranfah Feb 15, 2024
aee016f
polish
tigranfah Feb 16, 2024
0a284d1
import refac
tigranfah Feb 16, 2024
1dffd04
rej_sampling merge
tigranfah Feb 26, 2024
13d85a3
resolve merge conflicts
tigranfah Feb 26, 2024
f26d89a
remove accelerate parameter
tigranfah Feb 26, 2024
c1f73ee
remove accelerate parameter
tigranfah Feb 26, 2024
8351ec8
merge main into rej_sample branch
tigranfah Feb 28, 2024
09844cc
reduce the num of train steps in CI test
tigranfah Feb 28, 2024
3f98893
tests passed
tigranfah Feb 28, 2024
c518809
resolve issues
tigranfah Mar 8, 2024
93acfa2
add tests
tigranfah Mar 8, 2024
4201dde
Merge branch 'main' into rejection_sampling
philippguevorguian Mar 11, 2024
b098ab3
Update test_status.yaml
philippguevorguian Mar 11, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions chemlactica/config/create_finetine_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import yaml
import os

absolute_path = os.path.dirname(os.path.abspath(__file__))
relative_path = "models_train_config.yaml"
full_path = os.path.join(absolute_path, relative_path)


# read static fine-tine config
with open(os.path.join(absolute_path, "models_fine-tune_config.yaml"), "r") as f_:
model_fine_tune_configs = yaml.full_load(f_)

model_fine_tune_configs["125m"]["max_learning_rate"] = 1e-5
model_fine_tune_configs["125m"]["adam_beta1"] = 0.9
model_fine_tune_configs["125m"]["adam_beta2"] = 0.95
model_fine_tune_configs["125m"]["warmup_steps"] = 0
2 changes: 1 addition & 1 deletion chemlactica/config/create_train_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,4 +42,4 @@

model_train_configs["1.3b"]["warmup_steps"] = 500
model_train_configs["1.3b"]["max_learning_rate"] = 0.0014
model_train_configs["1.3b"]["global_gradient_norm"] = 1.0
model_train_configs["1.3b"]["global_gradient_norm"] = 1.0
118 changes: 118 additions & 0 deletions chemlactica/config/models_fine-tune_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
1.3b:
adam_beta1: 0.9
adam_beta2: 0.95
batch_size: 1
block_size: 2048
d_heads: 64
d_model: 2048
dropout_prob: 0.1
global_gradient_norm: 1.0
learning_rate_decay: 0.1
max_learning_rate: 2.0e-04
n_heads: 32
n_layers: 24
vocab_size: 50000
warmup_steps: 500
weight_decay: 0.1
125m:
adam_beta1: 0.9
adam_beta2: 0.95
batch_size: 500000
block_size: 2048
d_heads: 64
d_model: 768
dropout_prob: 0.1
eval_step: 256
global_gradient_norm: 1.0
learning_rate_decay: 0.1
max_learning_rate: 6.0e-4
n_heads: 12
n_layers: 12
vocab_size: 50000
warmup_steps: 500
weight_decay: 0.1
6.7b:
adam_beta1: 0.9
adam_beta2: 0.95
batch_size: 2000000
block_size: 2048
d_heads: 128
d_model: 4096
dropout_prob: 0.1
global_gradient_norm: 1.0
learning_rate_decay: 0.1
max_learning_rate: 1.2e-04
n_heads: 32
n_layers: 32
vocab_size: 50000
warmup_steps: 500
weight_decay: 0.1
small_opt:
adam_beta1: 0.9
adam_beta2: 0.95
batch_size: 2
block_size: 2048
dropout_prob: 0.1
ffn_dim: 16
global_gradient_norm: 1.0
hidden_size: 16
learning_rate_decay: 0.1
max_learning_rate: 2.0e-04
max_position_embeddings: 2048
num_attention_heads: 1
num_hidden_layers: 1
vocab_size: 50000
warmup_steps: 500
weight_decay: 0.1
word_embed_proj_dim: 16
mistral7b:
vocab_size: 32000
block_size: 2048
hidden_size: 4096
intermediate_size: 14336
num_hidden_layers: 32
num_attention_heads: 32
num_key_value_heads: 8
hidden_act: 'silu'
max_position_embeddings: 131072
initializer_range: 0.02
rms_norm_eps: 1.0e-6
use_cache: True
pad_token_id: None
bos_token_id: 1
eos_token_id: 2
tie_word_embeddings: False
rope_theta: 10000.0
sliding_windows: 1024
global_gradient_norm: 1.0
weight_decay: 0.1
adam_beta1: 0.9
adam_beta2: 0.95
llama2:
"_name_or_path": "meta-llama/Llama-2-7b-hf"
architectures: ["LlamaForCausalLM"]
bos_token_id: 1
eos_token_id: 2
hidden_act: "silu"
hidden_size: 4096
initializer_range: 0.02
intermediate_size: 11008
max_position_embeddings: 4096
model_type: "llama"
num_attention_heads: 32
num_hidden_layers: 32
num_key_value_heads: 32
pretraining_tp: 1
rms_norm_eps: 1e-05
rope_scaling: null
tie_word_embeddings: false
torch_dtype: "float16"
transformers_version: "4.31.0.dev0"
use_cache: true
vocab_size: 32000
dim: 4096
block_size: 2048
adam_beta1: 0.9
adam_beta2: 0.95
weight_decay: 0.1
warmup_steps: 2000
2 changes: 1 addition & 1 deletion chemlactica/config/test_configs/ddp_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ downcast_bf16: 'no'
gpu_ids: 0,1
machine_rank: 0
main_training_function: main
mixed_precision: 'no' # bf16
mixed_precision: bf16
num_machines: 1
num_processes: 2
rdzv_backend: static
Expand Down
23 changes: 23 additions & 0 deletions chemlactica/custom_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,21 @@
import submitit
from typing import Any, Dict
import os
import torch
from torch._tensor import Tensor
from torch.nn.modules import Module
from torch.distributed.fsdp.fully_sharded_data_parallel import (
FullyShardedDataParallel as FSDP,
)
from transformers import Trainer, TrainingArguments
from chemlactica.utils.utils import get_tokenizer
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
from transformers.utils import is_torch_tpu_available
from trl import IterativeSFTTrainer
from chemlactica.utils.utils import get_tokenizer
from dataclasses import dataclass, field


if is_torch_tpu_available(check_device=False):
import torch_xla.core.xla_model as xm

Expand Down Expand Up @@ -127,3 +134,19 @@ def _maybe_log_save_evaluate(
if not metric_to_check.startswith("eval_"):
metric_to_check = f"eval_{metric_to_check}"
self.lr_scheduler.step(metrics[metric_to_check])


class CustomIterativeSFTTrainer(IterativeSFTTrainer):

def __init__(self, *args, **kwargs):
# the number of samples to print when the training begins, for debugging purposes
self.num_samples_to_print = 5
super().__init__(*args, **kwargs)

def training_step(self, model: Module, inputs: Dict[str, Tensor | Any]) -> Tensor:
philippguevorguian marked this conversation as resolved.
Show resolved Hide resolved
if self.num_samples_to_print:
tokeinzer = get_tokenizer()
for i in range(min(inputs["input_ids"].size(0), self.num_samples_to_print)):
print(f"Sample {i + 1}:", tokeinzer.decode(inputs["input_ids"][i]))
self.num_samples_to_print = None
return super().training_step(model, inputs)
120 changes: 120 additions & 0 deletions chemlactica/generation/generation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
from argparse import ArgumentParser
from typing import List
import torch

from model_utils import load_model
from utils import get_tokenizer


def generate(prompts: List[str], model, **gen_kwargs):
if type(prompts) == str:
prompts = [prompts]
tokenizer = get_tokenizer()

generation_dict = {}
for prompt in prompts:
data = tokenizer(prompt, return_tensors="pt", padding=True).to(model.device)
outputs = model.generate(
input_ids=data.input_ids,
**gen_kwargs
)
if not generation_dict.get(prompt):
generation_dict[prompt] = []
for out in outputs:
generation_dict[prompt].append(tokenizer.decode(out[len(data):]))
return generation_dict


if __name__ == "__main__":
parser = ArgumentParser()

parser.add_argument(
"--prompts",
type=str,
nargs="*",
required=True,
)
parser.add_argument(
"--checkpoint_path",
type=str,
required=True,
)
parser.add_argument(
"--flash_attn",
action="store_true",
dest="use_flash_attn",
help="whether or not to use flash attn)",
)
parser.add_argument(
"--no_flash_attn",
action="store_false",
dest="use_flash_attn",
help="whether or not to use flash attn",
)
parser.set_defaults(use_flash_attn=False)
parser.add_argument(
"--device",
type=str,
required=True
)
parser.add_argument(
"--max_new_tokens",
type=int,
required=False,
)
parser.add_argument(
"--do_sample",
action="store_true",
dest="do_sample",
)
parser.add_argument(
"--no_do_sample",
action="store_false",
dest="do_sample",
)
parser.set_defaults(do_sample=False)
parser.add_argument(
"--num_return_sequences",
type=int,
required=False,
default=None
)
parser.add_argument(
"--diversity_penalty",
type=int,
required=False,
default=None
)
parser.add_argument(
"--repetition_penalty",
type=int,
required=False,
default=None
)
parser.add_argument(
"--length_penalty",
type=int,
required=False,
default=None
)
parser.add_argument(
"--num_beams",
type=int,
required=False,
default=None,
)
parser.add_argument(
"--num_beams_groups",
type=int,
required=False,
default=None,
)

args = parser.parse_args()
args = {key: value for key, value in args.__dict__.items() if value != None}

# generate(
# args.pop("prompts"), args.pop("checkpoint_path"),
# args.pop("use_flash_attn"), device=args.pop("device"),
# **args
# )
15 changes: 15 additions & 0 deletions chemlactica/generation/rejection_sampling_configs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
sample_gen_args = {
"max_new_tokens": 50,
"temperature": 1.0,
"repetition_penalty": 1.0,
"do_sample": True,
"eos_token_id": 2
}
rej_sample_args = {
"max_new_tokens": 300,
"temperature": 1.0,
"repetition_penalty": 1.0,
"do_sample": True,
"num_return_sequences": 20,
"eos_token_id": 20,
}
Loading
Loading