Skip to content

Commit

Permalink
fix(compiler/tests): Fixing the generation of dot/matmul signed tests
Browse files Browse the repository at this point in the history
  • Loading branch information
BourgerieQuentin committed Aug 1, 2023
1 parent bd45401 commit 8e8b2dd
Showing 1 changed file with 24 additions and 12 deletions.
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
import argparse
import numpy as np

PRECISIONS_TO_BENCH = [(6, 2)]#, (16, 7)]
PRECISIONS_TO_BENCH = [
# output, input
(6, 2),
#, (16, 7)
]
SHAPES = [((2, 3, 4), (2, 4, 2)), ((3, 4), (4, 2)), ((3,), (3,)), ((3,), (3, 2)), ((3,), (4, 3, 2)), ((3,4), (4,)), ((2,3,4), (4,)), ((2, 1, 3, 4), (5, 4, 2))]
P_ERROR = 1.0 / 1e6

Expand All @@ -26,11 +30,15 @@ def generate(op):
for p, p_inputs in PRECISIONS_TO_BENCH:
for shapes in SHAPES:
for signed in [False, True]:
min_value = 0
max_value = (2 ** p_inputs) - 1

inp_0 = np.random.randint(min_value, max_value+1, size=shapes[0])
inp_1 = np.random.randint(min_value, max_value+1, size=shapes[1])
if signed:
min_value = - 2 ** (p_inputs - 1)
max_value = 2 ** (p_inputs - 1) - 1
else:
min_value = 0
max_value = 2 ** p_inputs - 1

inp_0 = np.random.randint(min_value, max_value/2, size=shapes[0])
inp_1 = np.random.randint(min_value, max_value/2, size=shapes[1])

expected_result = inp_0 @ inp_1

Expand Down Expand Up @@ -75,29 +83,33 @@ def generate(op):
shape_1_str_yaml = ",".join(map(str, shapes[1]))
expected_shape_yaml = ",".join(map(str, out_shape))

program += (f"p-error: {P_ERROR}\n"
if signed:
signed_line = " signed: True\n"
else:
signed_line = ""
program += (
f"p-error: {P_ERROR}\n"
"tests:\n"
" - inputs: \n"
f" - tensor: {inp_0_str}\n"
f" shape: [{shape_0_str_yaml}]\n"
f"{signed_line}"
f" - tensor: {inp_1_str}\n"
f" shape: [{shape_1_str_yaml}]\n"
f"{signed_line}"
f" outputs:\n"
)

if op_outputs_scalar:
program += (
f" - scalar: {expected_str}\n"
f"{signed_line}"
)
else:
program += (
f" - tensor: {expected_str}\n"
f" shape: [{expected_shape_yaml}]\n"
)

if signed:
program += (
f" signed: True\n"
f"{signed_line}"
)

program += f"---"
Expand Down

0 comments on commit 8e8b2dd

Please sign in to comment.