From 60d5151933939eb52226e1c2b0c7247fa54eff86 Mon Sep 17 00:00:00 2001 From: Anton Lydike Date: Thu, 23 Jan 2025 11:07:42 +0000 Subject: [PATCH] dialects: (llvm) Add a bunch of float methods --- tests/filecheck/dialects/llvm/arithmetic.mlir | 22 ++- tests/filecheck/dialects/llvm/example.mlir | 8 + xdsl/dialects/llvm.py | 157 +++++++++++++----- 3 files changed, 148 insertions(+), 39 deletions(-) diff --git a/tests/filecheck/dialects/llvm/arithmetic.mlir b/tests/filecheck/dialects/llvm/arithmetic.mlir index d305c61ab0..bc57548f78 100644 --- a/tests/filecheck/dialects/llvm/arithmetic.mlir +++ b/tests/filecheck/dialects/llvm/arithmetic.mlir @@ -1,6 +1,6 @@ // RUN: XDSL_ROUNDTRIP -%arg0, %arg1 = "test.op"() : () -> (i32, i32) +%arg0, %arg1, %f1 = "test.op"() : () -> (i32, i32, f32) %add_both = llvm.add %arg0, %arg1 {"overflowFlags" = #llvm.overflow} : i32 // CHECK: %add_both = llvm.add %arg0, %arg1 {overflowFlags = #llvm.overflow} : i32 @@ -121,3 +121,23 @@ %icmp_uge = llvm.icmp "uge" %arg0, %arg1 : i32 // CHECK: %icmp_uge = llvm.icmp "uge" %arg0, %arg1 : i32 + +// float arith: + +%fmul = llvm.fmul %f1, %f1 : f32 +// CHECK: %fmul = llvm.fmul %f1, %f1 : f32 + +%fmul_fast = llvm.fmul %f1, %f1 {fastmathFlags = #llvm.fastmath} : f32 +// CHECK: %fmul_fast = llvm.fmul %f1, %f1 {fastmathFlags = #llvm.fastmath} : f32 + +%fdiv = llvm.fdiv %f1, %f1 : f32 +// CHECK: %fdiv = llvm.fdiv %f1, %f1 : f32 + +%fadd = llvm.fadd %f1, %f1 : f32 +// CHECK: %fadd = llvm.fadd %f1, %f1 : f32 + +%fsub = llvm.fsub %f1, %f1 : f32 +// CHECK: %fsub = llvm.fsub %f1, %f1 : f32 + +%frem = llvm.frem %f1, %f1 : f32 +// CHECK: %frem = llvm.frem %f1, %f1 : f32 diff --git a/tests/filecheck/dialects/llvm/example.mlir b/tests/filecheck/dialects/llvm/example.mlir index 0bec0a27c9..c8222ab472 100644 --- a/tests/filecheck/dialects/llvm/example.mlir +++ b/tests/filecheck/dialects/llvm/example.mlir @@ -90,4 +90,12 @@ builtin.module { // CHECK: %val = "test.op"() : () -> i32 // CHECK-NEXT: %fval = llvm.bitcast %val : i32 to f32 + + %fval2 = llvm.sitofp %val : i32 to f32 + +// CHECK-NEXT: %fval2 = llvm.sitofp %val : i32 to f32 + + %fval3 = llvm.fpext %fval : f32 to f64 + +// CHECK-NEXT: %fval3 = llvm.fpext %fval : f32 to f64 } diff --git a/xdsl/dialects/llvm.py b/xdsl/dialects/llvm.py index f1570f057c..861cb27208 100644 --- a/xdsl/dialects/llvm.py +++ b/xdsl/dialects/llvm.py @@ -13,6 +13,9 @@ ContainerType, DenseArrayBase, DenseI64ArrayConstr, + Float16Type, + Float32Type, + Float64Type, IndexType, IntAttr, IntegerAttr, @@ -25,7 +28,7 @@ i32, i64, ) -from xdsl.dialects.utils import FastMathAttrBase +from xdsl.dialects.utils import FastMathAttrBase, FastMathFlag from xdsl.ir import ( Attribute, BitEnumAttribute, @@ -46,6 +49,7 @@ irdl_attr_definition, irdl_op_definition, operand_def, + opt_attr_def, opt_operand_def, opt_prop_def, opt_result_def, @@ -54,6 +58,7 @@ result_def, traits_def, var_operand_def, + AnyOf, ) from xdsl.parser import AttrParser, Parser from xdsl.printer import Printer @@ -1710,10 +1715,7 @@ class ZeroOp(IRDLOperation): res = result_def(LLVMTypeConstr) -@irdl_op_definition -class BitcastOp(IRDLOperation): - name = "llvm.bitcast" - +class GenericCastOp(IRDLOperation, ABC): arg = operand_def(Attribute) """ LLVM-compatible non-aggregate type @@ -1735,56 +1737,135 @@ def __init__(self, val: Operation | SSAValue, res_type: Attribute): ) +floatingPointLike = AnyOf([Float16Type, Float32Type, Float64Type]) + + +class AbstractFloatArithOp(IRDLOperation, ABC): + T: ClassVar = VarConstraint("T", floatingPointLike) + + lhs = operand_def(T) + rhs = operand_def(T) + res = result_def(T) + + fastmathFlags = opt_attr_def(FastMathAttr) + + traits = traits_def(NoMemoryEffect()) + + assembly_format = "$lhs `,` $rhs attr-dict `:` type($lhs)" + + def __init__( + self, + lhs: SSAValue | Operation, + rhs: SSAValue | Operation, + fast_math: FastMathAttr | FastMathFlag | None, + ): + if isinstance(fast_math, FastMathFlag | str): + fast_math = FastMathAttr(fast_math) + + super().__init__( + operands=[lhs, rhs], + result_types=[SSAValue.get(lhs).type], + attributes={"fastmathFlags": fast_math}, + ) + + +@irdl_op_definition +class FAddOp(AbstractFloatArithOp): + name = "llvm.fadd" + + +@irdl_op_definition +class FMulOp(AbstractFloatArithOp): + name = "llvm.fmul" + + +@irdl_op_definition +class FDivOp(AbstractFloatArithOp): + name = "llvm.fdiv" + + +@irdl_op_definition +class FSubOp(AbstractFloatArithOp): + name = "llvm.fsub" + + +@irdl_op_definition +class FRemOp(AbstractFloatArithOp): + name = "llvm.frem" + + +@irdl_op_definition +class BitcastOp(GenericCastOp): + name = "llvm.bitcast" + + +@irdl_op_definition +class SIToFPOp(GenericCastOp): + name = "llvm.sitofp" + + +@irdl_op_definition +class FPExtOp(GenericCastOp): + name = "llvm.fpext" + + LLVM = Dialect( "llvm", [ + AShrOp, AddOp, + AddressOfOp, + AllocaOp, + AndOp, BitcastOp, - SubOp, + CallIntrinsicOp, + CallOp, + ConstantOp, + ExtractValueOp, + FAddOp, + FDivOp, + FMulOp, + FPExtOp, + FRemOp, + FSubOp, + FuncOp, + GEPOp, + GlobalOp, + ICmpOp, + InlineAsmOp, + InsertValueOp, + IntToPtrOp, + LShrOp, + LoadOp, MulOp, - UDivOp, + NullOp, + OrOp, + ReturnOp, SDivOp, - URemOp, + SExtOp, + SIToFPOp, SRemOp, - AndOp, - OrOp, - XOrOp, ShlOp, - LShrOp, - AShrOp, + StoreOp, + SubOp, TruncOp, - ZExtOp, - SExtOp, - ICmpOp, - ExtractValueOp, - InsertValueOp, - InlineAsmOp, + UDivOp, + URemOp, UndefOp, - AllocaOp, - GEPOp, - IntToPtrOp, - NullOp, - LoadOp, - StoreOp, - GlobalOp, - AddressOfOp, - FuncOp, - CallOp, - ReturnOp, - ConstantOp, - CallIntrinsicOp, + XOrOp, + ZExtOp, ZeroOp, ], [ - LLVMStructType, - LLVMPointerType, + CallingConventionAttr, + FastMathAttr, LLVMArrayType, - LLVMVoidType, LLVMFunctionType, + LLVMPointerType, + LLVMStructType, + LLVMVoidType, LinkageAttr, - CallingConventionAttr, - TailCallKindAttr, - FastMathAttr, OverflowAttr, + TailCallKindAttr, ], )