Skip to content

Commit

Permalink
[AutoBump] Merge with c3bd850
Browse files Browse the repository at this point in the history
  • Loading branch information
mgehre-amd committed Aug 21, 2024
2 parents da6dca5 + c3bd850 commit 327bb5f
Show file tree
Hide file tree
Showing 13 changed files with 711 additions and 442 deletions.
2 changes: 1 addition & 1 deletion externals/stablehlo
Submodule stablehlo updated 171 files
416 changes: 217 additions & 199 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp

Large diffs are not rendered by default.

166 changes: 104 additions & 62 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -851,9 +851,6 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
binder.tensorResultType(resultType))
return failure();

Torch::ValueTensorType inputType =
operand.getType().cast<Torch::ValueTensorType>();

Value vAlpha = rewriter.create<Torch::ConstantFloatOp>(
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
rewriter.getFloatAttr(rewriter.getF64Type(), alpha));
Expand All @@ -862,31 +859,12 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
rewriter.getFloatAttr(rewriter.getF64Type(), gamma));

Value cstOne = rewriter.create<Torch::ConstantFloatOp>(
Value vInputScale = rewriter.create<Torch::ConstantFloatOp>(
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
rewriter.getFloatAttr(rewriter.getF64Type(), 1.0));

Value cstNone = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
Value zeroTensor = rewriter.create<Torch::AtenZerosLikeOp>(
binder.getLoc(), resultType, operand, cstNone, cstNone, cstNone,
cstNone, cstNone);
Value exp = rewriter.create<Torch::AtenExpOp>(binder.getLoc(),
resultType, operand);
Value expMulAlpha = rewriter.create<Torch::AtenMulScalarOp>(
binder.getLoc(), resultType, exp, vAlpha);
Value expMulAlphaSubAlpha = rewriter.create<Torch::AtenSubScalarOp>(
binder.getLoc(), resultType, expMulAlpha, vAlpha, cstOne);
Value neg = rewriter.create<Torch::AtenMulScalarOp>(
binder.getLoc(), resultType, expMulAlphaSubAlpha, vScale);
Value pos = rewriter.create<Torch::AtenMulScalarOp>(
binder.getLoc(), resultType, operand, vScale);
Type compareType = inputType.getWithSizesAndDtype(
inputType.getOptionalSizes(), rewriter.getI1Type());
Value xLessThanZero = rewriter.create<Torch::AtenLtTensorOp>(
binder.getLoc(), compareType, operand, zeroTensor);

rewriter.replaceOpWithNewOp<Torch::AtenWhereSelfOp>(
binder.op, resultType, xLessThanZero, neg, pos);
rewriter.replaceOpWithNewOp<Torch::AtenEluOp>(
binder.op, resultType, operand, vAlpha, vScale, vInputScale);
return success();
});
patterns.onOp("ReduceL1", 1,
Expand Down Expand Up @@ -962,6 +940,32 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
/*memory_format=*/noneVal);
return success();
});
patterns.onOp("ReduceLogSum", 1,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
Value data;
int64_t keepDims, noop_with_empty_axes;
if (binder.tensorOperandAtIndex(data, 0) ||
binder.tensorResultType(resultType) ||
binder.s64IntegerAttr(keepDims, "keepdims", 1) ||
binder.s64IntegerAttr(noop_with_empty_axes,
"noop_with_empty_axes", 0))
return failure();

auto reducedSumBool =
reducedSumImpl(binder, rewriter, data, resultType,
/*storeValue=*/data, keepDims,
noop_with_empty_axes, true);

if (failed(reducedSumBool))
return rewriter.notifyMatchFailure(
binder.op,
"Failed to perform sum operation on square of operand");

rewriter.replaceOpWithNewOp<Torch::AtenLogOp>(
binder.op, resultType, data);
return success();
});
patterns.onOp("ReduceSum", 1,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
Expand All @@ -978,7 +982,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
/*storeValue=*/data, keepDims,
noop_with_empty_axes, false);
});
patterns.onOp("ReduceLogSum", 1,
patterns.onOp("ReduceSumSquare", 1,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
Value data;
Expand All @@ -990,19 +994,13 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
"noop_with_empty_axes", 0))
return failure();

auto reducedSumBool =
reducedSumImpl(binder, rewriter, data, resultType,
/*storeValue=*/data, keepDims,
noop_with_empty_axes, true);

if (failed(reducedSumBool))
return rewriter.notifyMatchFailure(
binder.op,
"Failed to perform sum operation on square of operand");
Value dataSquare = rewriter.create<Torch::AtenMulTensorOp>(
binder.getLoc(), data.getType(), data, data);

rewriter.replaceOpWithNewOp<Torch::AtenLogOp>(
binder.op, resultType, data);
return success();
return reducedSumImpl(binder, rewriter, dataSquare,
resultType,
/*storeValue=*/data, keepDims,
noop_with_empty_axes, false);
});
patterns.onOp(
"ReduceMean", 1,
Expand Down Expand Up @@ -1441,31 +1439,18 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
return success();
});

patterns.onOp(
"Sinh", 9, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
Value operand;
if (binder.tensorOperand(operand) ||
binder.tensorResultType(resultType))
return failure();
patterns.onOp("Sinh", 9,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
Value operand;
if (binder.tensorOperand(operand) ||
binder.tensorResultType(resultType))
return failure();

// 1/2 * (exp(x) – exp(-x))
Value x = rewriter.create<Torch::AtenExpOp>(binder.getLoc(), resultType,
operand);
Value neg = rewriter.create<Torch::AtenNegOp>(binder.getLoc(),
resultType, operand);
Value y =
rewriter.create<Torch::AtenExpOp>(binder.getLoc(), resultType, neg);
Value cstOne = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(1));
Value z = rewriter.create<Torch::AtenSubTensorOp>(
binder.getLoc(), resultType, x, y, cstOne);
Value cstTwo = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(2));
rewriter.replaceOpWithNewOp<Torch::AtenDivScalarOp>(
binder.op, resultType, z, cstTwo);
return success();
});
rewriter.replaceOpWithNewOp<Torch::AtenSinhOp>(
binder.op, resultType, operand);
return success();
});

// split with fixed-size parts
// Arguments:
Expand Down Expand Up @@ -2777,4 +2762,61 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
/*generator=*/cstNone);
return success();
});
patterns.onOp(
"SoftmaxCrossEntropyLoss", 12,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
int64_t ignoreIndex;
std::string reduction;
SmallVector<int64_t> shape;
Value scores, labels, weight;
if (binder.tensorOperandAtIndex(scores, 0) ||
binder.tensorOperandAtIndex(labels, 1) ||
binder.s64IntegerAttr(ignoreIndex, "ignore_index ", -100) ||
binder.customOpNameStringAttr(reduction, "reduction", "mean") ||
binder.tensorResultTypeAtIndex(resultType, 0)) {
return failure();
}

if (binder.tensorOperandAtIndex(weight, 2))
weight = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());

Value cstIgnoreIndex = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(ignoreIndex));

int64_t reductionInt = reduction == "none" ? 0
: reduction == "mean" ? 1
: 2;
Value cstReductionInt = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(reductionInt));

// The default PyTorch value for label smoothing is "0.0".
// Refer:
// https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html
Value cstLabelSmoothing = rewriter.create<Torch::ConstantFloatOp>(
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
rewriter.getFloatAttr(rewriter.getF64Type(), 0.0));

Value loss = rewriter.create<Torch::AtenCrossEntropyLossOp>(
binder.getLoc(), resultType, scores, labels, weight,
cstReductionInt, cstIgnoreIndex, cstLabelSmoothing);

if (binder.op->getNumResults() == 1) {
rewriter.replaceOp(binder.op, loss);
return success();
}

Torch::ValueTensorType resultTypeLogProb;
if (binder.tensorResultTypeAtIndex(resultTypeLogProb, 1))
return failure();

Value dim = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(1));
Value cstNone = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
Value logProb = rewriter.create<Torch::AtenLogSoftmaxIntOp>(
binder.getLoc(), resultTypeLogProb, scores, dim, /*dtype=*/cstNone);

rewriter.replaceOp(binder.op, {loss, logProb});
return success();
});
}
3 changes: 0 additions & 3 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2813,9 +2813,6 @@
"IndexPutHackedTwin3DFloatNonAccumulateModule_basic",
"IndexPutHackedTwin3DIntAccumulateModule_basic",
"IndexPutHackedTwin3DIntNonAccumulateModule_basic",
# Failure - onnx_lowering: onnx.SoftmaxCrossEntropyLoss
"CrossEntropyLossModule_basic",
"CrossEntropyLossNoReductionModule_basic",
# RuntimeError: unsupported input type: Device
"PrimsIotaModule_basic",
# Failure - unknown
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,13 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.

from typing import Union, Optional, Sequence

import numpy as np
import torch
import torch.utils._pytree as pytree
from torch.export.graph_signature import OutputSpec, OutputKind
from torch.export import ExportedProgram

from torch_mlir import fx
from torch_mlir.compiler_utils import (
run_pipeline_with_repro_report,
lower_mlir_module,
OutputType,
)
from torch_mlir.torchscript import (
BACKEND_LEGAL_OPS,
_canon_extra_library,
)
from torch_mlir_e2e_test.configs.utils import (
recursively_convert_to_numpy,
recursively_convert_from_numpy,
Expand All @@ -39,53 +28,6 @@ def refine_result_type(_result):
raise ValueError(f"Unhandled return type {type(_result)}")


def jit(
prog: ExportedProgram,
func_name: str,
output_type: Union[str, "OutputType"] = OutputType.TORCH,
backend_legal_ops: Optional[Sequence[str]] = None,
extra_library=None,
verbose: bool = False,
):
if extra_library is None:
extra_library = []
mlir_module = None

extra_library_file_name = _canon_extra_library(extra_library)
output_type = OutputType.get(output_type)
if backend_legal_ops is not None:
if output_type != OutputType.TORCH:
raise Exception(
"`backend_legal_ops` is only valid with the " "`torch` output type"
)
backend_legal_ops = list(sorted(set(backend_legal_ops)))
else:
backend_legal_ops = BACKEND_LEGAL_OPS.get(output_type, [])

option_string = (
"{backend-legal-ops="
+ ",".join(backend_legal_ops)
+ " extra-library="
+ extra_library_file_name
+ "}"
)

mlir_module = fx.export_and_import(prog, func_name=func_name)
assert mlir_module is not None
run_pipeline_with_repro_report(
mlir_module,
f"builtin.module(torch-simplification-pipeline)",
"Simplification pipeline for torch dialect",
)
run_pipeline_with_repro_report(
mlir_module,
f"builtin.module(torch-function-to-torch-backend-pipeline{option_string})",
"Lowering TorchFX IR -> Torch Backend IR",
)

return lower_mlir_module(verbose, output_type, mlir_module)


class FxImporterTestConfig(TestConfig):
"""TestConfig that runs the torch.nn.Module with Fx Importer"""

Expand All @@ -100,11 +42,11 @@ def compile(self, program: torch.nn.Module) -> torch.nn.Module:
def run(self, artifact: torch.nn.Module, trace: Trace) -> Trace:
result: Trace = []
for item in trace:
prog = torch.export.export(artifact, tuple(item.inputs))
module = jit(
prog: ExportedProgram = torch.export.export(artifact, tuple(item.inputs))
module = fx.export_and_import(
prog,
func_name=artifact.__class__.__name__,
output_type=self._output_type,
func_name=artifact.__class__.__name__,
)
module = self._backend.compile(module)
backend_module = self._backend.load(module)
Expand Down
Loading

0 comments on commit 327bb5f

Please sign in to comment.