Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AutoBump] Merge with c3bd8509 (24) #255

Merged
merged 11 commits into from
Aug 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion externals/stablehlo
Submodule stablehlo updated 171 files
416 changes: 217 additions & 199 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp

Large diffs are not rendered by default.

166 changes: 104 additions & 62 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -851,9 +851,6 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
binder.tensorResultType(resultType))
return failure();

Torch::ValueTensorType inputType =
operand.getType().cast<Torch::ValueTensorType>();

Value vAlpha = rewriter.create<Torch::ConstantFloatOp>(
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
rewriter.getFloatAttr(rewriter.getF64Type(), alpha));
Expand All @@ -862,31 +859,12 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
rewriter.getFloatAttr(rewriter.getF64Type(), gamma));

Value cstOne = rewriter.create<Torch::ConstantFloatOp>(
Value vInputScale = rewriter.create<Torch::ConstantFloatOp>(
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
rewriter.getFloatAttr(rewriter.getF64Type(), 1.0));

Value cstNone = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
Value zeroTensor = rewriter.create<Torch::AtenZerosLikeOp>(
binder.getLoc(), resultType, operand, cstNone, cstNone, cstNone,
cstNone, cstNone);
Value exp = rewriter.create<Torch::AtenExpOp>(binder.getLoc(),
resultType, operand);
Value expMulAlpha = rewriter.create<Torch::AtenMulScalarOp>(
binder.getLoc(), resultType, exp, vAlpha);
Value expMulAlphaSubAlpha = rewriter.create<Torch::AtenSubScalarOp>(
binder.getLoc(), resultType, expMulAlpha, vAlpha, cstOne);
Value neg = rewriter.create<Torch::AtenMulScalarOp>(
binder.getLoc(), resultType, expMulAlphaSubAlpha, vScale);
Value pos = rewriter.create<Torch::AtenMulScalarOp>(
binder.getLoc(), resultType, operand, vScale);
Type compareType = inputType.getWithSizesAndDtype(
inputType.getOptionalSizes(), rewriter.getI1Type());
Value xLessThanZero = rewriter.create<Torch::AtenLtTensorOp>(
binder.getLoc(), compareType, operand, zeroTensor);

rewriter.replaceOpWithNewOp<Torch::AtenWhereSelfOp>(
binder.op, resultType, xLessThanZero, neg, pos);
rewriter.replaceOpWithNewOp<Torch::AtenEluOp>(
binder.op, resultType, operand, vAlpha, vScale, vInputScale);
return success();
});
patterns.onOp("ReduceL1", 1,
Expand Down Expand Up @@ -962,6 +940,32 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
/*memory_format=*/noneVal);
return success();
});
patterns.onOp("ReduceLogSum", 1,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
Value data;
int64_t keepDims, noop_with_empty_axes;
if (binder.tensorOperandAtIndex(data, 0) ||
binder.tensorResultType(resultType) ||
binder.s64IntegerAttr(keepDims, "keepdims", 1) ||
binder.s64IntegerAttr(noop_with_empty_axes,
"noop_with_empty_axes", 0))
return failure();

auto reducedSumBool =
reducedSumImpl(binder, rewriter, data, resultType,
/*storeValue=*/data, keepDims,
noop_with_empty_axes, true);

if (failed(reducedSumBool))
return rewriter.notifyMatchFailure(
binder.op,
"Failed to perform sum operation on square of operand");

rewriter.replaceOpWithNewOp<Torch::AtenLogOp>(
binder.op, resultType, data);
return success();
});
patterns.onOp("ReduceSum", 1,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
Expand All @@ -978,7 +982,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
/*storeValue=*/data, keepDims,
noop_with_empty_axes, false);
});
patterns.onOp("ReduceLogSum", 1,
patterns.onOp("ReduceSumSquare", 1,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
Value data;
Expand All @@ -990,19 +994,13 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
"noop_with_empty_axes", 0))
return failure();

auto reducedSumBool =
reducedSumImpl(binder, rewriter, data, resultType,
/*storeValue=*/data, keepDims,
noop_with_empty_axes, true);

if (failed(reducedSumBool))
return rewriter.notifyMatchFailure(
binder.op,
"Failed to perform sum operation on square of operand");
Value dataSquare = rewriter.create<Torch::AtenMulTensorOp>(
binder.getLoc(), data.getType(), data, data);

rewriter.replaceOpWithNewOp<Torch::AtenLogOp>(
binder.op, resultType, data);
return success();
return reducedSumImpl(binder, rewriter, dataSquare,
resultType,
/*storeValue=*/data, keepDims,
noop_with_empty_axes, false);
});
patterns.onOp(
"ReduceMean", 1,
Expand Down Expand Up @@ -1441,31 +1439,18 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
return success();
});

patterns.onOp(
"Sinh", 9, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
Value operand;
if (binder.tensorOperand(operand) ||
binder.tensorResultType(resultType))
return failure();
patterns.onOp("Sinh", 9,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
Value operand;
if (binder.tensorOperand(operand) ||
binder.tensorResultType(resultType))
return failure();

// 1/2 * (exp(x) – exp(-x))
Value x = rewriter.create<Torch::AtenExpOp>(binder.getLoc(), resultType,
operand);
Value neg = rewriter.create<Torch::AtenNegOp>(binder.getLoc(),
resultType, operand);
Value y =
rewriter.create<Torch::AtenExpOp>(binder.getLoc(), resultType, neg);
Value cstOne = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(1));
Value z = rewriter.create<Torch::AtenSubTensorOp>(
binder.getLoc(), resultType, x, y, cstOne);
Value cstTwo = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(2));
rewriter.replaceOpWithNewOp<Torch::AtenDivScalarOp>(
binder.op, resultType, z, cstTwo);
return success();
});
rewriter.replaceOpWithNewOp<Torch::AtenSinhOp>(
binder.op, resultType, operand);
return success();
});

// split with fixed-size parts
// Arguments:
Expand Down Expand Up @@ -2777,4 +2762,61 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
/*generator=*/cstNone);
return success();
});
patterns.onOp(
"SoftmaxCrossEntropyLoss", 12,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
int64_t ignoreIndex;
std::string reduction;
SmallVector<int64_t> shape;
Value scores, labels, weight;
if (binder.tensorOperandAtIndex(scores, 0) ||
binder.tensorOperandAtIndex(labels, 1) ||
binder.s64IntegerAttr(ignoreIndex, "ignore_index ", -100) ||
binder.customOpNameStringAttr(reduction, "reduction", "mean") ||
binder.tensorResultTypeAtIndex(resultType, 0)) {
return failure();
}

if (binder.tensorOperandAtIndex(weight, 2))
weight = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());

Value cstIgnoreIndex = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(ignoreIndex));

int64_t reductionInt = reduction == "none" ? 0
: reduction == "mean" ? 1
: 2;
Value cstReductionInt = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(reductionInt));

// The default PyTorch value for label smoothing is "0.0".
// Refer:
// https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html
Value cstLabelSmoothing = rewriter.create<Torch::ConstantFloatOp>(
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
rewriter.getFloatAttr(rewriter.getF64Type(), 0.0));

Value loss = rewriter.create<Torch::AtenCrossEntropyLossOp>(
binder.getLoc(), resultType, scores, labels, weight,
cstReductionInt, cstIgnoreIndex, cstLabelSmoothing);

if (binder.op->getNumResults() == 1) {
rewriter.replaceOp(binder.op, loss);
return success();
}

Torch::ValueTensorType resultTypeLogProb;
if (binder.tensorResultTypeAtIndex(resultTypeLogProb, 1))
return failure();

Value dim = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(1));
Value cstNone = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
Value logProb = rewriter.create<Torch::AtenLogSoftmaxIntOp>(
binder.getLoc(), resultTypeLogProb, scores, dim, /*dtype=*/cstNone);

rewriter.replaceOp(binder.op, {loss, logProb});
return success();
});
}
3 changes: 0 additions & 3 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2813,9 +2813,6 @@
"IndexPutHackedTwin3DFloatNonAccumulateModule_basic",
"IndexPutHackedTwin3DIntAccumulateModule_basic",
"IndexPutHackedTwin3DIntNonAccumulateModule_basic",
# Failure - onnx_lowering: onnx.SoftmaxCrossEntropyLoss
"CrossEntropyLossModule_basic",
"CrossEntropyLossNoReductionModule_basic",
# RuntimeError: unsupported input type: Device
"PrimsIotaModule_basic",
# Failure - unknown
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,13 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.

from typing import Union, Optional, Sequence

import numpy as np
import torch
import torch.utils._pytree as pytree
from torch.export.graph_signature import OutputSpec, OutputKind
from torch.export import ExportedProgram

from torch_mlir import fx
from torch_mlir.compiler_utils import (
run_pipeline_with_repro_report,
lower_mlir_module,
OutputType,
)
from torch_mlir.torchscript import (
BACKEND_LEGAL_OPS,
_canon_extra_library,
)
from torch_mlir_e2e_test.configs.utils import (
recursively_convert_to_numpy,
recursively_convert_from_numpy,
Expand All @@ -39,53 +28,6 @@ def refine_result_type(_result):
raise ValueError(f"Unhandled return type {type(_result)}")


def jit(
prog: ExportedProgram,
func_name: str,
output_type: Union[str, "OutputType"] = OutputType.TORCH,
backend_legal_ops: Optional[Sequence[str]] = None,
extra_library=None,
verbose: bool = False,
):
if extra_library is None:
extra_library = []
mlir_module = None

extra_library_file_name = _canon_extra_library(extra_library)
output_type = OutputType.get(output_type)
if backend_legal_ops is not None:
if output_type != OutputType.TORCH:
raise Exception(
"`backend_legal_ops` is only valid with the " "`torch` output type"
)
backend_legal_ops = list(sorted(set(backend_legal_ops)))
else:
backend_legal_ops = BACKEND_LEGAL_OPS.get(output_type, [])

option_string = (
"{backend-legal-ops="
+ ",".join(backend_legal_ops)
+ " extra-library="
+ extra_library_file_name
+ "}"
)

mlir_module = fx.export_and_import(prog, func_name=func_name)
assert mlir_module is not None
run_pipeline_with_repro_report(
mlir_module,
f"builtin.module(torch-simplification-pipeline)",
"Simplification pipeline for torch dialect",
)
run_pipeline_with_repro_report(
mlir_module,
f"builtin.module(torch-function-to-torch-backend-pipeline{option_string})",
"Lowering TorchFX IR -> Torch Backend IR",
)

return lower_mlir_module(verbose, output_type, mlir_module)


class FxImporterTestConfig(TestConfig):
"""TestConfig that runs the torch.nn.Module with Fx Importer"""

Expand All @@ -100,11 +42,11 @@ def compile(self, program: torch.nn.Module) -> torch.nn.Module:
def run(self, artifact: torch.nn.Module, trace: Trace) -> Trace:
result: Trace = []
for item in trace:
prog = torch.export.export(artifact, tuple(item.inputs))
module = jit(
prog: ExportedProgram = torch.export.export(artifact, tuple(item.inputs))
module = fx.export_and_import(
prog,
func_name=artifact.__class__.__name__,
output_type=self._output_type,
func_name=artifact.__class__.__name__,
)
module = self._backend.compile(module)
backend_module = self._backend.load(module)
Expand Down
Loading
Loading