diff --git a/backends/vulkan/test/op_tests/cases.py b/backends/vulkan/test/op_tests/cases.py index e558966cc2..7eb9f1936d 100644 --- a/backends/vulkan/test/op_tests/cases.py +++ b/backends/vulkan/test/op_tests/cases.py @@ -166,6 +166,7 @@ def get_avg_pool2d_inputs(): ), ] test_suite = VkTestSuite([tuple(tc) for tc in test_cases]) + test_suite.dtypes = ["at::kFloat"] return test_suite