Skip to content

Commit

Permalink
dialects: (builtin) mimic mlir floating point precision for printing …
Browse files Browse the repository at this point in the history
…and parsing (#3607)

I'm running into some precision issues when printing and parsing floats
after they have been packed/unpacked.
This PR tries to resolve the issues with printing and parsing, trying to
mimic MLIR behaviour as much as possible.
  • Loading branch information
jorendumoulin authored Dec 12, 2024
1 parent 414bcb0 commit 7e38685
Show file tree
Hide file tree
Showing 14 changed files with 117 additions and 56 deletions.
48 changes: 24 additions & 24 deletions tests/filecheck/backend/csl/print_csl.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@
%int2 = arith.constant 1 : i32

%float1 = arith.constant 0.0 : f32
%float2 = arith.constant 1.1 : f32
%float2 = arith.constant 1.5 : f32

%intEq = arith.cmpi eq, %int1, %int2 : i32
%intNe = arith.cmpi ne, %int1, %int2 : i32
Expand Down Expand Up @@ -204,16 +204,16 @@
}

"memref.global"() {"sym_name" = "uninit_array", "type" = memref<10xf32>, "sym_visibility" = "public", "initial_value"} : () -> ()
"memref.global"() {"sym_name" = "global_array", "type" = memref<10xf32>, "sym_visibility" = "public", "initial_value" = dense<4.2> : tensor<1xf32>} : () -> ()
"memref.global"() {"sym_name" = "global_array", "type" = memref<10xf32>, "sym_visibility" = "public", "initial_value" = dense<4.5> : tensor<1xf32>} : () -> ()
"memref.global"() {"sym_name" = "const_array", "type" = memref<10xi32>, "sym_visibility" = "public", "constant", "initial_value" = dense<10> : tensor<1xi32>} : () -> ()


%uninit_array = memref.get_global @uninit_array : memref<10xf32>
%global_array = memref.get_global @global_array : memref<10xf32>
%const_array = memref.get_global @const_array : memref<10xi32>

%literal_array = arith.constant dense<[1.200000e+00, 2.300000e+00, 3.400000e+00]> : memref<3xf32>
%literal_array_w_zeros = arith.constant dense<[1.200000e+00, 0, 3.400000e+00, 0]> : memref<4xf32>
%literal_array = arith.constant dense<[1.500000e+00, 2.500000e+00, 3.500000e+00]> : memref<3xf32>
%literal_array_w_zeros = arith.constant dense<[1.500000e+00, 0, 3.500000e+00, 0]> : memref<4xf32>

%uninit_ptr = "csl.addressof"(%uninit_array) : (memref<10xf32>) -> !csl.ptr<f32, #csl<ptr_kind many>, #csl<ptr_const var>>
%global_ptr = "csl.addressof"(%global_array) : (memref<10xf32>) -> !csl.ptr<f32, #csl<ptr_kind many>, #csl<ptr_const var>>
Expand Down Expand Up @@ -273,8 +273,8 @@ csl.func @initialize() {
// member access
%11 = "csl.member_access"(%thing) <{field = "some_field"}> : (!csl.imported_module) -> !csl.comptime_struct

%0 = arith.constant 3.14 : f32
%v0 = arith.constant 2.718 : f16
%0 = arith.constant 3.5 : f32
%v0 = arith.constant 2.5 : f16

%u32cst = arith.constant 44 : ui32

Expand Down Expand Up @@ -481,7 +481,7 @@ csl.func @builtins() {

"csl.module"() <{kind=#csl<module_kind layout>}> ({
%x_dim = "csl.param"() <{param_name = "param_1"}> : () -> i32
%init = arith.constant 3.14 : f16
%init = arith.constant 3.5 : f16
%p2 = "csl.param"(%init) <{param_name = "param_2"}> : (f16) -> f16

csl.layout {
Expand Down Expand Up @@ -598,18 +598,18 @@ csl.func @builtins() {
// CHECK-NEXT: const intUgt : bool = 0 > 1;
// CHECK-NEXT: const intUge : bool = 0 >= 1;
// CHECK-NEXT: const floatFalse : bool = false;
// CHECK-NEXT: const floatOeq : bool = 0.0 == 1.1;
// CHECK-NEXT: const floatOgt : bool = 0.0 > 1.1;
// CHECK-NEXT: const floatOlt : bool = 0.0 < 1.1;
// CHECK-NEXT: const floatOle : bool = 0.0 <= 1.1;
// CHECK-NEXT: const floatOne : bool = 0.0 != 1.1;
// CHECK-NEXT: const floatUeq : bool = 0.0 == 1.1;
// CHECK-NEXT: const floatUge : bool = 0.0 >= 1.1;
// CHECK-NEXT: const floatUlt : bool = 0.0 < 1.1;
// CHECK-NEXT: const floatUle : bool = 0.0 <= 1.1;
// CHECK-NEXT: const floatUne : bool = 0.0 != 1.1;
// CHECK-NEXT: const floatOeq : bool = 0.0 == 1.5;
// CHECK-NEXT: const floatOgt : bool = 0.0 > 1.5;
// CHECK-NEXT: const floatOlt : bool = 0.0 < 1.5;
// CHECK-NEXT: const floatOle : bool = 0.0 <= 1.5;
// CHECK-NEXT: const floatOne : bool = 0.0 != 1.5;
// CHECK-NEXT: const floatUeq : bool = 0.0 == 1.5;
// CHECK-NEXT: const floatUge : bool = 0.0 >= 1.5;
// CHECK-NEXT: const floatUlt : bool = 0.0 < 1.5;
// CHECK-NEXT: const floatUle : bool = 0.0 <= 1.5;
// CHECK-NEXT: const floatUne : bool = 0.0 != 1.5;
// CHECK-NEXT: const floatTrue : bool = true;
// CHECK-NEXT: return (((0 <= 1) or (0.0 > 1.1)) and (0.0 >= 1.1));
// CHECK-NEXT: return (((0 <= 1) or (0.0 > 1.5)) and (0.0 >= 1.5));
// CHECK-NEXT: }
// CHECK-NEXT: {{ *}}
// CHECK-NEXT: fn select() mem1d_dsd {
Expand Down Expand Up @@ -673,10 +673,10 @@ csl.func @builtins() {
// CHECK-NEXT: return;
// CHECK-NEXT: }
// CHECK-NEXT: var uninit_array : [10]f32;
// CHECK-NEXT: var global_array : [10]f32 = @constants([10]f32, 4.2);
// CHECK-NEXT: var global_array : [10]f32 = @constants([10]f32, 4.5);
// CHECK-NEXT: const const_array : [10]i32 = @constants([10]i32, 10);
// CHECK-NEXT: const literal_array : [3]f32 = [3]f32 { 1.2, 2.3, 3.4 };
// CHECK-NEXT: const literal_array_w_zeros : [4]f32 = [4]f32 { 1.2, 0.0, 3.4, 0.0 };
// CHECK-NEXT: const literal_array : [3]f32 = [3]f32 { 1.5, 2.5, 3.5 };
// CHECK-NEXT: const literal_array_w_zeros : [4]f32 = [4]f32 { 1.5, 0.0, 3.5, 0.0 };
// CHECK-NEXT: var uninit_ptr : [*]f32 = &uninit_array;
// CHECK-NEXT: var global_ptr : [*]f32 = &global_array;
// CHECK-NEXT: const const_ptr : [*]const i32 = &const_array;
Expand Down Expand Up @@ -712,8 +712,8 @@ csl.func @builtins() {
// CHECK-NEXT: thing.some_func(0, 24);
// CHECK-NEXT: const res : i32 = thing.some_func(0, 24);
// CHECK-NEXT: const v1 : comptime_struct = thing.some_field;
// CHECK-NEXT: const v2 : f32 = 3.14;
// CHECK-NEXT: const v0 : f16 = 2.718;
// CHECK-NEXT: const v2 : f32 = 3.5;
// CHECK-NEXT: const v0 : f16 = 2.5;
// CHECK-NEXT: const u32cst : u32 = 44;
// CHECK-NEXT: {{ *}}
// CHECK-NEXT: for(@range(i16, 0, 24, 1)) |idx| {
Expand Down Expand Up @@ -873,7 +873,7 @@ csl.func @builtins() {
// CHECK-NEXT: // -----
// CHECK-NEXT: // FILE: layout.csl
// CHECK-NEXT: param param_1 : i32;
// CHECK-NEXT: param param_2 : f16 = 3.14;
// CHECK-NEXT: param param_2 : f16 = 3.5;
// CHECK-NEXT: layout {
// CHECK-NEXT: @set_rectangle(param_1, 6);
// CHECK-NEXT: @set_tile_code(0, 0, "file.csl", );
Expand Down
8 changes: 4 additions & 4 deletions tests/filecheck/dialects/arith/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@ func.func @test_const_const() {

// CHECK-LABEL: @test_const_const
// CHECK-NEXT: %0 = arith.constant 6.139400e+00 : f32
// CHECK-NEXT: %1 = arith.constant -0.14360000000000017 : f32
// CHECK-NEXT: %2 = arith.constant 9.41790285 : f32
// CHECK-NEXT: %3 = arith.constant 0.9542893522202769 : f32
// CHECK-NEXT: %1 = arith.constant -0.143599987 : f32
// CHECK-NEXT: %2 = arith.constant 9.41790295 : f32
// CHECK-NEXT: %3 = arith.constant 0.954289377 : f32
// CHECK-NEXT: "test.op"(%0, %1, %2, %3) : (f32, f32, f32, f32) -> ()
}

Expand All @@ -57,7 +57,7 @@ func.func @test_const_var_const() {
// CHECK-NEXT: %b = arith.constant 3.141500e+00 : f32
// CHECK-NEXT: %2 = arith.mulf %0, %a : f32
// CHECK-NEXT: %3 = arith.mulf %2, %b : f32
// CHECK-NEXT: %4 = arith.constant 21.29352225 : f32
// CHECK-NEXT: %4 = arith.constant 21.2935219 : f32
// CHECK-NEXT: %5 = arith.mulf %4, %0 fastmath<fast> : f32
// CHECK-NEXT: "test.op"(%3, %5) : (f32, f32) -> ()
}
Expand Down
4 changes: 2 additions & 2 deletions tests/filecheck/dialects/csl/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -571,7 +571,7 @@ csl.func @builtins() {
// CHECK-NEXT: }) {"sym_name" = "program"} : () -> ()
// CHECK-NEXT: "csl.module"() <{"kind" = #csl<module_kind layout>}> ({
// CHECK-NEXT: %comp_const = "csl.param"() <{"param_name" = "comp_constant"}> : () -> i32
// CHECK-NEXT: %init = arith.constant 3.140000e+00 : f16
// CHECK-NEXT: %init = arith.constant 3.140620e+00 : f16
// CHECK-NEXT: %p2 = "csl.param"(%init) <{"param_name" = "param_2"}> : (f16) -> f16
// CHECK-NEXT: csl.layout {
// CHECK-NEXT: %x_dim, %y_dim = "test.op"() : () -> (i32, i32)
Expand Down Expand Up @@ -818,7 +818,7 @@ csl.func @builtins() {
// CHECK-GENERIC-NEXT: }) {"sym_name" = "program"} : () -> ()
// CHECK-GENERIC-NEXT: "csl.module"() <{"kind" = #csl<module_kind layout>}> ({
// CHECK-GENERIC-NEXT: %comp_const = "csl.param"() <{"param_name" = "comp_constant"}> : () -> i32
// CHECK-GENERIC-NEXT: %init = "arith.constant"() <{"value" = 3.140000e+00 : f16}> : () -> f16
// CHECK-GENERIC-NEXT: %init = "arith.constant"() <{"value" = 3.140620e+00 : f16}> : () -> f16
// CHECK-GENERIC-NEXT: %p2 = "csl.param"(%init) <{"param_name" = "param_2"}> : (f16) -> f16
// CHECK-GENERIC-NEXT: "csl.layout"() ({
// CHECK-GENERIC-NEXT: %x_dim, %y_dim = "test.op"() : () -> (i32, i32)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
// RUN: mlir-opt %s | filecheck %s
// RUN: xdsl-opt %s | filecheck %s

// CHECK: module {

// f16 is always representable with scientific format
arith.constant 1.1 : f16
// CHECK-NEXT: {{%.*}} = arith.constant 1.099610e+00 : f16

// f32 is represented by scientific format if it is precise enough
arith.constant 3.1415 : f32
// CHECK-NEXT: {{%.*}} = arith.constant 3.141500e+00 : f32

// else, f32 is printed with 9 significant digits
arith.constant 3.141592 : f32
// CHECK-NEXT: {{%.*}} = arith.constant 3.14159203 : f32

// if the decimal separator is within these 9 siginficant non-zero digits, fine
arith.constant 2.997925e+05 : f32
// CHECK-NEXT: {{%.*}} = arith.constant 299792.5 : f32

// else, print hex format
arith.constant 2.997925e+06 : f32
// CHECK-NEXT: {{%.*}} = arith.constant 0x4A36FA94 : f32

// f64 is represented by scientific format if it is precise enough
arith.constant 3.1415 : f64
// CHECK-NEXT: {{%.*}} = arith.constant 3.141500e+00 : f64

// else, f64 is printed with 17 significant digits
arith.constant 3.141592 : f64
// CHECK-NEXT: {{%.*}} = arith.constant 3.1415920000000002 : f64
4 changes: 2 additions & 2 deletions tests/filecheck/parser-printer/builtin_attrs.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@

"func.func"() ({}) {function_type = () -> (), value = dense<"0xEEA7CC3DF47612BE2BA4173E8B75E8BDE0B915BDA3191CBE8388E0BDC826DB3DFE78273E6B037E3DEF140D3EF0B5803D4026693CD6B6E1BCE08B4DBDC3A9E63D943B163EE64E46BD808C253EB8F4893D30270CBE36696C3D045E1DBED06A703DA33EBBBD66D646BD36507BBD764D8FBD7010FA3DB6E1B53D9B83C8BDD33FA73D58AD293EB0A6123EAB2627BA40B4CB3C20E9B6BD805AB2BDE047BDBC809A743DE01ADD3D9B77D5BDCEE7043E00B8C1BDCBA80A3DBB03DA3D787C993D163968BC208510BDABFDB1BD8C07213EA34614BEAB06B73A0091413B8013B3BD768F193E7B6515BE7306833D363183BC36BC8B3CA016B7BD3E05D33DE67C28BDCABB0EBEDA2A013EA67DF6BD007EB5BA782A04BEAB69F73D16DD703D3B93A43D1BE45B3DEBAEE8BD8891F1BDF8B18F3D20EC923CE67101BE8382A8BDAB9EE7BA0006CA3AA3F224BE1B56A5BDC06B8A3DC3E6BE3D562310BB964B713C2CC11FBE4BC68F3DAEACD7BDFB093A3D00070F3EC3E4C93D5BCF0D3D1B01E13D9B7D7F3D537CD43D6BEDFBBC4BD9AEBD17BA023E569906BB86599CBD4E28073E1639F5BDF60909BE8B4727BEE4AD153EDF3C05BEB01913BEEB1A59BD03E8D4BD4BD3123D9EA381BD6058F03CD0EFF73D00747FBADBC5AEBD5054273E204DB4BD00CA683B1E28C93D3BCC2A3D9B0E683D4302923D9A3408BEABC89D3A565336BCC0A7F3BD76D1F93D68A3B93D44891C3E1685243E1B3FDBBD5E06A4BD2B4192BD2B19983C50C97B3D40A808BEC0994C3D4B3435BD0B88293D506749BDFC13063E2B7ADF3CF3B013BE"> : tensor<4x4x3x3xf32>, sym_name = "hex_f32_large_attr"} : () -> ()

// CHECK: "value" = dense<[[[[0.09992967545986176, -0.14303189516067505, 0.14808718860149384], [-0.11350544542074203, -0.03655421733856201, -0.15244154632091522], [-0.1096353754401207, 0.10700756311416626, 0.16354748606681824]], [[0.062014978379011154, 0.13777516782283783, 0.06284701824188232], [0.014230310916900635, -0.027553003281354904, -0.050182223320007324], [0.11262848228216171, 0.14671164751052856, -0.048415087163448334]], [[0.1616687774658203, 0.06736129522323608, -0.13686823844909668], [0.05771752446889877, -0.15367895364761353, 0.05869561433792114], [-0.09142806380987167, -0.04854431003332138, -0.06135579198598862]], [[-0.069971963763237, 0.12210166454315186, 0.08880941569805145], [-0.09790726751089096, 0.08166470378637314, 0.16570031642913818], [0.14321398735046387, -0.0006376306409947574, 0.024866223335266113]]], [[[-0.08931183815002441, -0.08708667755126953, -0.02310556173324585], [0.059717655181884766, 0.10796141624450684, -0.10423203557729721], [0.1297905147075653, -0.0945892333984375, 0.03385237976908684]], [[0.10645242780447006, 0.07494443655014038, -0.01417376659810543], [-0.03528320789337158, -0.08690961450338364, 0.1572553515434265], [-0.14480070769786835, 0.0013963779201731086, 0.0029535889625549316]], [[-0.08743953704833984, 0.14996132254600525, -0.14589492976665497], [0.06397714465856552, -0.01601467654109001, 0.017057519406080246], [-0.08939862251281738, 0.10303734242916107, -0.04113473743200302]], [[-0.13938823342323303, 0.1261400282382965, -0.12035684287548065], [-0.0013846755027770996, -0.1290682554244995, 0.12080701440572739], [0.05880459398031235, 0.08035894483327866, 0.053684335201978683]]], [[[-0.11361487954854965, -0.117953360080719, 0.07016366720199585], [0.017934858798980713, -0.1264110505580902, -0.08228018134832382], [-0.0017671188106760383, 0.0015413165092468262, -0.16108183562755585]], [[-0.08073063939809799, 0.06758832931518555, 0.09321358054876328], [-0.002199371811002493, 0.01472749374806881, -0.15601032972335815], [0.07020243257284164, -0.10530982911586761, 0.04541967436671257]], [[0.13967514038085938, 0.09858085960149765, 0.034621577709913254], [0.10986538976430893, 0.06237564608454704, 0.1037527546286583], [-0.03075285814702511, -0.08537539094686508, 0.1276630014181137]], [[-0.002053817268460989, -0.0763426274061203, 0.13198968768119812], [-0.11973778903484344, -0.1338270604610443, -0.16335885226726532], [0.14617115259170532, -0.13011501729488373, -0.14365267753601074]]], [[[-0.0530041866004467, -0.10395815223455429, 0.0358460359275341], [-0.06330035626888275, 0.0293390154838562, 0.12106287479400635], [-0.0009744763374328613, -0.08533831685781479, 0.163407564163208]], [[-0.08803772926330566, 0.003552079200744629, 0.09822104871273041], [0.041698675602674484, 0.05665455386042595, 0.07129337638616562], [-0.13301315903663635, 0.0012037953129038215, -0.011128267273306847]], [[-0.1189723014831543, 0.12198154628276825, 0.09064370393753052], [0.1528673768043518, 0.1606639325618744, -0.1070539578795433], [-0.08009026944637299, -0.07141336053609848, 0.018566688522696495]], [[0.06147128343582153, -0.1334543228149414, 0.04995131492614746], [-0.04423932358622551, 0.04138950631022453, -0.04917079210281372], [0.13093560934066772, 0.027279933914542198, -0.1442296952009201]]]]> : tensor<4x4x3x3xf32>
// CHECK: "value" = dense<[[[[0.0999296755, -0.143031895, 0.148087189], [-0.113505445, -0.0365542173, -0.152441546], [-0.109635375, 0.107007563, 0.163547486]], [[0.0620149784, 0.137775168, 0.0628470182], [0.0142303109, -0.0275530033, -0.0501822233], [0.112628482, 0.146711648, -0.0484150872]], [[0.161668777, 0.0673612952, -0.136868238], [0.0577175245, -0.153678954, 0.0586956143], [-0.0914280638, -0.04854431, -0.061355792]], [[-0.0699719638, 0.122101665, 0.0888094157], [-0.0979072675, 0.0816647038, 0.165700316], [0.143213987, -0.000637630641, 0.0248662233]]], [[[-0.0893118382, -0.0870866776, -0.0231055617], [0.0597176552, 0.107961416, -0.104232036], [0.129790515, -0.0945892334, 0.0338523798]], [[0.106452428, 0.0749444366, -0.0141737666], [-0.0352832079, -0.0869096145, 0.157255352], [-0.144800708, 0.00139637792, 0.00295358896]], [[-0.087439537, 0.149961323, -0.14589493], [0.0639771447, -0.0160146765, 0.0170575194], [-0.0893986225, 0.103037342, -0.0411347374]], [[-0.139388233, 0.126140028, -0.120356843], [-0.0013846755, -0.129068255, 0.120807014], [0.058804594, 0.0803589448, 0.0536843352]]], [[[-0.11361488, -0.11795336, 0.0701636672], [0.0179348588, -0.126411051, -0.0822801813], [-0.00176711881, 0.00154131651, -0.161081836]], [[-0.0807306394, 0.0675883293, 0.0932135805], [-0.00219937181, 0.0147274937, -0.15601033], [0.0702024326, -0.105309829, 0.0454196744]], [[0.13967514, 0.0985808596, 0.0346215777], [0.10986539, 0.0623756461, 0.103752755], [-0.0307528581, -0.0853753909, 1.276630e-01]], [[-0.00205381727, -0.0763426274, 0.131989688], [-0.119737789, -0.13382706, -0.163358852], [0.146171153, -0.130115017, -0.143652678]]], [[[-0.0530041866, -0.103958152, 0.0358460359], [-0.0633003563, 0.0293390155, 0.121062875], [-0.000974476337, -0.0853383169, 0.163407564]], [[-0.0880377293, 0.0035520792, 0.0982210487], [0.0416986756, 0.0566545539, 0.0712933764], [-0.133013159, 0.00120379531, -0.0111282673]], [[-0.118972301, 0.121981546, 0.0906437039], [0.152867377, 0.160663933, -0.107053958], [-0.0800902694, -0.0714133605, 0.0185666885]], [[0.0614712834, -0.133454323, 0.0499513149], [-0.0442393236, 0.0413895063, -0.0491707921], [0.130935609, 0.0272799339, -0.144229695]]]]> : tensor<4x4x3x3xf32>} : () -> ()

"func.func"() ({}) {function_type = () -> (), value = "foo", sym_name = "string_attr"} : () -> ()

Expand Down Expand Up @@ -141,7 +141,7 @@
value1 = dense<"0xCAFEBABE"> : tensor<2xf32>,
value2 = dense<"0xCAFEBABEB00BAABE"> : tensor<1xf64>,
sym_name = "dense_tensor_attr_hex_float"} : () -> ()
// CHECK: "value1" = dense<-0.3652251362800598> : tensor<2xf32>, "value2" = dense<-7.762213249592702e-07> : tensor<1xf64>
// CHECK: "value1" = dense<-0.365225136> : tensor<2xf32>, "value2" = dense<-7.7622132495927025e-07> : tensor<1xf64>

"func.func"() ({}) {function_type = () -> (),
value1 = dense<[0]> : vector<1xi32>,
Expand Down
2 changes: 1 addition & 1 deletion tests/filecheck/parser-printer/float_parsing.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,5 @@

// this should print in full precision
"test.op"() {"value" = 3.141592653589793 : f64} : () -> ()
// CHECK-NEXT: "test.op"() {"value" = 3.141592653589793 : f64} : () -> ()
// CHECK-NEXT: "test.op"() {"value" = 3.1415926535897931 : f64} : () -> ()
}) : () -> ()
Loading

0 comments on commit 7e38685

Please sign in to comment.