Skip to content

Commit 061c271

Browse files
author
AzeezIsh
committed
Removed dtype_map
1 parent 1726f3e commit 061c271

File tree

1 file changed

+6
-32
lines changed

1 file changed

+6
-32
lines changed

Diff for: tests/test_muldiv.py

+6-32
Original file line numberDiff line numberDiff line change
@@ -7,36 +7,8 @@
77
import arrayfire_wrapper.lib as wrapper
88

99
# import arrayfire_wrapper.lib.mathematical_functions as ops
10-
from arrayfire_wrapper.lib.create_and_modify_array.helper_functions import array_to_string
11-
12-
dtype_map = {
13-
"int16": dtype.s16,
14-
"int32": dtype.s32,
15-
"int64": dtype.s64,
16-
"uint8": dtype.u8,
17-
"uint16": dtype.u16,
18-
"uint32": dtype.u32,
19-
"uint64": dtype.u64,
20-
"float16": dtype.f16,
21-
"float32": dtype.f32,
22-
# 'float64': dtype.f64,
23-
# 'complex64': dtype.c64,
24-
"complex32": dtype.c32,
25-
"bool": dtype.b8,
26-
"s16": dtype.s16,
27-
"s32": dtype.s32,
28-
"s64": dtype.s64,
29-
"u8": dtype.u8,
30-
"u16": dtype.u16,
31-
"u32": dtype.u32,
32-
"u64": dtype.u64,
33-
"f16": dtype.f16,
34-
"f32": dtype.f32,
35-
# 'f64': dtype.f64,
36-
"c32": dtype.c32,
37-
# 'c64': dtype.c64,
38-
"b8": dtype.b8,
39-
}
10+
11+
from . import utility_functions as util
4012

4113

4214
@pytest.mark.parametrize(
@@ -87,9 +59,10 @@ def test_multiply_negative_shapes() -> None:
8759
), f"Failed for shapes {lhs_shape} and {rhs_shape}"
8860

8961

90-
@pytest.mark.parametrize("dtype_name", dtype_map.values())
62+
@pytest.mark.parametrize("dtype_name", util.get_all_types())
9163
def test_multiply_supported_dtypes(dtype_name: dtype.Dtype) -> None:
9264
"""Test multiplication operation across all supported data types."""
65+
util.check_type_supported(dtype_name)
9366
shape = (5, 5)
9467
lhs = wrapper.randu(shape, dtype_name)
9568
rhs = wrapper.randu(shape, dtype_name)
@@ -201,9 +174,10 @@ def test_divide_negative_shapes() -> None:
201174
), f"Failed for shapes {lhs_shape} and {rhs_shape}"
202175

203176

204-
@pytest.mark.parametrize("dtype_name", dtype_map.values())
177+
@pytest.mark.parametrize("dtype_name", util.get_all_types())
205178
def test_divide_supported_dtypes(dtype_name: dtype.Dtype) -> None:
206179
"""Test division operation across all supported data types."""
180+
util.check_type_supported(dtype_name)
207181
shape = (5, 5)
208182
lhs = wrapper.randu(shape, dtype_name)
209183
rhs = wrapper.randu(shape, dtype_name)

0 commit comments

Comments
 (0)