Skip to content

Commit

Permalink
updates
Browse files Browse the repository at this point in the history
  • Loading branch information
NinoRisteski committed Nov 12, 2024
1 parent 95b7232 commit 469c271
Show file tree
Hide file tree
Showing 3 changed files with 115 additions and 3 deletions.
79 changes: 77 additions & 2 deletions src/shallowflow/trainer/llm_trainer.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import torch
import torch.distributed as dist
from transformers import PreTrainedModel, PreTrainedTokenizer
from ..utils.config import TrainingConfig
from ..optimizations import LoRALayer, Quantizer
from ..utils.memory import MemoryTracker
from typing import Union, List, Dict, Any
import boto3

class LLMTrainer:
def __init__(
Expand Down Expand Up @@ -31,6 +34,11 @@ def _apply_lora(self):
in_features = module.in_features
out_features = module.out_features

if not isinstance(in_features, int) or not isinstance(out_features, int):
raise TypeError("Features must be integers")
if in_features <= 0 or out_features <= 0:
raise ValueError("Features must be positive")

# Replace with LoRA layer
lora_layer = LoRALayer(
in_features,
Expand All @@ -39,6 +47,18 @@ def _apply_lora(self):
)
setattr(self.model, name, lora_layer)

def _setup_distributed(self):
"""Initialize distributed process group"""
if not dist.is_initialized():
dist.init_process_group(
backend=self.config.backend,
world_size=self.config.world_size,
rank=self.config.rank
)

torch.cuda.set_device(self.config.rank)
dist.barrier() # Add synchronization point

def train(
self,
train_dataset,
Expand All @@ -50,10 +70,65 @@ def train(

for epoch in range(self.config.num_epochs):
self.model.train()
for batch in train_dataset:
for batch_idx, batch in enumerate(train_dataset):
loss = self._training_step(batch)
self._optimization_step(loss, optimizer)

if batch_idx % 100 == 0:
torch.cuda.empty_cache()

if eval_dataset:
eval_loss = self._evaluate(eval_dataset)
print(f"Epoch {epoch}: eval_loss = {eval_loss}")
print(f"Epoch {epoch}: eval_loss = {eval_loss}")

def process_text(
self,
text: Union[str, List[str]]
) -> Dict[str, torch.Tensor]:
"""
Process raw text input
"""
if self.tokenizer is None:
raise ValueError("Tokenizer not initialized")

# Tokenize text
encoded = self.tokenizer(
text,
max_length=self.config.max_length,
padding=self.config.padding,
truncation=self.config.truncation,
return_tensors=self.config.return_tensors,
add_special_tokens=self.config.add_special_tokens
)

def cleanup(self):
torch.cuda.empty_cache()
if hasattr(self, 'wrapped_model'):
del self.wrapped_model

def launch_instance(self) -> Dict:
launch_template = {
'InstanceType': self.config.instance_type,
'ImageId': self._get_deep_learning_ami(),
'BlockDeviceMappings': [{
'DeviceName': '/dev/xvda',
'Ebs': {
'VolumeSize': self.config.volume_size,
'VolumeType': 'gp3'
}
}]
}

if self.config.spot_instance:
try:
return self._launch_spot_instance(launch_template)
except boto3.exceptions.Boto3Error as e:
raise RuntimeError(f"AWS operation failed: {str(e)}")
return self._launch_on_demand_instance(launch_template)

def process_dataset(self, dataset):
if self.max_samples:
dataset = dataset[:self.max_samples]

for batch in self._create_batches(dataset):
yield self.text_processor.process_batch(batch)
2 changes: 2 additions & 0 deletions src/shallowflow/trainer/local_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ def __init__(
# Initialize wandb
self.init_wandb()

self.scaler = torch.cuda.amp.GradScaler()

def init_wandb(self):
"""Initialize wandb tracking"""
wandb.init(
Expand Down
37 changes: 36 additions & 1 deletion src/shallowflow/utils/gpu_optimizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,13 @@
from typing import Tuple, Optional
from torch.cuda.amp import autocast, GradScaler
from torch.utils.data import DataLoader
from torch.distributed.fsdp import (
FullyShardedDataParallel,
size_based_auto_wrap_policy,
enable_wrap,
wrap,
BackwardPrefetch
)

@dataclass
class GTX1660Config:
Expand Down Expand Up @@ -42,4 +49,32 @@ def get_memory_stats(self) -> dict:
"allocated": torch.cuda.memory_allocated() / 1024**2,
"reserved": torch.cuda.memory_reserved() / 1024**2,
"max_allocated": torch.cuda.max_memory_allocated() / 1024**2
}
}

def _check_memory(self):
if torch.cuda.memory_allocated() > 0.9 * self.config.memory_limit * 1e9:
torch.cuda.empty_cache()
raise RuntimeError("GPU memory nearly exhausted")

def prepare_model(self, model: torch.nn.Module) -> FullyShardedDataParallel:
"""Wrap model in FSDP"""
# Auto wrapping policy
auto_wrap_policy = size_based_auto_wrap_policy(
min_num_params=self.config.min_num_params
)

# FSDP configuration
fsdp_config = {
"auto_wrap_policy": auto_wrap_policy,
"mixed_precision": self._get_mixed_precision_policy(),
"cpu_offload": self._get_cpu_offload()
}

if self.config.backward_prefetch:
fsdp_config["backward_prefetch"] = BackwardPrefetch.BACKWARD_PRE

# Wrap model with FSDP
with enable_wrap(wrapper_cls=FullyShardedDataParallel, **fsdp_config):
wrapped_model = wrap(model)

return wrapped_model

0 comments on commit 469c271

Please sign in to comment.