From 1e75320d592e4cd4c1b2a076178db3e2e33527ed Mon Sep 17 00:00:00 2001 From: Anas Awadalla Date: Wed, 21 Feb 2024 23:06:49 -0800 Subject: [PATCH] remove deepspeed, some fixes, and llava --- README.md | 2 +- open_flamingo/__init__.py | 1 + open_flamingo/eval/README.md | 2 - open_flamingo/eval/eval_models/eval_model.py | 33 +------- open_flamingo/eval/evaluate.py | 10 +-- open_flamingo/scripts/run_eval_deepspeed.sh | 77 ------------------- open_flamingo/scripts/run_train_deepspeed.sh | 41 ---------- open_flamingo/src/__init__.py | 1 + open_flamingo/src/cross_attn_lm.py | 21 +++++- open_flamingo/src/factory.py | 63 ++++++---------- open_flamingo/src/flamingo.py | 19 ++++- open_flamingo/src/helpers.py | 20 ++--- open_flamingo/src/llava.py | 55 ++++++++++++++ open_flamingo/src/vlm.py | 61 ++++++++++----- open_flamingo/train/README.md | 3 +- open_flamingo/train/data.py | 4 +- open_flamingo/train/distributed.py | 50 +------------ open_flamingo/train/train.py | 17 ++--- open_flamingo/train/train_utils.py | 79 ++++---------------- requirements-training.txt | 3 +- requirements.txt | 2 +- setup.py | 31 ++------ 22 files changed, 203 insertions(+), 392 deletions(-) delete mode 100644 open_flamingo/scripts/run_eval_deepspeed.sh delete mode 100644 open_flamingo/scripts/run_train_deepspeed.sh create mode 100644 open_flamingo/src/llava.py diff --git a/README.md b/README.md index de751f92..87a89bb9 100644 --- a/README.md +++ b/README.md @@ -102,7 +102,7 @@ To instantiate an OpenFlamingo model with one of our released weights, initializ from huggingface_hub import hf_hub_download import torch -checkpoint_path = hf_hub_download("openflamingo/OpenFlamingo-3B-vitl-mpt1b", "checkpoint.pt") +checkpoint_path = hf_hub_download("openflamingo/OpenFlamingo-4B-vitl-rpj3b", "checkpoint.pt") model.load_state_dict(torch.load(checkpoint_path), strict=False) ``` diff --git a/open_flamingo/__init__.py b/open_flamingo/__init__.py index ab77c87d..681f4f19 100644 --- a/open_flamingo/__init__.py +++ b/open_flamingo/__init__.py @@ -1,4 +1,5 @@ from .src.flamingo import Flamingo from .src.kosmos import Kosmos from .src.blip import BLIP +from .src.llava import Llava from .src.factory import create_model_and_transforms, SUPPORTED_MODEL_FAMILIES diff --git a/open_flamingo/eval/README.md b/open_flamingo/eval/README.md index 8835ac24..bbe6f665 100644 --- a/open_flamingo/eval/README.md +++ b/open_flamingo/eval/README.md @@ -30,8 +30,6 @@ To help standardize VLM evaluations, we have implemented EvalModel wrappers for ## Distributed evaluation Our codebase uses DistributedDataParallel to parallelize evaluation by default, so please make sure to set the `MASTER_ADDR` and `MASTER_PORT` environment variables or use `torchrun` (see sample scripts section below). -We have also implemented distributed evaluation using Deepspeed, which additionally shards model parameters across GPUs for memory savings. To use Deepspeed instead of DDP, use the `--deepspeed` flag. - We also support evaluating at a lower precision using the `--precision` flag. We find minimal difference between evaluating at full precision vs. amp_bf16. ## Sample scripts diff --git a/open_flamingo/eval/eval_models/eval_model.py b/open_flamingo/eval/eval_models/eval_model.py index c3cebb6d..d61739a3 100644 --- a/open_flamingo/eval/eval_models/eval_model.py +++ b/open_flamingo/eval/eval_models/eval_model.py @@ -31,7 +31,7 @@ def get_eval_model(name, *args, **kwargs): class BaseEvalModel(abc.ABC): """Base class encapsulating functionality needed to evaluate a model.""" - def __init__(self, model_args: List[str], init_on_device=False): + def __init__(self, model_args: List[str]): """Initialize model. Args: @@ -59,17 +59,6 @@ def __init__(self, model_args: List[str], init_on_device=False): self.autocast = get_autocast(self.precision) self.cast_dtype = get_cast_dtype(self.precision) - # initialization context - if init_on_device: - # for deepspeed, must init on device, or likely CPU OOM - import deepspeed - - self.init_ctx = deepspeed.OnDevice( - dtype=self.cast_dtype, device=self.device - ) - else: - self.init_ctx = suppress() - @property def required_args(self): """Return list of required arguments to initialize model.""" @@ -83,23 +72,9 @@ def _check_init(self): assert hasattr(self, "tokenizer"), "Tokenizer has not been initialized" self.tokenizer.padding_side = "left" - def init_distributed(self, world_size=None, use_deepspeed=False): - """Wrap model as DDP or deepspeed.""" - if use_deepspeed: - assert "amp" not in self.precision, "Deepspeed does not support amp" - import deepspeed - - self.ds_engine = deepspeed.init_inference( - self.model, - mp_size=world_size, - dtype=self.cast_dtype, - checkpoint=None, - replace_with_kernel_inject=True, - ) - self.model = self.ds_engine.module - self.autocast = get_autocast(None) - else: - self.model = DDP(self.model, device_ids=[self.device]) + def init_distributed(self): + """Wrap model as DDP.""" + self.model = DDP(self.model, device_ids=[self.device]) def __call__( self, diff --git a/open_flamingo/eval/evaluate.py b/open_flamingo/eval/evaluate.py index ed094e81..26dcfa94 100644 --- a/open_flamingo/eval/evaluate.py +++ b/open_flamingo/eval/evaluate.py @@ -394,12 +394,6 @@ action="store_true", help="Don't set device index from local rank (when CUDA_VISIBLE_DEVICES restricted to one per proc).", ) -parser.add_argument( - "--deepspeed", - default=False, - action="store_true", - help="Whether to use deepspeed for distributed inference.", -) def main(): @@ -414,11 +408,9 @@ def main(): model_args["device"] = device_id # initialize model - eval_model = get_eval_model(args.model, model_args, init_on_device=args.deepspeed) + eval_model = get_eval_model(args.model, model_args, init_on_device=False) eval_model.init_distributed( local_rank=args.local_rank, - world_size=args.world_size, - use_deepspeed=args.deepspeed, ) # Validate args diff --git a/open_flamingo/scripts/run_eval_deepspeed.sh b/open_flamingo/scripts/run_eval_deepspeed.sh deleted file mode 100644 index fbba58ff..00000000 --- a/open_flamingo/scripts/run_eval_deepspeed.sh +++ /dev/null @@ -1,77 +0,0 @@ -#!/bin/bash -#SBATCH --nodes=1 -#SBATCH --ntasks-per-node=2 -#SBATCH --gpus-per-task=1 - -< bool: def clear_conditioned_layers(self): for layer in self._get_decoder_layers(): layer.condition_vis_x(None) - layer.condition_media_locations(None) + layer.condition_media_locations(None) \ No newline at end of file diff --git a/open_flamingo/src/factory.py b/open_flamingo/src/factory.py index 98acb61e..b96cc075 100644 --- a/open_flamingo/src/factory.py +++ b/open_flamingo/src/factory.py @@ -1,5 +1,4 @@ from typing import Optional -import torch.nn as nn from transformers import AutoModelForCausalLM, AutoTokenizer import open_clip @@ -7,10 +6,16 @@ from .flamingo import Flamingo from .kosmos import Kosmos from .blip import BLIP +from .llava import Llava from .utils import hasattr_recursive, setattr_recursive -SUPPORTED_MODEL_FAMILIES = ("flamingo", "kosmos", "blip") - +SUPPORTED_MODEL_FAMILIES = ("flamingo", "kosmos", "blip", "llava") +MODEL_FAMILY_TO_CLASS = { + "flamingo": Flamingo, + "kosmos": Kosmos, + "blip": BLIP, + "llava": Llava, +} def create_model_and_transforms( clip_vision_encoder_path: str, @@ -83,41 +88,16 @@ def create_model_and_transforms( if decoder_layers_attr_name is None: decoder_layers_attr_name = _infer_decoder_layers_attr_name(lang_model) - if model_family == "flamingo": - model = Flamingo( - vision_encoder=vision_encoder, - lang_model=lang_model, - vis_feature_dim=vis_hidden_dim, - initial_tokenizer_len=len(text_tokenizer), - gradient_checkpointing=gradient_checkpointing, - decoder_layers_attr_name=decoder_layers_attr_name, - pad_token_id=text_tokenizer.pad_token_id, - **model_kwargs, - ) - - elif model_family == "kosmos": - model = Kosmos( - vision_encoder=vision_encoder, - lang_model=lang_model, - vis_feature_dim=vis_hidden_dim, - initial_tokenizer_len=len(text_tokenizer), - gradient_checkpointing=gradient_checkpointing, - pad_token_id=text_tokenizer.pad_token_id, - decoder_layers_attr_name=decoder_layers_attr_name, - **model_kwargs, - ) - - elif model_family == "blip": - model = BLIP( - vision_encoder=vision_encoder, - lang_model=lang_model, - vis_feature_dim=vis_hidden_dim, - initial_tokenizer_len=len(text_tokenizer), - gradient_checkpointing=gradient_checkpointing, - pad_token_id=text_tokenizer.pad_token_id, - decoder_layers_attr_name=decoder_layers_attr_name, - **model_kwargs, - ) + model = MODEL_FAMILY_TO_CLASS[model_family]( + vision_encoder=vision_encoder, + lang_model=lang_model, + vis_feature_dim=vis_hidden_dim, + initial_tokenizer_len=len(text_tokenizer), + gradient_checkpointing=gradient_checkpointing, + decoder_layers_attr_name=decoder_layers_attr_name, + pad_token_id=text_tokenizer.pad_token_id, + **model_kwargs, + ) # add special tokens to the tokenizer and language models text_tokenizer.add_special_tokens( @@ -130,7 +110,6 @@ def create_model_and_transforms( for v in model.special_tokens.values() } ) - # freeze appropriate parameters model.set_trainable() @@ -139,8 +118,8 @@ def create_model_and_transforms( print( f"{model_family} model initialized with {model.num_trainable_params:,} trainable parameters" ) - print(f"========== Trainable Parameters\n{model.num_trainable_params_per_module}") - print(f"========== Total Parameters\n{model.num_params_per_module}\n==========") + print(f"==========Trainable Parameters\n{model.num_trainable_params_per_module}") + print(f"==========Total Parameters\n{model.num_params_per_module}\n==========") return model, image_processor, text_tokenizer @@ -220,4 +199,4 @@ def has_fn(model, fn_name): getattr(model, fn_name)() return True except: - return False + return False \ No newline at end of file diff --git a/open_flamingo/src/flamingo.py b/open_flamingo/src/flamingo.py index 15829361..89d7c40e 100644 --- a/open_flamingo/src/flamingo.py +++ b/open_flamingo/src/flamingo.py @@ -1,8 +1,9 @@ +from typing import List, Optional, Tuple, Union from torch import nn -from .helpers import PerceiverResampler, GatedCrossAttentionBlock +import torch +from .helpers import PerceiverResampler from .vlm import VLMWithCrossAttention - class Flamingo(VLMWithCrossAttention): def __init__( self, @@ -61,3 +62,17 @@ def _should_apply_weight_decay(self, parameter_name): Flamingo applies 0.1 weight decay to cross attention parameters """ return "gated_cross_attn" in parameter_name + + def generate( + self, + vision_x: torch.Tensor, + lang_x: torch.Tensor, + attention_mask: torch.Tensor = None, + past_key_values: Optional[ + List[Union[torch.Tensor, Tuple[torch.Tensor]]] + ] = None, + past_media_locations: Optional[torch.Tensor] = None, + past_vision_tokens: Optional[torch.Tensor] = None, + **kwargs, + ): + return super().generate(vision_x, lang_x, attention_mask, past_key_values, past_media_locations, past_vision_tokens, eos_token_id=self.eoc_token_id, **kwargs) \ No newline at end of file diff --git a/open_flamingo/src/helpers.py b/open_flamingo/src/helpers.py index 759fd3ce..c4ab859e 100644 --- a/open_flamingo/src/helpers.py +++ b/open_flamingo/src/helpers.py @@ -184,19 +184,19 @@ def forward(self, x): latents = ff(latents) + latents return self.norm(latents) - -class LinearProjection(nn.Module): +class LinearPatchProjection(VisionTokenizer): """Linear projection from patch features to image tokens.""" - def __init__(self, *, dim, dim_out): - super().__init__() - self.proj = nn.Linear(dim, dim_out) - self.out_dim = dim_out + def __init__(self, *, dim_visual, dim_out, num_patches): + super().__init__(dim_media=dim_visual, num_tokens_per_media=num_patches) + self.proj = nn.Linear(dim_visual, dim_out) def forward(self, x): - return self.proj(x) - - + B = x.shape[0] + x = rearrange(x, "b T F v d -> (b T) (F v) d") + x = self.proj(x) + return rearrange(x, "(b T) n d -> b T n d", b=B) + # gated cross attention class MaskedCrossAttention(nn.Module): def __init__( @@ -362,7 +362,7 @@ def __init__( self.query_tokens = nn.Parameter( torch.zeros(1, num_query_tokens, dim_inner) ) - self.proj = LinearProjection(dim=dim_inner, dim_out=dim_out) + self.proj = nn.Linear(dim_inner, dim_out) else: model = Blip2Model.from_pretrained( pretrained_path, diff --git a/open_flamingo/src/llava.py b/open_flamingo/src/llava.py new file mode 100644 index 00000000..df4349a9 --- /dev/null +++ b/open_flamingo/src/llava.py @@ -0,0 +1,55 @@ +from torch import nn +from .helpers import LinearPatchProjection +from .vlm import VLMWithLanguageStream + + +class Llava(VLMWithLanguageStream): + def __init__( + self, + vision_encoder: nn.Module, + lang_model: nn.Module, + vis_feature_dim: int, + initial_tokenizer_len: int, + pad_token_id: int, + decoder_layers_attr_name: str = None, + gradient_checkpointing: bool = False, + ): + """ + Language stream VLM that uses a linear projection, similar to Llava. + + Args: + vision_encoder (nn.Module): HF CLIPModel + lang_encoder (nn.Module): HF causal language model + vis_feature_dim (int): final dimension of the visual features outputted by the vision_encoder + initial_tokenizer_len (int): size of the tokenizer vocab + padding_token_id (int): id of the padding token. None if no padding token; then a padding token + will be inserted into self.special_tokens, which factory.py fills after creating new tokens + decoder_layers_attr_name (str, optional): name of the decoder layers attribute. Defaults to None. + gradient_checkpointing (bool, optional): whether to use gradient checkpointing. Defaults to False. + """ + self._special_tokens = { + "media_token": "", + } + lang_embedding_dim = lang_model.get_input_embeddings().weight.shape[1] + super().__init__( + vision_encoder=vision_encoder, + vision_tokenizer=LinearPatchProjection(dim_visual=vis_feature_dim, + dim_out=lang_embedding_dim, + num_patches=vision_encoder.grid_size[0] * vision_encoder.grid_size[1]), + lang_model=lang_model, + initial_tokenizer_len=initial_tokenizer_len, + gradient_checkpointing=gradient_checkpointing, + decoder_layers_attr_name=decoder_layers_attr_name, + pad_token_id=pad_token_id, + ) + + def set_trainable(self): + """ + Freeze everything except the Q-former and the inserted LM embeddings + """ + self.requires_grad_(False) + self.vision_tokenizer.requires_grad_(True) + self.lang_model.requires_grad_(True) + + def _should_apply_weight_decay(self, parameter_name): + return True diff --git a/open_flamingo/src/vlm.py b/open_flamingo/src/vlm.py index 05434cb8..c55e014a 100644 --- a/open_flamingo/src/vlm.py +++ b/open_flamingo/src/vlm.py @@ -58,6 +58,7 @@ def __init__( # lm embeddings self.pad_token_id = pad_token_id + self.initial_tokenizer_len = initial_tokenizer_len input_embeds = DecoupledEmbedding( max_original_id=initial_tokenizer_len - 1, num_additional_embeddings=len(self.special_tokens), @@ -71,14 +72,13 @@ def __init__( if hasattr(self.lang_model.config, "initializer_range") else 0.02, ) - self.lang_model.set_input_embeddings(input_embeds) out_embeds = DecoupledLinear( max_original_id=initial_tokenizer_len - 1, additional_out_features=len(self.special_tokens), _weight=self.lang_model.get_output_embeddings().weight, - _bias=self.lang_model.get_output_embeddings().bias, + _bias=self.lang_model.get_output_embeddings().bias if hasattr(self.lang_model.get_output_embeddings(), "bias") else None, ) if hasattr(out_embeds, "additional_fc"): out_embeds.additional_fc.weight.data.normal_( @@ -194,7 +194,9 @@ def _encode_vision_x(self, vision_x: torch.Tensor): def _concat_vision_cache( self, lang_x, vision_tokens, past_vision_tokens, past_media_locations, use_cache ): - """Helper function to include the past vision tokens and past media locations in the output""" + """ + Helper function to include the past vision tokens and past media locations in the output. + """ if use_cache: if past_media_locations is not None and past_vision_tokens is not None: if vision_tokens is not None: @@ -253,14 +255,14 @@ def generate( # convert pixels to vision tokens if vision_x is not None: - if num_beams > 1: - vision_x = vision_x.repeat_interleave(num_beams, dim=0) vision_features = self._encode_vision_x(vision_x=vision_x) vision_tokens = self.vision_tokenizer(vision_features) else: vision_tokens = None # fuse the vision and language tokens + # for xattn, vision_x and media_location are repeat_interleaved s.t. + # the total batch size is B * num_beams new_inputs = self._prepare_inputs_for_forward( vision_tokens=vision_tokens, lang_x=lang_x, @@ -268,17 +270,16 @@ def generate( past_key_values=past_key_values, past_media_locations=past_media_locations, past_vision_tokens=past_vision_tokens, - generating=True, + padding_side="right", + num_beams=num_beams, ) output = self.lang_model.generate( **new_inputs, past_key_values=past_key_values, - past_media_locations=past_media_locations, - past_vision_tokens=past_vision_tokens, num_beams=num_beams, + use_cache=True, **kwargs, ) - self._post_forward_hook() return output @@ -403,7 +404,8 @@ def _prepare_inputs_for_forward( past_key_values=None, past_media_locations: torch.Tensor = None, past_vision_tokens: torch.Tensor = None, - generating: bool = False, # Not used for cross-attention models + padding_side: str = "right", # noop for cross-attention models + num_beams: int = 1, ): """Each xattn layer needs to save the vision tokens and the locations of the media tokens in the language sequence""" self.lang_model._condition_media_before_forward( @@ -411,7 +413,17 @@ def _prepare_inputs_for_forward( vision_tokens=vision_tokens, past_media_locations=past_media_locations, past_vision_tokens=past_vision_tokens, + num_beams=num_beams, ) + if past_key_values is not None: + past_key_values = [ + ( + k.repeat_interleave(num_beams, dim=0), + v.repeat_interleave(num_beams, dim=0) + ) + for k, v in past_key_values + ] + return { "input_ids": lang_x, "attention_mask": attention_mask, @@ -460,7 +472,7 @@ def get_fsdp_lambda_fn(self): ) from .helpers import GatedCrossAttentionBlock - original_decoder_block_class = self.lang_model.old_decoder_blocks[0].__class__ + original_decoder_block_class = self.lang_model.decoder_block_class def lambda_fn(module: nn.Module): # we want FSDP(ckpt(module)), not ckpt(FSDP(module)) @@ -532,9 +544,6 @@ def __init__( self.decoder_layers_attr_name = decoder_layers_attr_name for block in getattr_recursive(self.lang_model, self.decoder_layers_attr_name): block._use_gradient_checkpointing = gradient_checkpointing - assert ( - self.vis_embedding_dim == self.lang_embedding_dim - ), "To place visual tokens directly in the language stream, the visual and language tokens need to be the same dim." def _prepare_inputs_for_forward( self, @@ -545,12 +554,20 @@ def _prepare_inputs_for_forward( past_key_values=None, past_media_locations: torch.Tensor = None, past_vision_tokens: torch.Tensor = None, - generating: bool = False, # whether we're generating to decide on padding side + padding_side: str = "left", + num_beams: int = 1, ): """ Insert the vision tokens directly into the language stream/ This requires us to modify the input_ids, attention_mask, and labels. """ + if past_key_values is not None: + past_len = past_key_values[0][0].shape[2] + assert attention_mask.shape[1] == past_len + lang_x.shape[1], ( + "Attention_mask must be as long as the entire past len (including image tokens) and current input IDs. " + + "Check that you've expanded the attention mask to account for past image tokens." + ) + if vision_tokens is None: return { "input_ids": lang_x, @@ -570,7 +587,7 @@ def _prepare_inputs_for_forward( for i in range(B): # get index of tokens in lang_x[i] image_token_idxs = torch.where(lang_x[i] == self.media_token_id)[0] - + if len(image_token_idxs) == 0: multimodal_embeds.append(lang_embeds[i].clone()) multimodal_attention_mask.append(attention_mask[i].clone()) @@ -628,14 +645,20 @@ def _prepare_inputs_for_forward( # stack multimodal_embeds = stack_with_padding( - multimodal_embeds, padding_value=self.pad_token_id, padding_side="left" if generating else "right" + multimodal_embeds, + padding_value=self.pad_token_id, + padding_side=padding_side, ) multimodal_attention_mask = stack_with_padding( - multimodal_attention_mask, padding_value=0, padding_side="left" if generating else "right" + multimodal_attention_mask, + padding_value=0, + padding_side=padding_side, ) if has_labels: multimodal_labels = stack_with_padding( - multimodal_labels, padding_value=-100, padding_side="left" if generating else "right" + multimodal_labels, + padding_value=-100, + padding_side=padding_side, ) return { diff --git a/open_flamingo/train/README.md b/open_flamingo/train/README.md index 47a029a8..86a8a61d 100644 --- a/open_flamingo/train/README.md +++ b/open_flamingo/train/README.md @@ -70,13 +70,12 @@ Our codebase supports distributed training using three frameworks: * Pytorch's [DistributedDataParallel](https://pytorch.org/docs/stable/torch.nn.parallel.DistributedDataParallel.html). This is the default method used by `train.py`. * Pytorch's [FullyShardedDataParallel](https://pytorch.org/docs/stable/fsdp.html) (FSDP). Use the `--fsdp` flag. -* [DeepSpeed](https://github.com/microsoft/DeepSpeed) stages 1-3. Use the `--deepspeed` flag. Note that you should use exactly one of these training methods. `train/distributed.py` contains utilities to help with setting up distributed training using Slurm / `torchrun`. See example scripts in the `scripts` directory. ### FSDP notes -To use FSDP, make sure to use Pytorch Nightly (> 2.0.1). +To use FSDP, make sure to use Pytorch (> 2.0.1). We support two sharding strategies for FSDP: full sharding (model sharing across all nodes and GPUs) or hybrid sharding (model sharding across GPUs within nodes, data parallel between nodes). The former saves GPU memory; the latter saves on communication costs. diff --git a/open_flamingo/train/data.py b/open_flamingo/train/data.py index 384b5242..1a5140ea 100644 --- a/open_flamingo/train/data.py +++ b/open_flamingo/train/data.py @@ -21,7 +21,7 @@ from einops import rearrange from scipy.optimize import linear_sum_assignment -from data_utils import * +from open_flamingo.train.data_utils import * SUPPORTED_DATASETS = ["laion", "mmc4"] @@ -518,4 +518,4 @@ def get_data(args, image_processor, tokenizer, dataset_type, epoch=0): args, image_processor=image_processor, epoch=epoch, tokenizer=tokenizer ) else: - raise ValueError(f"Unsupported dataset: {dataset_type}") + raise ValueError(f"Unsupported dataset: {dataset_type}") \ No newline at end of file diff --git a/open_flamingo/train/distributed.py b/open_flamingo/train/distributed.py index 2cd88c59..4b558ed2 100644 --- a/open_flamingo/train/distributed.py +++ b/open_flamingo/train/distributed.py @@ -1,10 +1,9 @@ """ -Util functions for distributed training, FSDP, and Deepspeed. +Util functions for distributed training and FSDP. """ import os import torch -from data import SUPPORTED_DATASETS ################################## # SLURM setup; Credit: open_clip # @@ -137,7 +136,7 @@ def init_distributed_device(args): ##################################### -# FSDP and Deepspeed util functions # +# FSDP util functions # ##################################### @@ -225,48 +224,3 @@ def get_fsdp_checkpoint_config(args): rank0_only=True, offload_to_cpu=True ), ) - - -def get_deepspeed_config( - args, -): - """ - Return kwargs for Deepspeed config. - """ - zero_opt_dict = { - "stage": args.deepspeed_stage, - "overlap_comm": True, - "contiguous_gradients": True, - "offload_param": {"device": "none"}, # TODO: Support CPU offload - "offload_optimizer": {"device": "none"}, - "stage3_param_persistence_threshold": 1e4, - "stage3_max_live_parameters": 3e7, - "stage3_prefetch_bucket_size": 3e7, - "memory_efficient_linear": False, - } - # sum all the args that start with batch_size_ to get the total batch size - total_batch_size = sum( - [getattr(args, arg) for arg in vars(args) if arg.startswith("batch_size_")] - ) - ds_config = { - "train_batch_size": total_batch_size - * args.world_size - * args.gradient_accumulation_steps, - "train_micro_batch_size_per_gpu": total_batch_size - * args.gradient_accumulation_steps, - "steps_per_print": args.logging_steps, - "zero_optimization": zero_opt_dict, - "gradient_clipping": 1.0, - "prescale_gradients": False, - "wall_clock_breakdown": False, - } - - if args.precision == "fp16": - ds_config["fp16"] = {"enabled": True} - elif args.precision == "bf16": - ds_config["bf16"] = {"enabled": True} - # amp not supported with DeepSpeed - elif "amp" in args.precision: - raise ValueError("amp not supported with DeepSpeed") - - return ds_config diff --git a/open_flamingo/train/train.py b/open_flamingo/train/train.py index 77d0451c..1ee6692f 100644 --- a/open_flamingo/train/train.py +++ b/open_flamingo/train/train.py @@ -10,15 +10,15 @@ from torch.distributed.fsdp.wrap import lambda_auto_wrap_policy from open_flamingo import create_model_and_transforms, SUPPORTED_MODEL_FAMILIES -from data import get_data, SUPPORTED_DATASETS -from distributed import ( +from open_flamingo.train.data import get_data, SUPPORTED_DATASETS +from open_flamingo.train.distributed import ( init_distributed_device, world_info_from_env, get_fsdp_config, get_fsdp_checkpoint_config, get_deepspeed_config, ) -from train_utils import ( +from open_flamingo.train.train_utils import ( train_one_epoch, random_seed, load_deepspeed_checkpoint, @@ -27,7 +27,7 @@ save_checkpoint, save_deepspeed_checkpoint, ) -from losses import ( +from open_flamingo.train.losses import ( SUPPORTED_LOSSES, get_loss_fn, ) @@ -261,11 +261,6 @@ def main(): assert ( "dev" in torch.__version__ and torch.__version__ > "2.0.1" ), "FSDP requires torch nightly > 2.0.1" - - if args.deepspeed and args.gradient_checkpointing: - print( - "Gradient checkpointing with Deepspeed will cause all parameters to be saved for each checkpoint." - ) # Set up distributed training args.local_rank, args.rank, args.world_size = world_info_from_env() @@ -375,7 +370,7 @@ def main(): ] total_training_steps = ( getattr(args, f"train_num_samples_{datasets_to_train_on[0]}") - // (getattr(args, f"batch_size_{datasets_to_train_on[0]}") * args.world_size) + // getattr(args, f"batch_size_{datasets_to_train_on[0]}") ) * args.num_epochs if args.rank == 0: @@ -447,4 +442,4 @@ def main(): if __name__ == "__main__": - main() + main() \ No newline at end of file diff --git a/open_flamingo/train/train_utils.py b/open_flamingo/train/train_utils.py index 58ec5ab9..badc6bb6 100644 --- a/open_flamingo/train/train_utils.py +++ b/open_flamingo/train/train_utils.py @@ -30,7 +30,7 @@ def train_one_epoch( Handles logging, calling forward, backward, gradient clipping, and optimizer step. Args: args (argparse.Namespace): arguments from command line - model: DDP / FSDP / Deepspeed wrapped model + model: DDP / FSDP wrapped model epoch (int): epoch number datasets (list): list of DataInfos, one for each dataset, to train on compute_loss_fn (callable): function that given the model and inputs, calls forward @@ -98,27 +98,21 @@ def train_one_epoch( dataset_loss *= ( datasets[dataset_ix].loss_multiplier / args.gradient_accumulation_steps ) - if args.deepspeed: - model.backward(dataset_loss) - else: - (dataset_loss).backward() + (dataset_loss).backward() # clip gradient norm if args.fsdp: model.clip_grad_norm_(1.0) - elif not args.deepspeed: # deepspeed handles clipping internally + else: torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) # step optimizer and log if (((step_num + 1) % args.gradient_accumulation_steps) == 0) or ( step_num == num_batches_per_epoch - 1 ): - if args.deepspeed: - model.step() - else: - optimizer.step() - lr_scheduler.step() - optimizer.zero_grad(set_to_none=True) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=True) # step time and reset end outside of rank 0 step_time_m.update(time.time() - end) @@ -269,27 +263,20 @@ def find_most_recent_checkpoint(args): """ Returns the path of the most recent checkpoint for a given run name. """ - if args.deepspeed: - if os.path.exists(f"{args.run_name}/latest"): - resume_from_checkpoint = args.run_name - print(f"Found checkpoint {resume_from_checkpoint} for run {args.run_name}.") - else: - print(f"Found no checkpoints for run {args.run_name}.") + checkpoint_list = glob.glob(f"{args.run_name}/checkpoint_*.pt") + if len(checkpoint_list) == 0: + print(f"Found no checkpoints for run {args.run_name}.") else: - checkpoint_list = glob.glob(f"{args.run_name}/checkpoint_*.pt") - if len(checkpoint_list) == 0: - print(f"Found no checkpoints for run {args.run_name}.") - else: - resume_from_checkpoint = sorted( - checkpoint_list, key=lambda x: int(x.split("_")[-1].split(".")[0]) - )[-1] - print(f"Found checkpoint {resume_from_checkpoint} for run {args.run_name}.") + resume_from_checkpoint = sorted( + checkpoint_list, key=lambda x: int(x.split("_")[-1].split(".")[0]) + )[-1] + print(f"Found checkpoint {resume_from_checkpoint} for run {args.run_name}.") return resume_from_checkpoint def load_checkpoint(args, model): """ - Loads a (non-Deepspeed) checkpoint into the model and returns the checkpoint + epoch to resume from. + Loads a checkpoint into the model and returns the checkpoint + epoch to resume from. Does not load the optimizer or learning rate checkpoints, but these are included in the returned checkpoint dict. """ if args.rank == 0: @@ -306,23 +293,6 @@ def load_checkpoint(args, model): model.load_state_dict(msd, False) return resume_from_epoch, checkpoint - -def load_deepspeed_checkpoint(args, model): - """Loads a deepspeed checkpoint and returns the epoch to resume from.""" - if args.rank == 0: - print(f"Loading checkpoint from {args.resume_from_checkpoint}") - # We will not pass in a 'tag' and instead rely on 'latest' file in the checkpoint directory - model.load_checkpoint( - load_dir=args.resume_from_checkpoint, # Note: this is the dir, not the file - load_module_strict=False, - ) - # read latest file to get epoch - latest_file = os.path.join(args.resume_from_checkpoint, "latest") - with open(latest_file, "r") as f: - checkpoint_epoch = int(f.read().split("_")[-1]) - return checkpoint_epoch + 1 - - def filter_state_dict_to_trainable(model, state_dict): """ Remove non-trainable parameters from model state dict. @@ -387,24 +357,3 @@ def save_checkpoint(model, optimizer, lr_scheduler, epoch, args): if args.delete_previous_checkpoint: if epoch > 0: os.remove(f"{args.run_name}/checkpoint_{epoch-1}.pt") - - -def save_deepspeed_checkpoint(model, epoch, args): - """ - Save training checkpoint for deepspeed. - """ - print(f"Saving checkpoint to {args.run_name}") - model.save_checkpoint( - save_dir=args.run_name, - save_latest=True, - tag=f"epoch_{epoch}", - exclude_frozen_parameters=not args.gradient_checkpointing, # Save all parameters if gradient checkpointing is enabled - ) - - if args.rank == 0: - if args.report_to_wandb and args.save_checkpoints_to_wandb: - wandb.save(f"{args.run_name}/epoch_{epoch}/mp_rank_00_model_states.pt") - - if args.delete_previous_checkpoint: - if epoch > 0: # remove checkpoint dir epoch_{epoch-1} - shutil.rmtree(f"{args.run_name}/epoch_{epoch-1}") diff --git a/requirements-training.txt b/requirements-training.txt index 8b46a831..3e4d87b9 100644 --- a/requirements-training.txt +++ b/requirements-training.txt @@ -2,5 +2,4 @@ torchvision braceexpand webdataset tqdm -wandb -deepspeed \ No newline at end of file +wandb \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 54bdd40c..4412420b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ einops einops-exts -transformers>=4.28.1 +transformers==4.28.1 torch>=2.0.1 pillow open_clip_torch>=2.16.0 diff --git a/setup.py b/setup.py index a1f62969..e4760b91 100644 --- a/setup.py +++ b/setup.py @@ -6,33 +6,14 @@ with Path(Path(__file__).parent, "README.md").open(encoding="utf-8") as file: long_description = file.read() - REQUIREMENTS = [ - "einops", - "einops-exts", - "transformers>=4.28.1", - "torch==2.0.1", - "pillow", - "open_clip_torch>=2.16.0", - "sentencepiece", - ] + with open("requirements.txt") as f: + REQUIREMENTS = f.read().splitlines() - EVAL = [ - "scipy", - "torchvision", - "nltk", - "inflection", - "pycocoevalcap", - "pycocotools", - "tqdm", - ] + with open("requirements-eval.txt") as f: + EVAL = f.read().splitlines() - TRAINING = [ - "wandb", - "torchvision", - "braceexpand", - "webdataset", - "tqdm", - ] + with open("requirements-training.txt") as f: + TRAINING = f.read().splitlines() setup( name="open_flamingo",