Skip to content

Commit

Permalink
Add where, gt, bucketize and reshape ops to Torch dialect
Browse files Browse the repository at this point in the history
This patch adds the where, gt, bucketize and reshape
ops to the Torch dialect. These ops are present in the histogram
calibration module.

TEST: Successfully lowers ops to Torch dialect in histogram module.
  • Loading branch information
harsh-nod authored and silvasean committed Dec 10, 2021
1 parent cfc8de3 commit 03b6edc
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 17 deletions.
112 changes: 96 additions & 16 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -512,6 +512,36 @@ def Torch_AtenEq_TensorOp : Torch_Op<"aten.eq_.Tensor", [
let assemblyFormat = "$self `,` $other attr-dict `:` type($self) `,` type($other) `->` type($result)";
}

def Torch_AtenGtTensorOp : Torch_Op<"aten.gt.Tensor", [
AllowsTypeRefinement,
HasValueSemantics
]> {
let summary = "Generated op for `aten::gt.Tensor : (Tensor, Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchTensorType:$other
);
let results = (outs
AnyTorchTensorType:$result
);
let assemblyFormat = "$self `,` $other attr-dict `:` type($self) `,` type($other) `->` type($result)";
}

def Torch_AtenGt_TensorOp : Torch_Op<"aten.gt_.Tensor", [
IsTrailingUnderscoreInplaceVariant,
AllowsTypeRefinement
]> {
let summary = "Generated op for `aten::gt_.Tensor : (Tensor, Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchTensorType:$other
);
let results = (outs
AnyTorchTensorType:$result
);
let assemblyFormat = "$self `,` $other attr-dict `:` type($self) `,` type($other) `->` type($result)";
}

def Torch_AtenNeTensorOp : Torch_Op<"aten.ne.Tensor", [
AllowsTypeRefinement,
HasValueSemantics
Expand Down Expand Up @@ -1071,22 +1101,6 @@ def Torch_AtenMaximumOp : Torch_Op<"aten.maximum", [
let assemblyFormat = "$self `,` $other attr-dict `:` type($self) `,` type($other) `->` type($result)";
}

def Torch_AtenWhereSelfOp : Torch_Op<"aten.where.self", [
AllowsTypeRefinement,
HasValueSemantics
]> {
let summary = "Generated op for `aten::where.self : (Tensor, Tensor, Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$condition,
AnyTorchTensorType:$self,
AnyTorchTensorType:$other
);
let results = (outs
AnyTorchTensorType:$result
);
let assemblyFormat = "$condition `,` $self `,` $other attr-dict `:` type($condition) `,` type($self) `,` type($other) `->` type($result)";
}

def Torch_AtenMinimumOp : Torch_Op<"aten.minimum", [
AllowsTypeRefinement,
HasValueSemantics
Expand Down Expand Up @@ -1942,6 +1956,23 @@ def Torch_AtenArgmaxOp : Torch_Op<"aten.argmax", [
let assemblyFormat = "$self `,` $dim `,` $keepdim attr-dict `:` type($self) `,` type($dim) `,` type($keepdim) `->` type($result)";
}

def Torch_AtenBucketizeTensorOp : Torch_Op<"aten.bucketize.Tensor", [
AllowsTypeRefinement,
HasValueSemantics
]> {
let summary = "Generated op for `aten::bucketize.Tensor : (Tensor, Tensor, bool, bool) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchTensorType:$boundaries,
Torch_BoolType:$out_int32,
Torch_BoolType:$right
);
let results = (outs
AnyTorchTensorType:$result
);
let assemblyFormat = "$self `,` $boundaries `,` $out_int32 `,` $right attr-dict `:` type($self) `,` type($boundaries) `,` type($out_int32) `,` type($right) `->` type($result)";
}

def Torch_AtenContiguousOp : Torch_Op<"aten.contiguous", [
AllowsTypeRefinement
]> {
Expand Down Expand Up @@ -2002,6 +2033,25 @@ def Torch_AtenEmbeddingOp : Torch_Op<"aten.embedding", [
let assemblyFormat = "$weight `,` $indices `,` $padding_idx `,` $scale_grad_by_freq `,` $sparse attr-dict `:` type($weight) `,` type($indices) `,` type($padding_idx) `,` type($scale_grad_by_freq) `,` type($sparse) `->` type($result)";
}

def Torch_AtenEmptyLikeOp : Torch_Op<"aten.empty_like", [
AllowsTypeRefinement,
HasValueSemantics
]> {
let summary = "Generated op for `aten::empty_like : (Tensor, int?, int?, Device?, bool?, int?) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
TorchOptionalIntType:$dtype,
TorchOptionalIntType:$layout,
TorchOptionalDeviceType:$device,
TorchOptionalBoolType:$pin_memory,
TorchOptionalIntType:$memory_format
);
let results = (outs
AnyTorchTensorType:$result
);
let assemblyFormat = "$self `,` $dtype `,` $layout `,` $device `,` $pin_memory `,` $memory_format attr-dict `:` type($self) `,` type($dtype) `,` type($layout) `,` type($device) `,` type($pin_memory) `,` type($memory_format) `->` type($result)";
}

def Torch_AtenEmptyMemoryFormatOp : Torch_Op<"aten.empty.memory_format", [
AllowsTypeRefinement,
HasValueSemantics
Expand Down Expand Up @@ -2139,6 +2189,20 @@ def Torch_AtenRepeatOp : Torch_Op<"aten.repeat", [
let assemblyFormat = "$self `,` $repeats attr-dict `:` type($self) `,` type($repeats) `->` type($result)";
}

def Torch_AtenReshapeOp : Torch_Op<"aten.reshape", [
AllowsTypeRefinement
]> {
let summary = "Generated op for `aten::reshape : (Tensor, int[]) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
TorchIntListType:$shape
);
let results = (outs
AnyTorchTensorType:$result
);
let assemblyFormat = "$self `,` $shape attr-dict `:` type($self) `,` type($shape) `->` type($result)";
}

def Torch_AtenResize_Op : Torch_Op<"aten.resize_", [
AllowsTypeRefinement
]> {
Expand Down Expand Up @@ -2312,6 +2376,22 @@ def Torch_AtenViewOp : Torch_Op<"aten.view", [
let assemblyFormat = "$self `,` $size attr-dict `:` type($self) `,` type($size) `->` type($result)";
}

def Torch_AtenWhereSelfOp : Torch_Op<"aten.where.self", [
AllowsTypeRefinement,
HasValueSemantics
]> {
let summary = "Generated op for `aten::where.self : (Tensor, Tensor, Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$condition,
AnyTorchTensorType:$self,
AnyTorchTensorType:$other
);
let results = (outs
AnyTorchTensorType:$result
);
let assemblyFormat = "$condition `,` $self `,` $other attr-dict `:` type($condition) `,` type($self) `,` type($other) `->` type($result)";
}

def Torch_AtenSliceTensorOp : Torch_Op<"aten.slice.Tensor", [
AllowsTypeRefinement
]> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -454,6 +454,7 @@ def emit_with_mutating_variants(key, **kwargs):
"aten::div.Tensor : (Tensor, Tensor) -> (Tensor)",
"aten::lerp.Tensor : (Tensor, Tensor, Tensor) -> (Tensor)",
"aten::eq.Tensor : (Tensor, Tensor) -> (Tensor)",
"aten::gt.Tensor : (Tensor, Tensor) -> (Tensor)",
"aten::ne.Tensor : (Tensor, Tensor) -> (Tensor)",
"aten::add.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)",
"aten::sub.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)",
Expand All @@ -479,7 +480,6 @@ def emit_with_mutating_variants(key, **kwargs):
emit("aten::addcmul : (Tensor, Tensor, Tensor, Scalar) -> (Tensor)")
emit("aten::addcdiv : (Tensor, Tensor, Tensor, Scalar) -> (Tensor)")
emit("aten::maximum : (Tensor, Tensor) -> (Tensor)")
emit("aten::where.self : (Tensor, Tensor, Tensor) -> (Tensor)")
emit("aten::minimum : (Tensor, Tensor) -> (Tensor)")
emit("aten::rsub.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)")
emit("aten::gelu : (Tensor) -> (Tensor)")
Expand Down Expand Up @@ -550,10 +550,12 @@ def emit_with_mutating_variants(key, **kwargs):
emit("aten::arange : (Scalar, int?, int?, Device?, bool?) -> (Tensor)")
emit("aten::arange.start : (Scalar, Scalar, int?, int?, Device?, bool?) -> (Tensor)")
emit("aten::argmax : (Tensor, int?, bool) -> (Tensor)")
emit("aten::bucketize.Tensor : (Tensor, Tensor, bool, bool) -> (Tensor)")
emit("aten::contiguous : (Tensor, int) -> (Tensor)")
emit("aten::copy_ : (Tensor, Tensor, bool) -> (Tensor)")
emit("aten::detach : (Tensor) -> (Tensor)")
emit("aten::embedding : (Tensor, Tensor, int, bool, bool) -> (Tensor)")
emit("aten::empty_like : (Tensor, int?, int?, Device?, bool?, int?) -> (Tensor)")
emit("aten::empty.memory_format : (int[], int?, int?, Device?, bool?, int?) -> (Tensor)")
emit("aten::expand : (Tensor, int[], bool) -> (Tensor)")
emit("aten::broadcast_to : (Tensor, int[]) -> (Tensor)")
Expand All @@ -563,6 +565,7 @@ def emit_with_mutating_variants(key, **kwargs):
emit("aten::masked_select : (Tensor, Tensor) -> (Tensor)")
emit("aten::numel : (Tensor) -> (int)")
emit("aten::repeat : (Tensor, int[]) -> (Tensor)")
emit("aten::reshape : (Tensor, int[]) -> (Tensor)")
emit("aten::resize_ : (Tensor, int[], int?) -> (Tensor)")
emit("aten::select.int : (Tensor, int, int) -> (Tensor)")
emit("aten::size.int : (Tensor, int) -> (int)", has_folder=True)
Expand All @@ -574,6 +577,7 @@ def emit_with_mutating_variants(key, **kwargs):
emit("aten::to.prim_Device : (Tensor, Device?, int?, bool, bool) -> (Tensor)")
emit("aten::type_as : (Tensor, Tensor) -> (Tensor)")
emit("aten::view : (Tensor, int[]) -> (Tensor)")
emit("aten::where.self : (Tensor, Tensor, Tensor) -> (Tensor)")
emit("aten::slice.Tensor : (Tensor, int, int?, int?, int) -> (Tensor)")
emit("aten::len.Tensor : (Tensor) -> (int)")
emit("aten::cpu : (Tensor) -> (Tensor)")
Expand Down

0 comments on commit 03b6edc

Please sign in to comment.