|
7 | 7 | import arrayfire_wrapper.lib as wrapper
|
8 | 8 |
|
9 | 9 | # import arrayfire_wrapper.lib.mathematical_functions as ops
|
10 |
| -from arrayfire_wrapper.lib.create_and_modify_array.helper_functions import array_to_string |
11 |
| - |
12 |
| -dtype_map = { |
13 |
| - "int16": dtype.s16, |
14 |
| - "int32": dtype.s32, |
15 |
| - "int64": dtype.s64, |
16 |
| - "uint8": dtype.u8, |
17 |
| - "uint16": dtype.u16, |
18 |
| - "uint32": dtype.u32, |
19 |
| - "uint64": dtype.u64, |
20 |
| - "float16": dtype.f16, |
21 |
| - "float32": dtype.f32, |
22 |
| - # 'float64': dtype.f64, |
23 |
| - # 'complex64': dtype.c64, |
24 |
| - "complex32": dtype.c32, |
25 |
| - "bool": dtype.b8, |
26 |
| - "s16": dtype.s16, |
27 |
| - "s32": dtype.s32, |
28 |
| - "s64": dtype.s64, |
29 |
| - "u8": dtype.u8, |
30 |
| - "u16": dtype.u16, |
31 |
| - "u32": dtype.u32, |
32 |
| - "u64": dtype.u64, |
33 |
| - "f16": dtype.f16, |
34 |
| - "f32": dtype.f32, |
35 |
| - # 'f64': dtype.f64, |
36 |
| - "c32": dtype.c32, |
37 |
| - # 'c64': dtype.c64, |
38 |
| - "b8": dtype.b8, |
39 |
| -} |
| 10 | + |
| 11 | +from . import utility_functions as util |
40 | 12 |
|
41 | 13 |
|
42 | 14 | @pytest.mark.parametrize(
|
@@ -87,9 +59,10 @@ def test_multiply_negative_shapes() -> None:
|
87 | 59 | ), f"Failed for shapes {lhs_shape} and {rhs_shape}"
|
88 | 60 |
|
89 | 61 |
|
90 |
| -@pytest.mark.parametrize("dtype_name", dtype_map.values()) |
| 62 | +@pytest.mark.parametrize("dtype_name", util.get_all_types()) |
91 | 63 | def test_multiply_supported_dtypes(dtype_name: dtype.Dtype) -> None:
|
92 | 64 | """Test multiplication operation across all supported data types."""
|
| 65 | + util.check_type_supported(dtype_name) |
93 | 66 | shape = (5, 5)
|
94 | 67 | lhs = wrapper.randu(shape, dtype_name)
|
95 | 68 | rhs = wrapper.randu(shape, dtype_name)
|
@@ -201,9 +174,10 @@ def test_divide_negative_shapes() -> None:
|
201 | 174 | ), f"Failed for shapes {lhs_shape} and {rhs_shape}"
|
202 | 175 |
|
203 | 176 |
|
204 |
| -@pytest.mark.parametrize("dtype_name", dtype_map.values()) |
| 177 | +@pytest.mark.parametrize("dtype_name", util.get_all_types()) |
205 | 178 | def test_divide_supported_dtypes(dtype_name: dtype.Dtype) -> None:
|
206 | 179 | """Test division operation across all supported data types."""
|
| 180 | + util.check_type_supported(dtype_name) |
207 | 181 | shape = (5, 5)
|
208 | 182 | lhs = wrapper.randu(shape, dtype_name)
|
209 | 183 | rhs = wrapper.randu(shape, dtype_name)
|
|
0 commit comments