Skip to content

Commit

Permalink
dialects: (llvm) Add a bunch of float methods
Browse files Browse the repository at this point in the history
  • Loading branch information
AntonLydike committed Jan 23, 2025
1 parent 9cdc462 commit 60d5151
Show file tree
Hide file tree
Showing 3 changed files with 148 additions and 39 deletions.
22 changes: 21 additions & 1 deletion tests/filecheck/dialects/llvm/arithmetic.mlir
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// RUN: XDSL_ROUNDTRIP

%arg0, %arg1 = "test.op"() : () -> (i32, i32)
%arg0, %arg1, %f1 = "test.op"() : () -> (i32, i32, f32)

%add_both = llvm.add %arg0, %arg1 {"overflowFlags" = #llvm.overflow<nsw, nuw>} : i32
// CHECK: %add_both = llvm.add %arg0, %arg1 {overflowFlags = #llvm.overflow<nsw,nuw>} : i32
Expand Down Expand Up @@ -121,3 +121,23 @@

%icmp_uge = llvm.icmp "uge" %arg0, %arg1 : i32
// CHECK: %icmp_uge = llvm.icmp "uge" %arg0, %arg1 : i32

// float arith:

%fmul = llvm.fmul %f1, %f1 : f32
// CHECK: %fmul = llvm.fmul %f1, %f1 : f32

%fmul_fast = llvm.fmul %f1, %f1 {fastmathFlags = #llvm.fastmath<fast>} : f32
// CHECK: %fmul_fast = llvm.fmul %f1, %f1 {fastmathFlags = #llvm.fastmath<fast>} : f32

%fdiv = llvm.fdiv %f1, %f1 : f32
// CHECK: %fdiv = llvm.fdiv %f1, %f1 : f32

%fadd = llvm.fadd %f1, %f1 : f32
// CHECK: %fadd = llvm.fadd %f1, %f1 : f32

%fsub = llvm.fsub %f1, %f1 : f32
// CHECK: %fsub = llvm.fsub %f1, %f1 : f32

%frem = llvm.frem %f1, %f1 : f32
// CHECK: %frem = llvm.frem %f1, %f1 : f32
8 changes: 8 additions & 0 deletions tests/filecheck/dialects/llvm/example.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -90,4 +90,12 @@ builtin.module {

// CHECK: %val = "test.op"() : () -> i32
// CHECK-NEXT: %fval = llvm.bitcast %val : i32 to f32

%fval2 = llvm.sitofp %val : i32 to f32

// CHECK-NEXT: %fval2 = llvm.sitofp %val : i32 to f32

%fval3 = llvm.fpext %fval : f32 to f64

// CHECK-NEXT: %fval3 = llvm.fpext %fval : f32 to f64
}
157 changes: 119 additions & 38 deletions xdsl/dialects/llvm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
ContainerType,
DenseArrayBase,
DenseI64ArrayConstr,
Float16Type,
Float32Type,
Float64Type,
IndexType,
IntAttr,
IntegerAttr,
Expand All @@ -25,7 +28,7 @@
i32,
i64,
)
from xdsl.dialects.utils import FastMathAttrBase
from xdsl.dialects.utils import FastMathAttrBase, FastMathFlag
from xdsl.ir import (
Attribute,
BitEnumAttribute,
Expand All @@ -46,6 +49,7 @@
irdl_attr_definition,
irdl_op_definition,
operand_def,
opt_attr_def,
opt_operand_def,
opt_prop_def,
opt_result_def,
Expand All @@ -54,6 +58,7 @@
result_def,
traits_def,
var_operand_def,
AnyOf,
)
from xdsl.parser import AttrParser, Parser
from xdsl.printer import Printer
Expand Down Expand Up @@ -1710,10 +1715,7 @@ class ZeroOp(IRDLOperation):
res = result_def(LLVMTypeConstr)


@irdl_op_definition
class BitcastOp(IRDLOperation):
name = "llvm.bitcast"

class GenericCastOp(IRDLOperation, ABC):
arg = operand_def(Attribute)
"""
LLVM-compatible non-aggregate type
Expand All @@ -1735,56 +1737,135 @@ def __init__(self, val: Operation | SSAValue, res_type: Attribute):
)


floatingPointLike = AnyOf([Float16Type, Float32Type, Float64Type])


class AbstractFloatArithOp(IRDLOperation, ABC):
T: ClassVar = VarConstraint("T", floatingPointLike)

lhs = operand_def(T)
rhs = operand_def(T)
res = result_def(T)

fastmathFlags = opt_attr_def(FastMathAttr)

traits = traits_def(NoMemoryEffect())

assembly_format = "$lhs `,` $rhs attr-dict `:` type($lhs)"

def __init__(
self,
lhs: SSAValue | Operation,
rhs: SSAValue | Operation,
fast_math: FastMathAttr | FastMathFlag | None,
):
if isinstance(fast_math, FastMathFlag | str):
fast_math = FastMathAttr(fast_math)

super().__init__(
operands=[lhs, rhs],
result_types=[SSAValue.get(lhs).type],
attributes={"fastmathFlags": fast_math},
)


@irdl_op_definition
class FAddOp(AbstractFloatArithOp):
name = "llvm.fadd"


@irdl_op_definition
class FMulOp(AbstractFloatArithOp):
name = "llvm.fmul"


@irdl_op_definition
class FDivOp(AbstractFloatArithOp):
name = "llvm.fdiv"


@irdl_op_definition
class FSubOp(AbstractFloatArithOp):
name = "llvm.fsub"


@irdl_op_definition
class FRemOp(AbstractFloatArithOp):
name = "llvm.frem"


@irdl_op_definition
class BitcastOp(GenericCastOp):
name = "llvm.bitcast"


@irdl_op_definition
class SIToFPOp(GenericCastOp):
name = "llvm.sitofp"


@irdl_op_definition
class FPExtOp(GenericCastOp):
name = "llvm.fpext"


LLVM = Dialect(
"llvm",
[
AShrOp,
AddOp,
AddressOfOp,
AllocaOp,
AndOp,
BitcastOp,
SubOp,
CallIntrinsicOp,
CallOp,
ConstantOp,
ExtractValueOp,
FAddOp,
FDivOp,
FMulOp,
FPExtOp,
FRemOp,
FSubOp,
FuncOp,
GEPOp,
GlobalOp,
ICmpOp,
InlineAsmOp,
InsertValueOp,
IntToPtrOp,
LShrOp,
LoadOp,
MulOp,
UDivOp,
NullOp,
OrOp,
ReturnOp,
SDivOp,
URemOp,
SExtOp,
SIToFPOp,
SRemOp,
AndOp,
OrOp,
XOrOp,
ShlOp,
LShrOp,
AShrOp,
StoreOp,
SubOp,
TruncOp,
ZExtOp,
SExtOp,
ICmpOp,
ExtractValueOp,
InsertValueOp,
InlineAsmOp,
UDivOp,
URemOp,
UndefOp,
AllocaOp,
GEPOp,
IntToPtrOp,
NullOp,
LoadOp,
StoreOp,
GlobalOp,
AddressOfOp,
FuncOp,
CallOp,
ReturnOp,
ConstantOp,
CallIntrinsicOp,
XOrOp,
ZExtOp,
ZeroOp,
],
[
LLVMStructType,
LLVMPointerType,
CallingConventionAttr,
FastMathAttr,
LLVMArrayType,
LLVMVoidType,
LLVMFunctionType,
LLVMPointerType,
LLVMStructType,
LLVMVoidType,
LinkageAttr,
CallingConventionAttr,
TailCallKindAttr,
FastMathAttr,
OverflowAttr,
TailCallKindAttr,
],
)

0 comments on commit 60d5151

Please sign in to comment.