diff --git a/tests/interpreters/test_riscv_interpreter.py b/tests/interpreters/test_riscv_interpreter.py index 5fda87f959..7cde5183b9 100644 --- a/tests/interpreters/test_riscv_interpreter.py +++ b/tests/interpreters/test_riscv_interpreter.py @@ -157,6 +157,16 @@ def my_custom_instruction( # D extension arithmetic + assert interpreter.run_op( + riscv.FMAddDOp( + TestSSAValue(fregister), + TestSSAValue(fregister), + TestSSAValue(fregister), + rd=riscv.FloatRegisterType.unallocated(), + ), + (3.0, 4.0, 5.0), + ) == (17.0,) + assert interpreter.run_op( riscv.FAddDOp( TestSSAValue(fregister), diff --git a/xdsl/interpreters/riscv.py b/xdsl/interpreters/riscv.py index ec8347f833..2d98872437 100644 --- a/xdsl/interpreters/riscv.py +++ b/xdsl/interpreters/riscv.py @@ -428,6 +428,17 @@ def run_fmv( # region D extension + @impl(riscv.FMAddDOp) + def run_fmadd_d( + self, + interpreter: Interpreter, + op: riscv.FMAddDOp, + args: tuple[Any, ...], + ): + args = RiscvFunctions.get_reg_values(interpreter, op.operands, args) + results = (args[0] * args[1] + args[2],) + return RiscvFunctions.set_reg_values(interpreter, op.results, results) + @impl(riscv.FAddDOp) def run_fadd_d( self,