Skip to content

Commit

Permalink
[WIP]
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelbenayoun committed Feb 19, 2025
1 parent 0639f91 commit e1b2591
Show file tree
Hide file tree
Showing 5 changed files with 330 additions and 61 deletions.
1 change: 1 addition & 0 deletions optimum/neuron/accelerate/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ def step(self, closure=None):
self.accelerator_state.distributed_type is DistributedType.XLA
or self.accelerator_state.distributed_type is NeuronDistributedType.MODEL_PARALLELISM
):

if parallel_layers.parallel_state.get_data_parallel_size() > 1:
bucket_allreduce_gradients(xm._fetch_gradients(self.optimizer))
if self.clip_grad_norm_to_perform is not None:
Expand Down
47 changes: 4 additions & 43 deletions optimum/neuron/distributed/parallel_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
FakeProj,
OptimumGQAQKVColumnParallelLinear,
WeightInformation,
parallel_cross_entropy,
embedding_to_parallel_embedding,
get_linear_weight_info,
linear_to_parallel_linear,
Expand All @@ -49,11 +50,6 @@
logger = logging.get_logger()


# Just for testing purposes, setting that to True will feed a copy of the input to `parallel_cross_entropy` which
# changes inputs inplace. This way the original input is not transformed and can be used in tests for comparison.
_PARALLEL_CROSS_ENTROPY_SHOULD_PRESERVE_INPUT: bool = False


class ParallelLayer(ABC):
PARALLEL_LAYER_SPECIFIC_KWARGS: Optional[Dict[str, Any]] = None

Expand Down Expand Up @@ -871,40 +867,6 @@ def _transform(
return layer


@requires_neuronx_distributed
@torch.fx.wrap
def safe_parallel_cross_entropy(*args, **kwargs):
if kwargs.pop("weight", None) is not None:
raise ValueError("The weight keyword argument is not supported when using parallel cross entropy")
if kwargs.pop("size_average", None) is not None:
raise ValueError("The size_average keyword argument is not supported when using parallel cross entropy")
if kwargs.pop("ignore_index", -100) != -100:
raise ValueError("The ignore_index keyword argument is not supported when using parallel cross entropy")
if kwargs.pop("reduce", None) is not None:
raise ValueError("The reduce keyword argument is not supported when using parallel cross entropy")
reduction = kwargs.pop("reduction", "mean")
if reduction not in ["mean", "sum", "none"]:
raise ValueError(
f'The reduction parameter only accepts 3 values: "mean", "sum" and "none", but {reduction} was provided '
"here."
)

from neuronx_distributed.parallel_layers.loss_functions import parallel_cross_entropy

input_ = args[0]
if _PARALLEL_CROSS_ENTROPY_SHOULD_PRESERVE_INPUT:
input_ = input_.clone()

loss = parallel_cross_entropy(input_, *args[1:], **kwargs)

if reduction == "mean":
loss = loss.mean()
elif reduction == "sum":
loss = loss.sum()

return loss


class ParallelCrossEntropyLoss(_WeightedLoss):
"""
Same as `torch.nn.CrossEntropyLoss` except that it uses
Expand Down Expand Up @@ -932,10 +894,9 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
# Original way of computing the cross-entropy in `torch_neuronx`:
# from torch_neuronx.xla_impl.ops import SimpleCrossEntropyLoss
# output = SimpleCrossEntropyLoss.gen_override().forward(self, input, target)
output = safe_parallel_cross_entropy(
output = parallel_cross_entropy(
input,
target,
weight=self.weight,
ignore_index=self.ignore_index,
reduction=self.reduction,
label_smoothing=self.label_smoothing,
Expand All @@ -961,8 +922,8 @@ def patch_cross_entropy(cls, model: "PreTrainedModel"):
orig_forward = model.forward
patcher = patch_within_function(
[
("torch.nn.functional.cross_entropy", safe_parallel_cross_entropy),
("torch.nn.modules.loss.F.cross_entropy", safe_parallel_cross_entropy),
("torch.nn.functional.cross_entropy", parallel_cross_entropy),
("torch.nn.modules.loss.F.cross_entropy", parallel_cross_entropy),
]
)
model.forward = patcher(orig_forward)
Expand Down
Loading

0 comments on commit e1b2591

Please sign in to comment.