diff --git a/heterocl/ast/ast.py b/heterocl/ast/ast.py index aea6b896..704c5cb4 100644 --- a/heterocl/ast/ast.py +++ b/heterocl/ast/ast.py @@ -170,10 +170,41 @@ def simplify(expr): index = struct.index e = struct.tensor.fcompute(*index)[expr.field] return sp.simplify(simplify(e)) - if isinstance(expr, SelectOp): # pylint: disable=no-else-return + if isinstance(expr, SelectOp): if simplify(expr.cond): return sp.simplify(simplify(expr.true_value)) return sp.simplify(simplify(expr.false_value)) + if isinstance(expr, MathExpOp): + expr = unwrap_sp(simplify(expr.expr)) + return sp.exp(expr) + if isinstance(expr, MathPowOp): + lhs = unwrap_sp(simplify(expr.lhs)) + rhs = unwrap_sp(simplify(expr.rhs)) + return sp.Pow(lhs, rhs) + if isinstance(expr, MathLogOp): + expr = unwrap_sp(simplify(expr.expr)) + return sp.log(expr) + if isinstance(expr, MathLog2Op): + expr = unwrap_sp(simplify(expr.expr)) + return sp.log(expr, 2) + if isinstance(expr, MathLog10Op): + expr = unwrap_sp(simplify(expr.expr)) + return sp.log(expr, 10) + if isinstance(expr, MathSqrtOp): + expr = unwrap_sp(simplify(expr.expr)) + return sp.sqrt(expr) + if isinstance(expr, MathSinOp): + expr = unwrap_sp(simplify(expr.expr)) + return sp.sin(expr) + if isinstance(expr, MathCosOp): + expr = unwrap_sp(simplify(expr.expr)) + return sp.cos(expr) + # if isinstance(expr, MathTanOp): + # expr = unwrap_sp(simplify(expr.expr)) + # return sp.tan(expr) + if isinstance(expr, MathTanhOp): # pylint: disable=no-else-return + expr = unwrap_sp(simplify(expr.expr)) + return sp.tanh(expr) else: raise HCLError(f"Unsupported expression type: {type(expr)}") diff --git a/tests/test_simplify.py b/tests/test_simplify.py index 4ab8c31c..c4baa3bb 100644 --- a/tests/test_simplify.py +++ b/tests/test_simplify.py @@ -364,3 +364,267 @@ def kernel(A): np_B = hcl.asarray([0, 0]) f(np_A, np_B) assert np_B.asnumpy().tolist() == [0b10101, 0b01110] + + +def test_math_exp(): + def kernel(A): + lower_idx = hcl.exp(0) + lower_idx = hcl.cast(hcl.Index(), lower_idx) + upper_idx = hcl.exp(1.5) + upper_idx = hcl.cast(hcl.Index(), upper_idx) # e^1.5 ~4.482 -> 4 + + B = hcl.compute(A.shape, lambda x: A[x][lower_idx:upper_idx]) + return B + + A = hcl.placeholder((2,), "A") + s = hcl.create_schedule([A], kernel) + f = hcl.build(s) + np_A = hcl.asarray([0b10101100, 0b01100101]) + np_B = hcl.asarray([0, 0]) + f(np_A, np_B) + assert np_B.asnumpy().tolist() == [0b110, 0b010] + + +def test_math_pow(): + def kernel(A): + a = hcl.scalar(0) + b = hcl.scalar(2) + + lower_idx = hcl.power(b.v, a.v) + lower_idx = hcl.cast(hcl.Index(), lower_idx) + upper_idx = hcl.power(b.v, b.v) + upper_idx = hcl.cast(hcl.Index(), upper_idx) + + B = hcl.compute(A.shape, lambda x: A[x][lower_idx:upper_idx]) + return B + + A = hcl.placeholder((2,), "A") + s = hcl.create_schedule([A], kernel) + f = hcl.build(s) + np_A = hcl.asarray([0b10101100, 0b01100101]) + np_B = hcl.asarray([0, 0]) + f(np_A, np_B) + assert np_B.asnumpy().tolist() == [0b110, 0b010] + + +def test_math_log_op(): + def kernel(A): + a = hcl.scalar(1) + + lower_idx = hcl.log(a.v) + lower_idx = hcl.cast(hcl.Index(), lower_idx) + upper_idx = hcl.log(12) + upper_idx = hcl.cast(hcl.Index(), upper_idx) + 3 + + B = hcl.compute(A.shape, lambda x: A[x][lower_idx:upper_idx]) + return B + + A = hcl.placeholder((2,), "A") + s = hcl.create_schedule([A], kernel) + f = hcl.build(s) + np_A = hcl.asarray([0b01001111, 0b11101110]) + np_B = hcl.asarray([0, 0]) + f(np_A, np_B) + assert np_B.asnumpy().tolist() == [0b01111, 0b01110] + + +def test_math_log2_op(): + def kernel(A): + a = hcl.scalar(16) + + lower_idx = hcl.log2(2) + lower_idx = hcl.cast(hcl.Index(), lower_idx) + upper_idx = hcl.log2(a.v) + upper_idx = hcl.cast(hcl.Index(), upper_idx) + + B = hcl.compute(A.shape, lambda x: A[x][lower_idx:upper_idx]) + return B + + A = hcl.placeholder((2,), "A") + s = hcl.create_schedule([A], kernel) + f = hcl.build(s) + np_A = hcl.asarray([0b00001010, 0b11101110]) + np_B = hcl.asarray([0, 0]) + f(np_A, np_B) + assert np_B.asnumpy().tolist() == [0b101, 0b111] + + +def test_math_log10_op(): + def kernel(A): + a = hcl.scalar(10) + b = hcl.scalar(10000) + + lower_idx = hcl.log10(a.v) + lower_idx = hcl.cast(hcl.Index(), lower_idx) + upper_idx = hcl.log10(b.v) + upper_idx = hcl.cast(hcl.Index(), upper_idx) + + B = hcl.compute(A.shape, lambda x: A[x][lower_idx:upper_idx]) + return B + + A = hcl.placeholder((2,), "A") + s = hcl.create_schedule([A], kernel) + f = hcl.build(s) + np_A = hcl.asarray([0b01111100, 0b01001010]) + np_B = hcl.asarray([0, 0]) + f(np_A, np_B) + assert np_B.asnumpy().tolist() == [0b110, 0b101] + + +def test_math_sqrt1(): + def kernel(A): + a = hcl.scalar(25) + + upper_idx = hcl.sqrt(a.v) + upper_idx = hcl.cast(hcl.Index(), upper_idx) + + B = hcl.compute(A.shape, lambda x: A[x][0:upper_idx]) + return B + + A = hcl.placeholder((2,), "A") + s = hcl.create_schedule([A], kernel) + f = hcl.build(s) + np_A = hcl.asarray([0b10011101, 0b01001010]) + np_B = hcl.asarray([0, 0]) + f(np_A, np_B) + assert np_B.asnumpy().tolist() == [0b11101, 0b01010] + + +def test_math_sqrt2(): + def kernel(A): + upper_idx = hcl.sqrt(30) + upper_idx = hcl.cast(hcl.Index(), upper_idx) # sqrt(30) ~ 5.477 -> 5 + + B = hcl.compute(A.shape, lambda x: A[x][0:upper_idx]) + return B + + A = hcl.placeholder((2,), "A") + s = hcl.create_schedule([A], kernel) + f = hcl.build(s) + np_A = hcl.asarray([0b10011101, 0b01001010]) + np_B = hcl.asarray([0, 0]) + f(np_A, np_B) + assert np_B.asnumpy().tolist() == [0b11101, 0b01010] + + +def test_math_sin1(): + def kernel(A): + lower_idx = hcl.sin(0) + lower_idx = hcl.cast(hcl.Index(), lower_idx) + + B = hcl.compute(A.shape, lambda x: A[x][lower_idx:3]) + return B + + A = hcl.placeholder((2,), "A") + s = hcl.create_schedule([A], kernel) + f = hcl.build(s) + np_A = hcl.asarray([0b10011101, 0b10110001]) + np_B = hcl.asarray([0, 0]) + f(np_A, np_B) + assert np_B.asnumpy().tolist() == [0b101, 0b001] + + +def test_math_sin2(): + def kernel(A): + lower_idx = hcl.sin(np.pi / 2) + lower_idx = hcl.cast(hcl.Index(), lower_idx) + + B = hcl.compute(A.shape, lambda x: A[x][lower_idx:5]) + return B + + A = hcl.placeholder((2,), "A") + s = hcl.create_schedule([A], kernel) + f = hcl.build(s) + np_A = hcl.asarray([0b10011101, 0b10110001]) + np_B = hcl.asarray([0, 0]) + f(np_A, np_B) + assert np_B.asnumpy().tolist() == [0b1110, 0b1000] + + +def test_math_cos1(): + def kernel(A): + lower_idx = hcl.cos(0) + lower_idx = hcl.cast(hcl.Index(), lower_idx) + + B = hcl.compute(A.shape, lambda x: A[x][lower_idx:4]) + return B + + A = hcl.placeholder((2,), "A") + s = hcl.create_schedule([A], kernel) + f = hcl.build(s) + np_A = hcl.asarray([0b10011101, 0b10110001]) + np_B = hcl.asarray([0, 0]) + f(np_A, np_B) + assert np_B.asnumpy().tolist() == [0b110, 0b000] + + +def test_math_cos2(): + def kernel(A): + lower_idx = hcl.cos(np.pi / 2) + lower_idx = hcl.cast(hcl.Index(), lower_idx) + + B = hcl.compute(A.shape, lambda x: A[x][lower_idx:4]) + return B + + A = hcl.placeholder((2,), "A") + s = hcl.create_schedule([A], kernel) + f = hcl.build(s) + np_A = hcl.asarray([0b10011101, 0b10110001]) + np_B = hcl.asarray([0, 0]) + f(np_A, np_B) + assert np_B.asnumpy().tolist() == [0b1101, 0b0001] + + +# hcl.tan() not supported yet + +# def test_math_tan(): +# def kernel(A): +# lower_idx = hcl.tan(0) +# lower_idx = hcl.cast(hcl.Index(), lower_idx) + +# B = hcl.compute(A.shape, lambda x: A[x][lower_idx:4]) +# return B + +# A = hcl.placeholder((2,), "A") +# s = hcl.create_schedule([A], kernel) +# f = hcl.build(s) +# np_A = hcl.asarray([0b10011101, 0b00101100]) +# np_B = hcl.asarray([0, 0]) +# f(np_A, np_B) +# assert np_B.asnumpy().tolist() == [0b1101, 0b1100] + + +def test_math_tanh1(): + def kernel(A): + lower_idx = hcl.tanh(0) + lower_idx = hcl.cast(hcl.Index(), lower_idx) + + B = hcl.compute(A.shape, lambda x: A[x][lower_idx:4]) + return B + + A = hcl.placeholder((2,), "A") + s = hcl.create_schedule([A], kernel) + f = hcl.build(s) + np_A = hcl.asarray([0b10101010, 0b11001011]) + np_B = hcl.asarray([0, 0]) + f(np_A, np_B) + assert np_B.asnumpy().tolist() == [0b1010, 0b1011] + + +def test_math_tanh2(): + def kernel(A): + a = hcl.scalar(8) + + lower_idx = hcl.tanh(a.v) + lower_idx = hcl.cast(hcl.Index(), lower_idx) + + B = hcl.compute(A.shape, lambda x: A[x][lower_idx:4]) + return B + + A = hcl.placeholder((2,), "A") + s = hcl.create_schedule([A], kernel) + f = hcl.build(s) + np_A = hcl.asarray([0b10101011, 0b11001011]) + np_B = hcl.asarray([0, 0]) + f(np_A, np_B) + assert np_B.asnumpy().tolist() == [0b101, 0b101]