diff --git a/chemlactica/config/test_configs/fsdp_config.yaml b/chemlactica/config/test_configs/fsdp_config.yaml index 8665e42..b7cb6ce 100644 --- a/chemlactica/config/test_configs/fsdp_config.yaml +++ b/chemlactica/config/test_configs/fsdp_config.yaml @@ -1,13 +1,15 @@ compute_environment: LOCAL_MACHINE +debug: false distributed_type: FSDP downcast_bf16: 'no' fsdp_config: fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP fsdp_backward_prefetch: BACKWARD_PRE + fsdp_cpu_ram_efficient_loading: false + fsdp_forward_prefetch: false fsdp_offload_params: false - fsdp_sharding_strategy: 1 + fsdp_sharding_strategy: FULL_SHARD fsdp_state_dict_type: FULL_STATE_DICT - fsdp_transformer_layer_cls_to_wrap: OPTForCausalLM machine_rank: 0 main_training_function: main mixed_precision: bf16 @@ -19,4 +21,3 @@ tpu_env: [] tpu_use_cluster: false tpu_use_sudo: false use_cpu: false -main_process_port: 30001 diff --git a/chemlactica/custom_trainer.py b/chemlactica/custom_trainer.py index a4b9a94..15fcf8e 100644 --- a/chemlactica/custom_trainer.py +++ b/chemlactica/custom_trainer.py @@ -1,4 +1,3 @@ -import shutil import torch import submitit from typing import Any, Dict @@ -6,9 +5,6 @@ from torch._tensor import Tensor from torch.nn.modules import Module from transformers import Trainer, TrainingArguments -from torch.distributed.fsdp.fully_sharded_data_parallel import ( - FullyShardedDataParallel as FSDP, -) from chemlactica.utils.utils import get_tokenizer from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR from transformers.utils import is_torch_tpu_available @@ -41,23 +37,6 @@ def training_step(self, model: Module, inputs: Dict[str, Tensor | Any]) -> Tenso self.num_samples_to_print = None return super().training_step(model, inputs) - def _save_checkpoint(self, model, trial, metrics=None): - if shutil.disk_usage("/").free > 3 * 1024**3: - super()._save_checkpoint(model, trial, metrics=None) - else: - print("**disk is full didn't save**") - - def _load_from_checkpoint(self, resume_from_checkpoint, model=None): - """ - This code is added because we had a failure when resuming training. - Basically, we load the model with fsdp when the model is not fsdp wrapped. - In the future versions transformers this issue is handled, by adding an extra check, - but not in 4.31.0 version. So this is our manual check addition to solve the problem. - """ - if type(self.model) != FSDP: - return - return super()._load_from_checkpoint(resume_from_checkpoint, model) - def _build_slurm_eval_command(self, train_command, trial): checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}" run_dir = self._get_output_dir(trial=trial) diff --git a/chemlactica/utils/model_utils.py b/chemlactica/utils/model_utils.py index 5937666..899651d 100644 --- a/chemlactica/utils/model_utils.py +++ b/chemlactica/utils/model_utils.py @@ -90,11 +90,10 @@ def load_model( model = OPTForCausalLM.from_pretrained( from_pretrained, torch_dtype=dtype, attn_implementation=attn_implementation ) - print(type(model.lm_head)) - model.lm_head = float_casting_decorator(model.lm_head.__class__)( - in_features=model.lm_head.in_features, - out_features=model.lm_head.out_features, - ) + # model.lm_head = float_casting_decorator(model.lm_head.__class__)( + # in_features=model.lm_head.in_features, + # out_features=model.lm_head.out_features, + # ) # model.lm_head.forward = cast_to_fp32(OPTForCausalLM.lm_head.forward) if "mistral" in from_pretrained.lower(): diff --git a/test_status.yaml b/test_status.yaml index 7f1325b..1e135cb 100644 --- a/test_status.yaml +++ b/test_status.yaml @@ -1 +1 @@ -86eb89a6651607c56835456eb9d2f0ae6cd222cc: PASS +d438c06d55ece4b8ccd55e414172079d0e98ef7f: PASS