Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Custom modeling for training #801

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions docs/source/guides/distributed_training.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -183,11 +183,11 @@ Just as for ZeRO-1, it is possible to wrap the optimizer class to make it lazy.
```python
from torch.optim import AdamW
from optimum.neuron import NeuronAccelerator
from optimum.neuron.accelerate.utils import ModelParallelismPlugin
from optimum.neuron.accelerate.utils import ModelParallelismConfig
from optimum.neuron.distributed import lazy_load_for_parallelism

tensor_parallel_size = 8
mp_plugin = ModelParallelismPlugin(
mp_config = ModelParallelismConfig(
tensor_parallel_size,
parallelize_embeddings=True,
sequence_parallel_enabled=True,
Expand All @@ -196,7 +196,7 @@ mp_plugin = ModelParallelismPlugin(

accelerator = NeuronAccelerator(
...
mp_plugin=mp_plugin,
mp_config=mp_config,
)

with lazy_load_for_parallelism(tensor_parallel_size=tensor_parallel_size):
Expand Down
4 changes: 2 additions & 2 deletions optimum/neuron/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,14 +68,14 @@
"NeuronAccelerator",
"NeuronAcceleratorState",
"NeuronPartialState",
"ModelParallelismPlugin",
"ModelParallelismConfig",
],
"pipelines": ["pipeline"],
"utils": ["NeuronSFTConfig", "NeuronORPOConfig", "get_peft_model"],
}

if TYPE_CHECKING:
from .accelerate import ModelParallelismPlugin, NeuronAccelerator, NeuronAcceleratorState, NeuronPartialState
from .accelerate import ModelParallelismConfig, NeuronAccelerator, NeuronAcceleratorState, NeuronPartialState
from .hf_argparser import NeuronHfArgumentParser
from .modeling import (
NeuronModelForAudioClassification,
Expand Down
2 changes: 1 addition & 1 deletion optimum/neuron/accelerate/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,4 @@

from .accelerator import NeuronAccelerator
from .state import NeuronAcceleratorState, NeuronPartialState
from .utils.dataclasses import ModelParallelismPlugin, NeuronDistributedType
from .utils.dataclasses import ModelParallelismConfig, NeuronDistributedType
16 changes: 8 additions & 8 deletions optimum/neuron/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
from .state import NeuronAcceleratorState
from .utils import (
AutocastBackend,
ModelParallelismPlugin,
ModelParallelismConfig,
NeuronDistributedType,
patch_accelerate_is_torch_xla_available,
)
Expand Down Expand Up @@ -98,7 +98,7 @@ class NeuronAccelerator(Accelerator):
def __init__(
self,
*args,
mp_plugin: Optional[ModelParallelismPlugin] = None,
mp_config: Optional[ModelParallelismConfig] = None,
zero_1: bool = False,
autocast_backend: Union[str, AutocastBackend] = "xla",
**kwargs,
Expand Down Expand Up @@ -146,7 +146,7 @@ def patched_is_torch_xla_available(check_is_tpu: bool = False, check_is_gpu: boo
accelerate.state.is_torch_xla_available = patched_is_torch_xla_available

patched_accelerator_state = partial(
NeuronAcceleratorState, mp_plugin=mp_plugin, autocast_backend=autocast_backend
NeuronAcceleratorState, mp_config=mp_config, autocast_backend=autocast_backend
)
with Patcher([("accelerate.accelerator.AcceleratorState", patched_accelerator_state)]):
super().__init__(**full_kwargs)
Expand Down Expand Up @@ -225,7 +225,7 @@ def prepare_data_loader(
data_loader, num_replicas=num_replicas, rank=rank, force_drop_last=force_drop_last
)
# No need to wrap the dataloader if we are using pipeline parallelism.
if use_mp_device_loader and self.state.mp_plugin.pipeline_parallel_size == 1:
if use_mp_device_loader and self.state.mp_config.pipeline_parallel_size == 1:
data_loader = MpDeviceLoader(data_loader, self.device)
return data_loader

Expand Down Expand Up @@ -384,7 +384,7 @@ def _prepare_model_for_mp(

tied_parameters_dict = get_tied_parameters_dict(model)
model_main_input_name = getattr(model, "main_input_name", None)
model = self.state.mp_plugin.parallelize_model(model, device=self.device)
model = self.state.mp_config.parallelize_model(model, device=self.device)

if model_main_input_name is not None:
setattr(model, "main_input_name", model_main_input_name)
Expand Down Expand Up @@ -628,9 +628,9 @@ def save_optimizer_func(accelerator, optimizer, model, output_dir, i):
model,
output_dir,
optimizer=optimizer,
use_xser=self.state.mp_plugin.use_xser,
async_save=self.state.mp_plugin.async_save,
num_local_ranks_per_step=self.state.mp_plugin.num_local_ranks_per_step,
use_xser=self.state.mp_config.use_xser,
async_save=self.state.mp_config.async_save,
num_local_ranks_per_step=self.state.mp_config.num_local_ranks_per_step,
)
logger.info(f"Parallel model and optimizer saved to the directory {output_dir}")

Expand Down
16 changes: 8 additions & 8 deletions optimum/neuron/accelerate/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
set_neuron_cc_flags_for_torch_amp,
)
from .utils import NeuronDistributedType
from .utils.dataclasses import AutocastBackend, ModelParallelismPlugin
from .utils.dataclasses import AutocastBackend, ModelParallelismConfig


if is_torch_xla_available():
Expand Down Expand Up @@ -127,7 +127,7 @@ def __init__(
deepspeed_plugin=None,
fsdp_plugin=None,
megatron_lm_plugin=None,
mp_plugin: Optional[ModelParallelismPlugin] = None,
mp_config: Optional[ModelParallelismConfig] = None,
autocast_backend: Optional[Union[str, AutocastBackend]] = None,
_from_accelerator: bool = False,
**kwargs,
Expand Down Expand Up @@ -183,18 +183,18 @@ def __init__(
if mixed_precision == "bf16":
os.environ["NEURON_RT_STOCHASTIC_ROUNDING_EN"] = "1"

if mp_plugin is None:
mp_plugin = ModelParallelismPlugin()
if mp_config is None:
mp_config = ModelParallelismConfig()

if mp_plugin.should_parallelize:
if mp_config.should_parallelize:
self.distributed_type = NeuronDistributedType.MODEL_PARALLELISM

self.mp_plugin = mp_plugin
self.mp_config = mp_config

if torch.distributed.is_initialized() and not parallel_state.model_parallel_is_initialized():
parallel_state.initialize_model_parallel(
tensor_model_parallel_size=self.mp_plugin.tensor_parallel_size,
pipeline_model_parallel_size=self.mp_plugin.pipeline_parallel_size,
tensor_model_parallel_size=self.mp_config.tensor_parallel_size,
pipeline_model_parallel_size=self.mp_config.pipeline_parallel_size,
)

if self.distributed_type is DistributedType.NO:
Expand Down
2 changes: 1 addition & 1 deletion optimum/neuron/accelerate/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from .dataclasses import (
AutocastBackend,
ModelParallelismPlugin,
ModelParallelismConfig,
NeuronDistributedType,
)
from .misc import patch_accelerate_is_torch_xla_available
27 changes: 26 additions & 1 deletion optimum/neuron/accelerate/utils/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@
import torch

from ...distributed import ParallelizersManager
from ...utils.torch_xla_and_neuronx_initialization import init_process_group
from ...utils import is_neuronx_distributed_available

if is_neuronx_distributed_available():
from neuronx_distributed.parallel_layers import parallel_state


if TYPE_CHECKING:
Expand Down Expand Up @@ -49,7 +54,7 @@ class AutocastBackend(str, enum.Enum):


@dataclass
class ModelParallelismPlugin:
class ModelParallelismConfig:
tensor_parallel_size: int = 1
parallelize_embeddings: bool = True
sequence_parallel_enabled: bool = False
Expand All @@ -62,6 +67,8 @@ class ModelParallelismPlugin:
num_local_ranks_per_step: int = 8
use_xser: bool = True
async_save: bool = False
fuse_qkv: bool = False
use_flash_attention: bool = True

def __post_init__(self):
if self.tensor_parallel_size < 1:
Expand All @@ -73,6 +80,24 @@ def __post_init__(self):
if isinstance(self.checkpoint_dir, str):
self.checkpoint_dir = Path(self.checkpoint_dir)

if not torch.distributed.is_initialized():
init_process_group()

if not parallel_state.model_parallel_is_initialized():
parallel_state.initialize_model_parallel(
tensor_model_parallel_size=self.tensor_parallel_size,
pipeline_model_parallel_size=self.pipeline_parallel_size,
)

def auto_kv_size_multiplier(self, num_key_value_heads: int) -> int:
kv_size_multiplier = self.tensor_parallel_size // num_key_value_heads
if self.kv_size_multiplier is not None and self.kv_size_multiplier != kv_size_multiplier:
raise ValueError(
"A kv size multiplier was already specified and is different from the inferred one: "
f"{self.kv_size_multiplier}"
)
return kv_size_multiplier

@property
def should_parallelize(self):
return self.tensor_parallel_size > 1 or self.pipeline_parallel_size > 1
Expand Down
6 changes: 6 additions & 0 deletions optimum/neuron/distributed/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import gc
import math
from abc import ABC, abstractclassmethod
from dataclasses import dataclass
from collections import defaultdict
from dataclasses import asdict, replace
from pathlib import Path
Expand Down Expand Up @@ -559,6 +560,10 @@ def parallelize(
orig_model, peft_prefix = get_base_model_and_peft_prefix(model)
model_class = orig_model.__class__

import inspect
if inspect.getmodule(orig_model.__class__).__name__.startswith("optimum.neuron.models.training"):
return orig_model

if peft_prefix:
# We update the weight_map to contain both the original parameter names, and the ones in the PeftModel.
# The reason we keep both is because depending on the context during parallelization one or the other name
Expand Down Expand Up @@ -1047,3 +1052,4 @@ def load_model_sharded_checkpoint(cls, model: "PreTrainedModel", load_dir: Union
@classmethod
def load_optimizer_sharded_checkpoint(cls, optimizer: "torch.optim.Optimizer", load_dir: Union[str, Path]):
return cls.load_sharded_checkpoint(load_dir, optimizer=optimizer)

52 changes: 52 additions & 0 deletions optimum/neuron/models/loss_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional

import torch
import torch.nn as nn
from torch.nn import BCEWithLogitsLoss, MSELoss

from neuronx_distributed.parallel_layers.parallel_state import get_tensor_model_parallel_size

from ..distributed.utils import parallel_cross_entropy


def fixed_cross_entropy(source, target, num_items_in_batch: Optional[int] = None, ignore_index: int = -100, **kwargs):
reduction = "sum" if num_items_in_batch is not None else "mean"
loss_function = parallel_cross_entropy if get_tensor_model_parallel_size() > 1 else nn.functional.cross_entropy
loss = loss_function(source, target, ignore_index=ignore_index, reduction=reduction)
if reduction == "sum":
loss = loss / num_items_in_batch
return loss


def ForCausalLMLoss(
logits, labels, vocab_size: int, num_items_in_batch: Optional[int] = None, ignore_index: int = -100, **kwargs
):
# Upcast to float if we need to compute the loss to avoid potential precision issues
logits = logits.float()
labels = labels.to(logits.device)
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()

# Flatten the tokens
shift_logits = shift_logits.view(-1, vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = fixed_cross_entropy(shift_logits, shift_labels, num_items_in_batch, ignore_index, **kwargs)
return loss
Loading
Loading