Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix ci tests #8

Merged
merged 4 commits into from
Feb 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions chemlactica/config/test_configs/fsdp_config.yaml
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
compute_environment: LOCAL_MACHINE
debug: false
distributed_type: FSDP
downcast_bf16: 'no'
fsdp_config:
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_backward_prefetch: BACKWARD_PRE
fsdp_cpu_ram_efficient_loading: false
fsdp_forward_prefetch: false
fsdp_offload_params: false
fsdp_sharding_strategy: 1
fsdp_sharding_strategy: FULL_SHARD
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_transformer_layer_cls_to_wrap: OPTForCausalLM
machine_rank: 0
main_training_function: main
mixed_precision: bf16
Expand All @@ -19,4 +21,3 @@ tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
main_process_port: 30001
21 changes: 0 additions & 21 deletions chemlactica/custom_trainer.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,10 @@
import shutil
import torch
import submitit
from typing import Any, Dict
import os
from torch._tensor import Tensor
from torch.nn.modules import Module
from transformers import Trainer, TrainingArguments
from torch.distributed.fsdp.fully_sharded_data_parallel import (
FullyShardedDataParallel as FSDP,
)
from chemlactica.utils.utils import get_tokenizer
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
from transformers.utils import is_torch_tpu_available
Expand Down Expand Up @@ -41,23 +37,6 @@ def training_step(self, model: Module, inputs: Dict[str, Tensor | Any]) -> Tenso
self.num_samples_to_print = None
return super().training_step(model, inputs)

def _save_checkpoint(self, model, trial, metrics=None):
if shutil.disk_usage("/").free > 3 * 1024**3:
super()._save_checkpoint(model, trial, metrics=None)
else:
print("**disk is full didn't save**")

def _load_from_checkpoint(self, resume_from_checkpoint, model=None):
"""
This code is added because we had a failure when resuming training.
Basically, we load the model with fsdp when the model is not fsdp wrapped.
In the future versions transformers this issue is handled, by adding an extra check,
but not in 4.31.0 version. So this is our manual check addition to solve the problem.
"""
if type(self.model) != FSDP:
return
return super()._load_from_checkpoint(resume_from_checkpoint, model)

def _build_slurm_eval_command(self, train_command, trial):
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
run_dir = self._get_output_dir(trial=trial)
Expand Down
9 changes: 4 additions & 5 deletions chemlactica/utils/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,11 +90,10 @@ def load_model(
model = OPTForCausalLM.from_pretrained(
from_pretrained, torch_dtype=dtype, attn_implementation=attn_implementation
)
print(type(model.lm_head))
model.lm_head = float_casting_decorator(model.lm_head.__class__)(
in_features=model.lm_head.in_features,
out_features=model.lm_head.out_features,
)
# model.lm_head = float_casting_decorator(model.lm_head.__class__)(
# in_features=model.lm_head.in_features,
# out_features=model.lm_head.out_features,
# )
# model.lm_head.forward = cast_to_fp32(OPTForCausalLM.lm_head.forward)

if "mistral" in from_pretrained.lower():
Expand Down
2 changes: 1 addition & 1 deletion test_status.yaml
Original file line number Diff line number Diff line change
@@ -1 +1 @@
86eb89a6651607c56835456eb9d2f0ae6cd222cc: PASS
d438c06d55ece4b8ccd55e414172079d0e98ef7f: PASS
Loading