diff --git a/scripts/generate_testscases.py b/scripts/generate_testscases.py index 25f83ead..0640362a 100644 --- a/scripts/generate_testscases.py +++ b/scripts/generate_testscases.py @@ -32,7 +32,7 @@ def generate_example(model_input): dtype = model_input.dtype - if dtype == np.float32: + if dtype == "float32": return np.random.uniform(low=0.0, high=1.0, size=shape).astype(np.float32) else: