Skip to content

Commit

Permalink
Merge pull request #10 from YerevaNN/rejection_sampling
Browse files Browse the repository at this point in the history
Rejection sampling
  • Loading branch information
philippguevorguian authored Mar 11, 2024
2 parents f4724dd + b098ab3 commit 248d4ea
Show file tree
Hide file tree
Showing 14 changed files with 1,074 additions and 8 deletions.
16 changes: 16 additions & 0 deletions chemlactica/config/create_finetine_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import yaml
import os

absolute_path = os.path.dirname(os.path.abspath(__file__))
relative_path = "models_train_config.yaml"
full_path = os.path.join(absolute_path, relative_path)


# read static fine-tine config
with open(os.path.join(absolute_path, "models_fine-tune_config.yaml"), "r") as f_:
model_fine_tune_configs = yaml.full_load(f_)

model_fine_tune_configs["125m"]["max_learning_rate"] = 1e-5
model_fine_tune_configs["125m"]["adam_beta1"] = 0.9
model_fine_tune_configs["125m"]["adam_beta2"] = 0.95
model_fine_tune_configs["125m"]["warmup_steps"] = 0
2 changes: 1 addition & 1 deletion chemlactica/config/create_train_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,4 +42,4 @@

model_train_configs["1.3b"]["warmup_steps"] = 500
model_train_configs["1.3b"]["max_learning_rate"] = 0.0014
model_train_configs["1.3b"]["global_gradient_norm"] = 1.0
model_train_configs["1.3b"]["global_gradient_norm"] = 1.0
118 changes: 118 additions & 0 deletions chemlactica/config/models_fine-tune_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
1.3b:
adam_beta1: 0.9
adam_beta2: 0.95
batch_size: 1
block_size: 2048
d_heads: 64
d_model: 2048
dropout_prob: 0.1
global_gradient_norm: 1.0
learning_rate_decay: 0.1
max_learning_rate: 2.0e-04
n_heads: 32
n_layers: 24
vocab_size: 50000
warmup_steps: 500
weight_decay: 0.1
125m:
adam_beta1: 0.9
adam_beta2: 0.95
batch_size: 500000
block_size: 2048
d_heads: 64
d_model: 768
dropout_prob: 0.1
eval_step: 256
global_gradient_norm: 1.0
learning_rate_decay: 0.1
max_learning_rate: 6.0e-4
n_heads: 12
n_layers: 12
vocab_size: 50000
warmup_steps: 500
weight_decay: 0.1
6.7b:
adam_beta1: 0.9
adam_beta2: 0.95
batch_size: 2000000
block_size: 2048
d_heads: 128
d_model: 4096
dropout_prob: 0.1
global_gradient_norm: 1.0
learning_rate_decay: 0.1
max_learning_rate: 1.2e-04
n_heads: 32
n_layers: 32
vocab_size: 50000
warmup_steps: 500
weight_decay: 0.1
small_opt:
adam_beta1: 0.9
adam_beta2: 0.95
batch_size: 2
block_size: 2048
dropout_prob: 0.1
ffn_dim: 16
global_gradient_norm: 1.0
hidden_size: 16
learning_rate_decay: 0.1
max_learning_rate: 2.0e-04
max_position_embeddings: 2048
num_attention_heads: 1
num_hidden_layers: 1
vocab_size: 50000
warmup_steps: 500
weight_decay: 0.1
word_embed_proj_dim: 16
mistral7b:
vocab_size: 32000
block_size: 2048
hidden_size: 4096
intermediate_size: 14336
num_hidden_layers: 32
num_attention_heads: 32
num_key_value_heads: 8
hidden_act: 'silu'
max_position_embeddings: 131072
initializer_range: 0.02
rms_norm_eps: 1.0e-6
use_cache: True
pad_token_id: None
bos_token_id: 1
eos_token_id: 2
tie_word_embeddings: False
rope_theta: 10000.0
sliding_windows: 1024
global_gradient_norm: 1.0
weight_decay: 0.1
adam_beta1: 0.9
adam_beta2: 0.95
llama2:
"_name_or_path": "meta-llama/Llama-2-7b-hf"
architectures: ["LlamaForCausalLM"]
bos_token_id: 1
eos_token_id: 2
hidden_act: "silu"
hidden_size: 4096
initializer_range: 0.02
intermediate_size: 11008
max_position_embeddings: 4096
model_type: "llama"
num_attention_heads: 32
num_hidden_layers: 32
num_key_value_heads: 32
pretraining_tp: 1
rms_norm_eps: 1e-05
rope_scaling: null
tie_word_embeddings: false
torch_dtype: "float16"
transformers_version: "4.31.0.dev0"
use_cache: true
vocab_size: 32000
dim: 4096
block_size: 2048
adam_beta1: 0.9
adam_beta2: 0.95
weight_decay: 0.1
warmup_steps: 2000
2 changes: 1 addition & 1 deletion chemlactica/config/test_configs/ddp_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ downcast_bf16: 'no'
gpu_ids: 0,1
machine_rank: 0
main_training_function: main
mixed_precision: 'no' # bf16
mixed_precision: bf16
num_machines: 1
num_processes: 2
rdzv_backend: static
Expand Down
23 changes: 23 additions & 0 deletions chemlactica/custom_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,21 @@
import submitit
from typing import Any, Dict
import os
import torch
from torch._tensor import Tensor
from torch.nn.modules import Module
from torch.distributed.fsdp.fully_sharded_data_parallel import (
FullyShardedDataParallel as FSDP,
)
from transformers import Trainer, TrainingArguments
from chemlactica.utils.utils import get_tokenizer
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
from transformers.utils import is_torch_tpu_available
from trl import IterativeSFTTrainer
from chemlactica.utils.utils import get_tokenizer
from dataclasses import dataclass, field


if is_torch_tpu_available(check_device=False):
import torch_xla.core.xla_model as xm

Expand Down Expand Up @@ -127,3 +134,19 @@ def _maybe_log_save_evaluate(
if not metric_to_check.startswith("eval_"):
metric_to_check = f"eval_{metric_to_check}"
self.lr_scheduler.step(metrics[metric_to_check])


class CustomIterativeSFTTrainer(IterativeSFTTrainer):

def __init__(self, *args, **kwargs):
# the number of samples to print when the training begins, for debugging purposes
self.num_samples_to_print = 5
super().__init__(*args, **kwargs)

def training_step(self, model: Module, inputs: Dict[str, Tensor | Any]) -> Tensor:
if self.num_samples_to_print:
tokeinzer = get_tokenizer()
for i in range(min(inputs["input_ids"].size(0), self.num_samples_to_print)):
print(f"Sample {i + 1}:", tokeinzer.decode(inputs["input_ids"][i]))
self.num_samples_to_print = None
return super().training_step(model, inputs)
120 changes: 120 additions & 0 deletions chemlactica/generation/generation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
from argparse import ArgumentParser
from typing import List
import torch

from model_utils import load_model
from utils import get_tokenizer


def generate(prompts: List[str], model, **gen_kwargs):
if type(prompts) == str:
prompts = [prompts]
tokenizer = get_tokenizer()

generation_dict = {}
for prompt in prompts:
data = tokenizer(prompt, return_tensors="pt", padding=True).to(model.device)
outputs = model.generate(
input_ids=data.input_ids,
**gen_kwargs
)
if not generation_dict.get(prompt):
generation_dict[prompt] = []
for out in outputs:
generation_dict[prompt].append(tokenizer.decode(out[len(data):]))
return generation_dict


if __name__ == "__main__":
parser = ArgumentParser()

parser.add_argument(
"--prompts",
type=str,
nargs="*",
required=True,
)
parser.add_argument(
"--checkpoint_path",
type=str,
required=True,
)
parser.add_argument(
"--flash_attn",
action="store_true",
dest="use_flash_attn",
help="whether or not to use flash attn)",
)
parser.add_argument(
"--no_flash_attn",
action="store_false",
dest="use_flash_attn",
help="whether or not to use flash attn",
)
parser.set_defaults(use_flash_attn=False)
parser.add_argument(
"--device",
type=str,
required=True
)
parser.add_argument(
"--max_new_tokens",
type=int,
required=False,
)
parser.add_argument(
"--do_sample",
action="store_true",
dest="do_sample",
)
parser.add_argument(
"--no_do_sample",
action="store_false",
dest="do_sample",
)
parser.set_defaults(do_sample=False)
parser.add_argument(
"--num_return_sequences",
type=int,
required=False,
default=None
)
parser.add_argument(
"--diversity_penalty",
type=int,
required=False,
default=None
)
parser.add_argument(
"--repetition_penalty",
type=int,
required=False,
default=None
)
parser.add_argument(
"--length_penalty",
type=int,
required=False,
default=None
)
parser.add_argument(
"--num_beams",
type=int,
required=False,
default=None,
)
parser.add_argument(
"--num_beams_groups",
type=int,
required=False,
default=None,
)

args = parser.parse_args()
args = {key: value for key, value in args.__dict__.items() if value != None}

# generate(
# args.pop("prompts"), args.pop("checkpoint_path"),
# args.pop("use_flash_attn"), device=args.pop("device"),
# **args
# )
15 changes: 15 additions & 0 deletions chemlactica/generation/rejection_sampling_configs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
sample_gen_args = {
"max_new_tokens": 50,
"temperature": 1.0,
"repetition_penalty": 1.0,
"do_sample": True,
"eos_token_id": 2
}
rej_sample_args = {
"max_new_tokens": 300,
"temperature": 1.0,
"repetition_penalty": 1.0,
"do_sample": True,
"num_return_sequences": 20,
"eos_token_id": 20,
}
Loading

0 comments on commit 248d4ea

Please sign in to comment.