Skip to content

Commit

Permalink
Merge pull request #17 from YerevaNN/integrate_gemma
Browse files Browse the repository at this point in the history
Integrate gemma, fix submitit issues
  • Loading branch information
MenuaB authored Mar 26, 2024
2 parents 4bc6368 + e2d9a4a commit caa3fde
Show file tree
Hide file tree
Showing 20 changed files with 838,005 additions and 137 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@ train_config:
model_config:
block_size: 2048
vocab_size: 50000
separator_token: </s>
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@ train_config:
model_config:
block_size: 2048
vocab_size: 50000
separator_token: </s>
18 changes: 18 additions & 0 deletions chemlactica/config/config_yamls/gemma_2b_pretrain_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
train_config:
adam_beta1: 0.9
adam_beta2: 0.95
batch_size: 500000
dropout_prob: 0.1
eval_step: 256
global_gradient_norm: 1.0
learning_rate_decay: 0.1
max_learning_rate: 6.0e-4
n_heads: 12
n_layers: 18
warmup_steps: 500
weight_decay: 0.1
model_config:
block_size: 8192
vocab_size: 256000
separator_token: <bos>
tokenizer_path: "chemlactica/tokenizer/GemmaTokenizer"
3 changes: 2 additions & 1 deletion chemlactica/config/default_train_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
class ModelConfig:
block_size: int = 2048
vocab_size: int = 50000
separator_token: str = "</s>"
tokenizer_path: str = "chemlactica/tokenizer/ChemLacticaTokenizer66"


@dataclass
Expand All @@ -19,7 +21,6 @@ class TrainConfig:
max_learning_rate: float = 6.0e-4
warmup_steps: int = 500
weight_decay: float = 0.1
tokenizer_path: str = "chemlactica/tokenizer/ChemLacticaTokenizer66"


@dataclass
Expand Down
25 changes: 25 additions & 0 deletions chemlactica/config/galactica_accelerate_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
compute_environment: LOCAL_MACHINE
distributed_type: FSDP
downcast_bf16: 'no'
fsdp_config:
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_backward_prefetch: BACKWARD_PRE
fsdp_offload_params: false
fsdp_forward_prefetch: true
fsdp_sharding_strategy: 1
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_transformer_layer_cls_to_wrap: OPTForCausalLM
fsdp_activation_checkpointing: true
fsdp_use_orig_params: true
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 6
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
main_process_port: 30001
25 changes: 25 additions & 0 deletions chemlactica/config/gemma_accelerate_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
compute_environment: LOCAL_MACHINE
distributed_type: FSDP
downcast_bf16: 'no'
fsdp_config:
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_backward_prefetch: NO_PREFETCH
fsdp_offload_params: false
fsdp_forward_prefetch: false
fsdp_sharding_strategy: 1
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_transformer_layer_cls_to_wrap: GemmaForCausalLM
fsdp_activation_checkpointing: false
fsdp_use_orig_params: true
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 6
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
main_process_port: 30001
136 changes: 68 additions & 68 deletions chemlactica/custom_trainer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import submitit
from typing import Any, Dict
import os
import torch
from torch._tensor import Tensor
from torch.nn.modules import Module

Expand All @@ -10,14 +9,15 @@
# )
from transformers import Trainer, TrainingArguments
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
from transformers.utils import is_torch_tpu_available

# from transformers.utils import is_torch_tpuc _available
from trl import IterativeSFTTrainer
from chemlactica.utils.utils import get_tokenizer
from dataclasses import dataclass, field


if is_torch_tpu_available(check_device=False):
import torch_xla.core.xla_model as xm
# if is_torch_tpu_available(check_device=False):
# import torch_xla.core.xla_model as xm


@dataclass
Expand Down Expand Up @@ -70,70 +70,70 @@ def _submitit_slurm_eval_launch(self, eval_function):
exit()
job = executor.submit(eval_function) # noqa

def _maybe_log_save_evaluate(
self, tr_loss, model, trial, epoch, ignore_keys_for_eval
):
# this method is being overwritten because it currently
# runs evaluation prior to saving.
# For offling evaluation this doesn't work.
if (
self.control.should_log
and self.state.global_step > self._globalstep_last_logged
):
if is_torch_tpu_available():
xm.mark_step()

logs: Dict[str, float] = {}

# all_gather + mean() to get average loss over all processes
tr_loss_scalar = self._nested_gather(tr_loss).mean().item()

# reset tr_loss to zero
tr_loss -= tr_loss

logs["loss"] = round(
tr_loss_scalar
/ (self.state.global_step - self._globalstep_last_logged),
4,
)
logs["learning_rate"] = self._get_learning_rate()

self._total_loss_scalar += tr_loss_scalar
self._globalstep_last_logged = self.state.global_step
self.store_flos()

self.log(logs)

metrics = None

if self.control.should_save:
self._save_checkpoint(model, trial, metrics=metrics)
self.control = self.callback_handler.on_save(
self.args, self.state, self.control
)

if self.control.should_evaluate:
if self.args.slurm_eval:
# BUILD SLURM EVALUATE COMMAND
eval_command = self._build_slurm_eval_command(self.args.command, trial)
print("-----------------------------------------------")
print("starting slurm eval with command:", eval_command)
eval_function = submitit.helpers.CommandFunction(
eval_command, verbose=True, cwd=os.getcwd()
)
self._submitit_slurm_eval_launch(eval_function)

else:
metrics = self.evaluate(ignore_keys=ignore_keys_for_eval)
self._report_to_hp_search(trial, self.state.global_step, metrics)
# Run delayed LR scheduler now that metrics are populated
if isinstance(
self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau
):
metric_to_check = self.args.metric_for_best_model
if not metric_to_check.startswith("eval_"):
metric_to_check = f"eval_{metric_to_check}"
self.lr_scheduler.step(metrics[metric_to_check])
# def _maybe_log_save_evaluate(
# self, tr_loss, model, trial, epoch, ignore_keys_for_eval
# ):
# # this method is being overwritten because it currently
# # runs evaluation prior to saving.
# # For offling evaluation this doesn't work.
# if (
# self.control.should_log
# and self.state.global_step > self._globalstep_last_logged
# ):
# if is_torch_tpu_available():
# xm.mark_step()

# logs: Dict[str, float] = {}

# # all_gather + mean() to get average loss over all processes
# tr_loss_scalar = self._nested_gather(tr_loss).mean().item()

# # reset tr_loss to zero
# tr_loss -= tr_loss

# logs["loss"] = round(
# tr_loss_scalar
# / (self.state.global_step - self._globalstep_last_logged),
# 4,
# )
# logs["learning_rate"] = self._get_learning_rate()

# self._total_loss_scalar += tr_loss_scalar
# self._globalstep_last_logged = self.state.global_step
# self.store_flos()

# self.log(logs)

# metrics = None

# if self.control.should_save:
# self._save_checkpoint(model, trial, metrics=metrics)
# self.control = self.callback_handler.on_save(
# self.args, self.state, self.control
# )

# if self.control.should_evaluate:
# if self.args.slurm_eval:
# # BUILD SLURM EVALUATE COMMAND
# eval_command = self._build_slurm_eval_command(self.args.command, trial)
# print("-----------------------------------------------")
# print("starting slurm eval with command:", eval_command)
# eval_function = submitit.helpers.CommandFunction(
# eval_command, verbose=True, cwd=os.getcwd()
# )
# self._submitit_slurm_eval_launch(eval_function)

# else:
# metrics = self.evaluate(ignore_keys=ignore_keys_for_eval)
# self._report_to_hp_search(trial, self.state.global_step, metrics)
# # Run delayed LR scheduler now that metrics are populated
# if isinstance(
# self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau
# ):
# metric_to_check = self.args.metric_for_best_model
# if not metric_to_check.startswith("eval_"):
# metric_to_check = f"eval_{metric_to_check}"
# self.lr_scheduler.step(metrics[metric_to_check])


class CustomIterativeSFTTrainer(IterativeSFTTrainer):
Expand Down
6 changes: 3 additions & 3 deletions chemlactica/get_trainer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from custom_trainer import CustomTrainer
from trl import SFTTrainer, DataCollatorForCompletionOnlyLM

from chemlactica.eval_metrics import compute_metrics, preprocess_logits_for_metrics
# from chemlactica.eval_metrics import compute_metrics, preprocess_logits_for_metrics
from utils.dataset_utils import sft_formatting_prompts_func
from utils.utils import get_tokenizer
from config.default_train_config import SFTTrainConfig
Expand All @@ -12,13 +12,13 @@ def get_trainer(train_type, model, dataset, training_args, evaluate_only, slurm_
trainer = CustomTrainer(
model=model,
args=training_args,
compute_metrics=compute_metrics,
# compute_metrics=compute_metrics,
train_dataset=dataset["train"] if not evaluate_only else None,
eval_dataset=dataset["validation"]["validation"]
if not evaluate_only or slurm_eval
else None,
# optimizers=[optimizer, lr_scheduler],
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
# preprocess_logits_for_metrics=preprocess_logits_for_metrics,
)

elif train_type == "sft":
Expand Down
30 changes: 30 additions & 0 deletions chemlactica/tokenizer/GemmaTokenizer/special_tokens_map.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
{
"bos_token": {
"content": "<bos>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false
},
"eos_token": {
"content": "<eos>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false
},
"pad_token": {
"content": "<pad>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false
},
"unk_token": {
"content": "<unk>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false
}
}
Loading

0 comments on commit caa3fde

Please sign in to comment.