diff --git a/src/shallowflow/trainer/llm_trainer.py b/src/shallowflow/trainer/llm_trainer.py index 8c90c59..1f90a24 100644 --- a/src/shallowflow/trainer/llm_trainer.py +++ b/src/shallowflow/trainer/llm_trainer.py @@ -1,8 +1,11 @@ import torch +import torch.distributed as dist from transformers import PreTrainedModel, PreTrainedTokenizer from ..utils.config import TrainingConfig from ..optimizations import LoRALayer, Quantizer from ..utils.memory import MemoryTracker +from typing import Union, List, Dict, Any +import boto3 class LLMTrainer: def __init__( @@ -31,6 +34,11 @@ def _apply_lora(self): in_features = module.in_features out_features = module.out_features + if not isinstance(in_features, int) or not isinstance(out_features, int): + raise TypeError("Features must be integers") + if in_features <= 0 or out_features <= 0: + raise ValueError("Features must be positive") + # Replace with LoRA layer lora_layer = LoRALayer( in_features, @@ -39,6 +47,18 @@ def _apply_lora(self): ) setattr(self.model, name, lora_layer) + def _setup_distributed(self): + """Initialize distributed process group""" + if not dist.is_initialized(): + dist.init_process_group( + backend=self.config.backend, + world_size=self.config.world_size, + rank=self.config.rank + ) + + torch.cuda.set_device(self.config.rank) + dist.barrier() # Add synchronization point + def train( self, train_dataset, @@ -50,10 +70,65 @@ def train( for epoch in range(self.config.num_epochs): self.model.train() - for batch in train_dataset: + for batch_idx, batch in enumerate(train_dataset): loss = self._training_step(batch) self._optimization_step(loss, optimizer) + if batch_idx % 100 == 0: + torch.cuda.empty_cache() + if eval_dataset: eval_loss = self._evaluate(eval_dataset) - print(f"Epoch {epoch}: eval_loss = {eval_loss}") \ No newline at end of file + print(f"Epoch {epoch}: eval_loss = {eval_loss}") + + def process_text( + self, + text: Union[str, List[str]] + ) -> Dict[str, torch.Tensor]: + """ + Process raw text input + """ + if self.tokenizer is None: + raise ValueError("Tokenizer not initialized") + + # Tokenize text + encoded = self.tokenizer( + text, + max_length=self.config.max_length, + padding=self.config.padding, + truncation=self.config.truncation, + return_tensors=self.config.return_tensors, + add_special_tokens=self.config.add_special_tokens + ) + + def cleanup(self): + torch.cuda.empty_cache() + if hasattr(self, 'wrapped_model'): + del self.wrapped_model + + def launch_instance(self) -> Dict: + launch_template = { + 'InstanceType': self.config.instance_type, + 'ImageId': self._get_deep_learning_ami(), + 'BlockDeviceMappings': [{ + 'DeviceName': '/dev/xvda', + 'Ebs': { + 'VolumeSize': self.config.volume_size, + 'VolumeType': 'gp3' + } + }] + } + + if self.config.spot_instance: + try: + return self._launch_spot_instance(launch_template) + except boto3.exceptions.Boto3Error as e: + raise RuntimeError(f"AWS operation failed: {str(e)}") + return self._launch_on_demand_instance(launch_template) + + def process_dataset(self, dataset): + if self.max_samples: + dataset = dataset[:self.max_samples] + + for batch in self._create_batches(dataset): + yield self.text_processor.process_batch(batch) \ No newline at end of file diff --git a/src/shallowflow/trainer/local_trainer.py b/src/shallowflow/trainer/local_trainer.py index d38185d..fd0c6c1 100644 --- a/src/shallowflow/trainer/local_trainer.py +++ b/src/shallowflow/trainer/local_trainer.py @@ -21,6 +21,8 @@ def __init__( # Initialize wandb self.init_wandb() + self.scaler = torch.cuda.amp.GradScaler() + def init_wandb(self): """Initialize wandb tracking""" wandb.init( diff --git a/src/shallowflow/utils/gpu_optimizations.py b/src/shallowflow/utils/gpu_optimizations.py index ab6c0c4..91e1c02 100644 --- a/src/shallowflow/utils/gpu_optimizations.py +++ b/src/shallowflow/utils/gpu_optimizations.py @@ -3,6 +3,13 @@ from typing import Tuple, Optional from torch.cuda.amp import autocast, GradScaler from torch.utils.data import DataLoader +from torch.distributed.fsdp import ( + FullyShardedDataParallel, + size_based_auto_wrap_policy, + enable_wrap, + wrap, + BackwardPrefetch +) @dataclass class GTX1660Config: @@ -42,4 +49,32 @@ def get_memory_stats(self) -> dict: "allocated": torch.cuda.memory_allocated() / 1024**2, "reserved": torch.cuda.memory_reserved() / 1024**2, "max_allocated": torch.cuda.max_memory_allocated() / 1024**2 - } \ No newline at end of file + } + + def _check_memory(self): + if torch.cuda.memory_allocated() > 0.9 * self.config.memory_limit * 1e9: + torch.cuda.empty_cache() + raise RuntimeError("GPU memory nearly exhausted") + + def prepare_model(self, model: torch.nn.Module) -> FullyShardedDataParallel: + """Wrap model in FSDP""" + # Auto wrapping policy + auto_wrap_policy = size_based_auto_wrap_policy( + min_num_params=self.config.min_num_params + ) + + # FSDP configuration + fsdp_config = { + "auto_wrap_policy": auto_wrap_policy, + "mixed_precision": self._get_mixed_precision_policy(), + "cpu_offload": self._get_cpu_offload() + } + + if self.config.backward_prefetch: + fsdp_config["backward_prefetch"] = BackwardPrefetch.BACKWARD_PRE + + # Wrap model with FSDP + with enable_wrap(wrapper_cls=FullyShardedDataParallel, **fsdp_config): + wrapped_model = wrap(model) + + return wrapped_model \ No newline at end of file