Skip to content

Commit

Permalink
Merge pull request #8 from YerevaNN/fix_ci_tests
Browse files Browse the repository at this point in the history
Fix ci tests
  • Loading branch information
philippguevorguian authored Feb 27, 2024
2 parents bd0dfa1 + 2a9f3ee commit b2a0fd8
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 30 deletions.
7 changes: 4 additions & 3 deletions chemlactica/config/test_configs/fsdp_config.yaml
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
compute_environment: LOCAL_MACHINE
debug: false
distributed_type: FSDP
downcast_bf16: 'no'
fsdp_config:
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_backward_prefetch: BACKWARD_PRE
fsdp_cpu_ram_efficient_loading: false
fsdp_forward_prefetch: false
fsdp_offload_params: false
fsdp_sharding_strategy: 1
fsdp_sharding_strategy: FULL_SHARD
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_transformer_layer_cls_to_wrap: OPTForCausalLM
machine_rank: 0
main_training_function: main
mixed_precision: bf16
Expand All @@ -19,4 +21,3 @@ tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
main_process_port: 30001
21 changes: 0 additions & 21 deletions chemlactica/custom_trainer.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,10 @@
import shutil
import torch
import submitit
from typing import Any, Dict
import os
from torch._tensor import Tensor
from torch.nn.modules import Module
from transformers import Trainer, TrainingArguments
from torch.distributed.fsdp.fully_sharded_data_parallel import (
FullyShardedDataParallel as FSDP,
)
from chemlactica.utils.utils import get_tokenizer
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
from transformers.utils import is_torch_tpu_available
Expand Down Expand Up @@ -41,23 +37,6 @@ def training_step(self, model: Module, inputs: Dict[str, Tensor | Any]) -> Tenso
self.num_samples_to_print = None
return super().training_step(model, inputs)

def _save_checkpoint(self, model, trial, metrics=None):
if shutil.disk_usage("/").free > 3 * 1024**3:
super()._save_checkpoint(model, trial, metrics=None)
else:
print("**disk is full didn't save**")

def _load_from_checkpoint(self, resume_from_checkpoint, model=None):
"""
This code is added because we had a failure when resuming training.
Basically, we load the model with fsdp when the model is not fsdp wrapped.
In the future versions transformers this issue is handled, by adding an extra check,
but not in 4.31.0 version. So this is our manual check addition to solve the problem.
"""
if type(self.model) != FSDP:
return
return super()._load_from_checkpoint(resume_from_checkpoint, model)

def _build_slurm_eval_command(self, train_command, trial):
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
run_dir = self._get_output_dir(trial=trial)
Expand Down
9 changes: 4 additions & 5 deletions chemlactica/utils/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,11 +90,10 @@ def load_model(
model = OPTForCausalLM.from_pretrained(
from_pretrained, torch_dtype=dtype, attn_implementation=attn_implementation
)
print(type(model.lm_head))
model.lm_head = float_casting_decorator(model.lm_head.__class__)(
in_features=model.lm_head.in_features,
out_features=model.lm_head.out_features,
)
# model.lm_head = float_casting_decorator(model.lm_head.__class__)(
# in_features=model.lm_head.in_features,
# out_features=model.lm_head.out_features,
# )
# model.lm_head.forward = cast_to_fp32(OPTForCausalLM.lm_head.forward)

if "mistral" in from_pretrained.lower():
Expand Down
2 changes: 1 addition & 1 deletion test_status.yaml
Original file line number Diff line number Diff line change
@@ -1 +1 @@
86eb89a6651607c56835456eb9d2f0ae6cd222cc: PASS
d438c06d55ece4b8ccd55e414172079d0e98ef7f: PASS

0 comments on commit b2a0fd8

Please sign in to comment.