diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d2dd579f..b35527a1 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -24,7 +24,7 @@ repos: - --target-version=py312 - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.4.3 + rev: v0.5.2 hooks: - id: ruff @@ -35,7 +35,7 @@ repos: language_version: python3 - repo: https://github.com/asottile/pyupgrade - rev: v3.15.2 + rev: v3.16.0 hooks: - id: pyupgrade args: @@ -52,16 +52,17 @@ repos: - id: yesqa - repo: https://github.com/adamchainz/blacken-docs - rev: 1.16.0 + rev: 1.18.0 hooks: - id: blacken-docs additional_dependencies: - black - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.10.0 + rev: v1.10.1 hooks: - id: mypy + files: "src/" args: [--ignore-missing-imports] additional_dependencies: - dask diff --git a/pyproject.toml b/pyproject.toml index 2409cd51..55cb6c97 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -111,7 +111,8 @@ src_paths = ["src", "tests"] [tool.mypy] python_version = "3.9" -files = ["src", "tests"] +files = ["src"] +exclude = ["tests/"] strict = false warn_unused_configs = true show_error_codes = true diff --git a/src/dask_awkward/__init__.py b/src/dask_awkward/__init__.py index bc445813..5d1c00f8 100644 --- a/src/dask_awkward/__init__.py +++ b/src/dask_awkward/__init__.py @@ -94,6 +94,7 @@ with_field, with_name, with_parameter, + without_field, without_parameters, zeros_like, zip, diff --git a/src/dask_awkward/lib/__init__.py b/src/dask_awkward/lib/__init__.py index a66177d2..879568cc 100644 --- a/src/dask_awkward/lib/__init__.py +++ b/src/dask_awkward/lib/__init__.py @@ -83,6 +83,7 @@ with_field, with_name, with_parameter, + without_field, without_parameters, zeros_like, zip, diff --git a/src/dask_awkward/lib/core.py b/src/dask_awkward/lib/core.py index 973bf622..40898b11 100644 --- a/src/dask_awkward/lib/core.py +++ b/src/dask_awkward/lib/core.py @@ -27,7 +27,6 @@ TypeTracerArray, create_unknown_scalar, is_unknown_scalar, - touch_data, ) from dask.base import ( DaskMethodsMixin, @@ -48,7 +47,6 @@ from dask_awkward.layers import AwkwardBlockwiseLayer, AwkwardMaterializedLayer from dask_awkward.lib.optimize import all_optimizations -from dask_awkward.lib.utils import commit_to_reports from dask_awkward.utils import ( DaskAwkwardNotImplemented, IncompatiblePartitions, @@ -400,10 +398,6 @@ def name(self) -> str: def key(self) -> Key: return (self._name, 0) - @property - def report(self): - return getattr(self._meta, "_report", set()) - def _check_meta(self, m): if isinstance(m, MaybeNone): return ak.Array(m.content) @@ -524,7 +518,6 @@ def f(self, other): meta = op(self._meta, other._meta) else: meta = op(self._meta, other) - commit_to_reports(name, self.report) return new_scalar_object(graph, name, meta=meta) return f @@ -720,9 +713,7 @@ def _check_meta(self, m: Any | None) -> Any | None: def __getitem__(self, where): token = tokenize(self, where) new_name = f"{where}-{token}" - report = self.report new_meta = self._meta[where] - commit_to_reports(new_name, report) # first check for array type return if isinstance(new_meta, ak.Array): @@ -732,8 +723,6 @@ def __getitem__(self, where): graphlayer, dependencies=[self], ) - new_meta._report = report - hlg.layers[new_name].meta = new_meta return new_array_object(hlg, new_name, meta=new_meta, npartitions=1) # then check for scalar (or record) type @@ -744,8 +733,6 @@ def __getitem__(self, where): dependencies=[self], ) if isinstance(new_meta, ak.Record): - new_meta._report = report - hlg.layers[new_name].meta = new_meta return new_record_object(hlg, new_name, meta=new_meta) else: return new_scalar_object(hlg, new_name, meta=new_meta) @@ -819,7 +806,7 @@ def new_record_object(dsk: HighLevelGraph, name: str, *, meta: Any) -> Record: raise TypeError( f"meta Record must have a typetracer backend, not {ak.backend(meta)}" ) - return out + return Record(dsk, name, meta) def _is_numpy_or_cupy_like(arr: Any) -> bool: @@ -950,10 +937,6 @@ def reset_meta(self) -> None: """Assign an empty typetracer array as the collection metadata.""" self._meta = empty_typetracer() - @property - def report(self): - return getattr(self._meta, "_report", set()) - def repartition( self, npartitions: int | None = None, @@ -989,7 +972,6 @@ def repartition( new_graph = HighLevelGraph.from_collections( key, new_layer, dependencies=(self,) ) - commit_to_reports(key, self.report) return new_array_object( new_graph, key, @@ -1175,13 +1157,11 @@ def _partitions(self, index: Any) -> Array: name = f"partitions-{token}" new_keys = self.keys_array[index].tolist() dsk = {(name, i): tuple(key) for i, key in enumerate(new_keys)} - layer = AwkwardMaterializedLayer(dsk, previous_layer_names=[self.name]) graph = HighLevelGraph.from_collections( name, - layer, + AwkwardMaterializedLayer(dsk, previous_layer_names=[self.name]), dependencies=(self,), ) - layer.meta = self._meta # if a single partition was requested we trivially know the new divisions. if len(raw) == 1 and isinstance(raw[0], int) and self.known_divisions: @@ -1193,7 +1173,7 @@ def _partitions(self, index: Any) -> Array: # otherwise nullify the known divisions else: new_divisions = (None,) * (len(new_keys) + 1) # type: ignore - commit_to_reports(name, self.report) + return new_array_object( graph, name, meta=self._meta, divisions=tuple(new_divisions) ) @@ -1415,7 +1395,6 @@ def _getitem_slice_on_zero(self, where): AwkwardMaterializedLayer(dask, previous_layer_names=[self.name]), dependencies=[self], ) - commit_to_reports(name, self.report) return new_array_object( hlg, name, @@ -1526,14 +1505,9 @@ def __getitem__(self, where): raise RuntimeError("Lists containing integers are not supported.") if isinstance(where, tuple): - out = self._getitem_tuple(where) - else: - out = self._getitem_single(where) - if self.report: - commit_to_reports(out.name, self.report) - out._meta._report = self._meta._report - out.dask.layers[out.name].meta = out._meta - return out + return self._getitem_tuple(where) + + return self._getitem_single(where) def _is_method_heuristic(self, resolved: Any) -> bool: return callable(resolved) @@ -1860,12 +1834,10 @@ def partitionwise_layer( """ pairs: list[Any] = [] numblocks: dict[str, tuple[int, ...]] = {} - reps = set() for arg in args: if isinstance(arg, Array): pairs.extend([arg.name, "i"]) numblocks[arg.name] = (arg.npartitions,) - reps.update(arg.report) elif isinstance(arg, BlockwiseDep): if len(arg.numblocks) == 1: pairs.extend([arg, "i"]) @@ -1885,8 +1857,6 @@ def partitionwise_layer( ) else: pairs.extend([arg, None]) - commit_to_reports(name, reps) - layer = dask_blockwise( func, name, @@ -1970,23 +1940,8 @@ def _map_partitions( **kwargs, ) - reps = set() - try: - if meta is None: - meta = map_meta(fn, *args, **kwargs) - else: - # To do any touching?? - map_meta(fn, *args, **kwargs) - meta._report = reps - lay.meta = meta - except (AssertionError, TypeError, NotImplementedError): - [touch_data(_._meta) for _ in dak_arrays] - - for dep in dak_arrays: - for rep in dep.report: - if rep not in reps: - rep.commit(name) - reps.add(rep) + if meta is None: + meta = map_meta(fn, *args, **kwargs) hlg = HighLevelGraph.from_collections( name, @@ -2009,6 +1964,7 @@ def _map_partitions( new_divisions = tuple(map(lambda x: x * output_divisions, in_divisions)) else: new_divisions = in_divisions + if output_divisions is not None: return new_array_object( hlg, @@ -2239,6 +2195,10 @@ def non_trivial_reduction( if combiner is None: combiner = reducer + # is_positional == True is not implemented + # if is_positional: + # assert combiner is reducer + # For `axis=None`, we prepare each array to have the following structure: # [[[ ... [x1 x2 x3 ... xN] ... ]]] (length-1 outer lists) # This makes the subsequent reductions an `axis=-1` reduction @@ -2313,16 +2273,14 @@ def non_trivial_reduction( ) graph = HighLevelGraph.from_collections(name_finalize, trl, dependencies=(chunked,)) + meta = reducer( array._meta, axis=axis, keepdims=keepdims, mask_identity=mask_identity, ) - trl.meta = meta - commit_to_reports(name_finalize, array.report) if isinstance(meta, ak.highlevel.Array): - meta._report = array.report return new_array_object(graph, name_finalize, meta=meta, npartitions=1) else: return new_scalar_object(graph, name_finalize, meta=meta) diff --git a/src/dask_awkward/lib/structure.py b/src/dask_awkward/lib/structure.py index f86ac1fd..2e14cd48 100644 --- a/src/dask_awkward/lib/structure.py +++ b/src/dask_awkward/lib/structure.py @@ -71,6 +71,7 @@ "values_astype", "where", "with_field", + "without_field", "with_name", "with_parameter", "without_parameters", @@ -602,26 +603,20 @@ def mask( @borrow_docstring(ak.nan_to_num) def nan_to_num( array: Array, - copy: bool = True, nan: float = 0.0, posinf: Any | None = None, neginf: Any | None = None, - highlevel: bool = True, behavior: Any | None = None, - attrs: Mapping[str, Any] | None = None, ) -> Array: - # return map_partitions( - # ak.nan_to_num, - # array, - # output_partitions=1, - # copy=copy, - # nan=nan, - # posinf=posinf, - # neginf=neginf, - # highlevel=highlevel, - # behavior=behavior, - # ) - raise DaskAwkwardNotImplemented("TODO") + return map_partitions( + ak.nan_to_num, + array, + nan=nan, + posinf=posinf, + neginf=neginf, + highlevel=True, + behavior=behavior, + ) def _numaxis0(*integers): @@ -1093,6 +1088,46 @@ def with_field( ) +class _WithoutFieldFn: + def __init__( + self, + highlevel: bool, + behavior: Mapping | None = None, + attrs: Mapping[str, Any] | None = None, + ) -> None: + self.highlevel = highlevel + self.behavior = behavior + self.attrs = attrs + + def __call__(self, array: ak.Array, where: str) -> ak.Array: + return ak.without_field( + array, where=where, behavior=self.behavior, attrs=self.attrs + ) + + +@borrow_docstring(ak.without_field) +def without_field( + base: Array, + where: str, + highlevel: bool = True, + behavior: Mapping | None = None, + attrs: Mapping[str, Any] | None = None, +) -> Array: + if not highlevel: + raise ValueError("Only highlevel=True is supported") + + if not isinstance(base, Array): + raise ValueError("Base argument in without_field must be a dask_awkward.Array") + + return map_partitions( + _WithoutFieldFn(highlevel=highlevel, behavior=behavior, attrs=attrs), + base, + where, + label="without-field", + output_divisions=1, + ) + + class _WithNameFn: def __init__( self, @@ -1360,6 +1395,54 @@ def repartition_layer(arr: Array, key: str, divisions: tuple[int, ...]) -> dict: return layer +def _subpart(data: ak.Array, parts: int, part: int) -> ak.Array: + from dask_awkward.lib.core import is_typetracer + + if is_typetracer(data): + return data + rows_per = len(data) // parts + return data[ + part * rows_per : None if part == (parts - 1) else (part + 1) * rows_per + ] + + +def _subcat(*arrs: tuple[ak.Array, ...]) -> ak.Array: + return ak.concatenate(arrs) + + +def simple_repartition_layer( + arr: Array, n_to_one: int | None, one_to_n: int | None, key: str +) -> tuple[dict, tuple[Any, ...]]: + layer: dict[tuple[str, int], tuple[Any, ...]] = {} + new_divisions: tuple[Any, ...] + if n_to_one: + for i0, i in enumerate(range(0, arr.npartitions, n_to_one)): + layer[(key, i0)] = (_subcat,) + tuple( + (arr.name, part) + for part in range(i, min(i + n_to_one, arr.npartitions)) + ) + new_divisions = arr.divisions[::n_to_one] + if arr.npartitions % n_to_one: + new_divisions = new_divisions + (arr.divisions[-1],) + layer[(key, i0 + 1)] = (_subcat,) + tuple( + (arr.name, part0) for part0 in range(len(layer), arr.npartitions) + ) + elif one_to_n: + for i in range(arr.npartitions): + for part in range(one_to_n): + layer[(key, (i * one_to_n + part))] = ( + _subpart, + (arr.name, i), + one_to_n, + part, + ) + # TODO: if arr.known_divisions: + new_divisions = (None,) * (arr.npartitions * one_to_n + 1) + else: + raise ValueError + return layer, new_divisions + + @borrow_docstring(ak.enforce_type) def enforce_type( array: Array, diff --git a/tests/conftest.py b/tests/conftest.py index 97785744..f34977f8 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -13,10 +13,14 @@ @pytest.fixture(autouse=True) def clear_cache(): + # as of 2024-07-23, this causes 5 tests to pass; but this fixture should + # not be required + dak.lib.core.dak_cache.clear() + yield dak.lib.core.dak_cache.clear() -@pytest.fixture(scope="session") +@pytest.fixture() def single_record_file(tmp_path_factory: pytest.TempPathFactory) -> str: fname = tmp_path_factory.mktemp("data") / "single_record.json" record = {"record": [1, 2, 3]} @@ -25,7 +29,7 @@ def single_record_file(tmp_path_factory: pytest.TempPathFactory) -> str: return str(fname) -@pytest.fixture(scope="session") +@pytest.fixture() def ndjson_points1(tmp_path_factory: pytest.TempPathFactory) -> str: array = daktu.awkward_xy_points() fname = tmp_path_factory.mktemp("data") / "points_ndjson1.json" @@ -35,7 +39,7 @@ def ndjson_points1(tmp_path_factory: pytest.TempPathFactory) -> str: return str(fname) -@pytest.fixture(scope="session") +@pytest.fixture() def ndjson_points1_str(tmp_path_factory: pytest.TempPathFactory) -> str: array = daktu.awkward_xy_points_str() fname = tmp_path_factory.mktemp("data") / "points_ndjson1.json" @@ -45,7 +49,7 @@ def ndjson_points1_str(tmp_path_factory: pytest.TempPathFactory) -> str: return str(fname) -@pytest.fixture(scope="session") +@pytest.fixture() def ndjson_points2(tmp_path_factory: pytest.TempPathFactory) -> str: array = daktu.awkward_xy_points() fname = tmp_path_factory.mktemp("data") / "points_ndjson2.json" @@ -55,77 +59,77 @@ def ndjson_points2(tmp_path_factory: pytest.TempPathFactory) -> str: return str(fname) -@pytest.fixture(scope="session") +@pytest.fixture() def ndjson_points_file(ndjson_points1: str) -> str: return ndjson_points1 -@pytest.fixture(scope="session") +@pytest.fixture() def ndjson_points_file_str(ndjson_points1_str: str) -> str: return ndjson_points1_str -@pytest.fixture(scope="session") +@pytest.fixture() def daa(ndjson_points1: str) -> dak.Array: return dak.from_json([ndjson_points1] * 3) -@pytest.fixture(scope="session") +@pytest.fixture() def pq_points_dir(daa: dak.Array, tmp_path_factory: pytest.TempPathFactory) -> str: pqdir = tmp_path_factory.mktemp("pqfiles") dak.to_parquet(daa, str(pqdir)) return str(pqdir) -@pytest.fixture(scope="session") +@pytest.fixture() def daa_parquet(pq_points_dir: str) -> dak.Array: return cast(dak.Array, dak.from_parquet(pq_points_dir)) -@pytest.fixture(scope="session") +@pytest.fixture() def daa_str(ndjson_points1_str: str) -> dak.Array: return dak.from_json([ndjson_points1_str] * 3) -@pytest.fixture(scope="session") +@pytest.fixture() def caa(ndjson_points1: str) -> ak.Array: with open(ndjson_points1, "rb") as f: a = ak.from_json(f, line_delimited=True) return ak.concatenate([a, a, a]) -@pytest.fixture(scope="session") +@pytest.fixture() def caa_str(ndjson_points1_str: str) -> ak.Array: with open(ndjson_points1_str, "rb") as f: a = ak.from_json(f, line_delimited=True) return ak.concatenate([a, a, a]) -@pytest.fixture(scope="session") +@pytest.fixture() def daa_p1(ndjson_points1: str) -> dak.Array: return dak.from_json([ndjson_points1] * 3) -@pytest.fixture(scope="session") +@pytest.fixture() def daa_p2(ndjson_points2: str) -> dak.Array: return dak.from_json([ndjson_points2] * 3) -@pytest.fixture(scope="session") +@pytest.fixture() def caa_p1(ndjson_points1: str) -> ak.Array: with open(ndjson_points1) as f: lines = [json.loads(line) for line in f] return ak.Array(lines * 3) -@pytest.fixture(scope="session") +@pytest.fixture() def caa_p2(ndjson_points2: str) -> ak.Array: with open(ndjson_points2) as f: lines = [json.loads(line) for line in f] return ak.Array(lines * 3) -@pytest.fixture(scope="session") +@pytest.fixture() def L1() -> list[list[dict[str, float]]]: return [ [{"x": 1.0, "y": 1.1}, {"x": 2.0, "y": 2.2}, {"x": 3, "y": 3.3}], @@ -136,7 +140,7 @@ def L1() -> list[list[dict[str, float]]]: ] -@pytest.fixture(scope="session") +@pytest.fixture() def L2() -> list[list[dict[str, float]]]: return [ [{"x": 0.9, "y": 1.0}, {"x": 2.0, "y": 2.2}, {"x": 2.9, "y": 3.0}], @@ -147,7 +151,7 @@ def L2() -> list[list[dict[str, float]]]: ] -@pytest.fixture(scope="session") +@pytest.fixture() def L3() -> list[list[dict[str, float]]]: return [ [{"x": 1.9, "y": 9.0}, {"x": 2.0, "y": 8.2}, {"x": 9.9, "y": 9.0}], @@ -158,7 +162,7 @@ def L3() -> list[list[dict[str, float]]]: ] -@pytest.fixture(scope="session") +@pytest.fixture() def L4() -> list[list[dict[str, float]] | None]: return [ [{"x": 1.9, "y": 9.0}, {"x": 2.0, "y": 8.2}, {"x": 9.9, "y": 9.0}], @@ -169,14 +173,14 @@ def L4() -> list[list[dict[str, float]] | None]: ] -@pytest.fixture(scope="session") +@pytest.fixture() def caa_parquet(caa: ak.Array, tmp_path_factory: pytest.TempPathFactory) -> str: fname = tmp_path_factory.mktemp("parquet_data") / "caa.parquet" ak.to_parquet(caa, str(fname), extensionarray=False) return str(fname) -@pytest.fixture(scope="session") +@pytest.fixture() def unnamed_root_parquet_file(tmp_path_factory: pytest.TempPathFactory) -> str: from dask_awkward.lib.testutils import unnamed_root_ds diff --git a/tests/test_core.py b/tests/test_core.py index 7aced6bf..ca2e4c71 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -115,7 +115,7 @@ def test_len(ndjson_points_file: str) -> None: assert len(daa) == 10 daa.eager_compute_divisions() assert daa.known_divisions - assert len(daa) == 10 # type: ignore + assert len(daa) == 10 def test_meta_exists(daa: Array) -> None: @@ -157,7 +157,7 @@ def test_partitions_divisions(ndjson_points_file: str) -> None: assert not t1.known_divisions t2 = daa.partitions[1] assert t2.known_divisions - assert t2.divisions == (0, divs[2] - divs[1]) # type: ignore + assert t2.divisions == (0, divs[2] - divs[1]) def test_array_rebuild(ndjson_points_file: str) -> None: @@ -384,13 +384,13 @@ def test_to_meta(daa: Array) -> None: def test_record_str(daa: Array) -> None: r = daa[0] - assert type(r) == dak.Record + assert isinstance(r, dak.Record) assert str(r) == "dask.awkward" def test_record_to_delayed(daa: Array) -> None: r = daa[0] - assert type(r) == dak.Record + assert isinstance(r, dak.Record) d = r.to_delayed() x = r.compute().tolist() y = d.compute().tolist() @@ -399,7 +399,7 @@ def test_record_to_delayed(daa: Array) -> None: def test_record_fields(daa: Array) -> None: r = daa[0] - assert type(r) == dak.Record + assert isinstance(r, dak.Record) r._meta = None with pytest.raises(TypeError, match="metadata is missing"): assert not r.fields @@ -407,7 +407,7 @@ def test_record_fields(daa: Array) -> None: def test_record_dir(daa: Array) -> None: r = daa["points"][0][0] - assert type(r) == dak.Record + assert isinstance(r, dak.Record) d = dir(r) for f in r.fields: assert f in d @@ -418,7 +418,7 @@ def test_record_dir(daa: Array) -> None: # import pickle # r = daa[0] -# assert type(r) == dak.Record +# assert isinstance(r, dak.Record) # assert isinstance(r._meta, ak.Record) # dumped = pickle.dumps(r) @@ -537,7 +537,7 @@ def test_compatible_partitions_after_slice() -> None: assert_eq(lazy, ccrt) # sanity - assert dak.compatible_partitions(lazy, lazy + 2) # type: ignore + assert dak.compatible_partitions(lazy, lazy + 2) assert dak.compatible_partitions(lazy, dak.num(lazy, axis=1) > 2) assert not dak.compatible_partitions(lazy[:-2], lazy) @@ -646,6 +646,14 @@ def test_scalar_divisions(daa: Array) -> None: assert s.divisions == (None, None) +def test_scalar_binop_inv() -> None: + # GH #515 + x = dak.from_lists([[1]]) + y = x[0] # scalar + assert (0 - y) == -1 + assert (y - 0) == 1 + + def test_array_persist(daa: Array) -> None: daa2 = daa["points"]["x"].persist() assert_eq(daa["points"]["x"], daa2) @@ -886,7 +894,7 @@ def test_shape_only_ops(fn: Callable, tmp_path_factory: pytest.TempPathFactory) p = tmp_path_factory.mktemp("zeros-like-flat") ak.to_parquet(a, str(p / "file.parquet")) lazy = dak.from_parquet(str(p)) - result = fn(lazy.b) # type: ignore + result = fn(lazy.b) with dask.config.set({"awkward.optimization.enabled": True}): result.compute() @@ -898,7 +906,7 @@ def test_assign_behavior() -> None: with pytest.raises( TypeError, match="'mappingproxy' object does not support item assignment" ): - dx.behavior["should_fail"] = None # type: ignore + dx.behavior["should_fail"] = None assert dx.behavior == behavior @@ -909,7 +917,7 @@ def test_assign_attrs() -> None: with pytest.raises( TypeError, match="'mappingproxy' object does not support item assignment" ): - dx.attrs["should_fail"] = None # type: ignore + dx.attrs["should_fail"] = None assert dx.attrs == attrs diff --git a/tests/test_optimize.py b/tests/test_optimize.py index 52f082b8..d86040cb 100644 --- a/tests/test_optimize.py +++ b/tests/test_optimize.py @@ -28,7 +28,7 @@ def test_multiple_computes(ndjson_points_file: str) -> None: assert len(things3[1]) < len(things3[0]) things = dask.compute(ds1.points, ds2.points.x, ds2.points.y, ds1.points.y, ds3) - assert things[-1].tolist() == ak.Array(lists[0] + lists[1]).tolist() # type: ignore + assert things[-1].tolist() == ak.Array(lists[0] + lists[1]).tolist() def identity(x): diff --git a/tests/test_structure.py b/tests/test_structure.py index 180e922e..9f86b747 100644 --- a/tests/test_structure.py +++ b/tests/test_structure.py @@ -157,6 +157,15 @@ def test_drop_none(axis: int) -> None: assert_eq(d, e) +def test_nan_to_num(): + a = [[1, 2, np.nan], [], [np.nan], [5, 6, 7, np.nan], [1, 2], np.nan] + b = [[np.nan, 2, 1], [np.nan], [], np.nan, [7, 6, np.nan, 5], [np.nan, np.nan]] + c = dak.from_lists([a, b]) + d = dak.nan_to_num(c, nan=5) + e = ak.nan_to_num(ak.from_iter(a + b), nan=5) + assert_eq(d, e) + + @pytest.mark.parametrize("axis", [0, 1, -1]) def test_is_none(axis: int) -> None: a: list[Any] = [[1, 2, None], None, None, [], [None], [5, 6, 7, None], [1, 2], None] @@ -203,20 +212,20 @@ def test_pad_none(axis: int, target: int) -> None: def test_with_field(caa: ak.Array, daa: dak.Array) -> None: - assert_eq( - ak.with_field(caa["points"], caa["points"]["x"], where="xx"), - dak.with_field(daa["points"], daa["points"]["x"], where="xx"), - ) + new_caa = ak.with_field(caa["points"], caa["points"]["x"], where="xx") + new_daa = dak.with_field(daa["points"], daa["points"]["x"], where="xx") + assert_eq(new_caa, new_daa) + assert_eq(ak.without_field(new_caa, "xx"), ak.without_field(new_daa, "xx")) - assert_eq( - ak.with_field(caa["points"], 1, where="xx"), - dak.with_field(daa["points"], 1, where="xx"), - ) + new_caa = ak.with_field(caa["points"], 1, where="xx") + new_daa = dak.with_field(daa["points"], 1, where="xx") + assert_eq(new_caa, new_daa) + assert_eq(ak.without_field(new_caa, "xx"), ak.without_field(new_daa, "xx")) - assert_eq( - ak.with_field(caa["points"], 1.0, where="xx"), - dak.with_field(daa["points"], 1.0, where="xx"), - ) + new_caa = ak.with_field(caa["points"], 1.0, where="xx") + new_daa = dak.with_field(daa["points"], 1.0, where="xx") + assert_eq(new_caa, new_daa) + assert_eq(ak.without_field(new_caa, "xx"), ak.without_field(new_daa, "xx")) with pytest.raises( ValueError, @@ -224,6 +233,15 @@ def test_with_field(caa: ak.Array, daa: dak.Array) -> None: ): _ = dak.with_field([{"foo": 1.0}, {"foo": 2.0}], daa.points.x, where="x") + with pytest.raises( + ValueError, + match="Base argument in without_field must be a dask_awkward.Array", + ): + _ = dak.without_field( + [{"foo": [1.0, 2.0], "bar": [3.0, 4.0]}], + "bar", + ) + with pytest.raises( ValueError, match="with_field cannot accept string, bytes, list, or dict values yet", @@ -530,6 +548,31 @@ def test_repartition_whole(daa): assert_eq(daa, daa1, check_divisions=False) +def test_repartition_one_to_n(daa): + daa1 = daa.repartition(one_to_n=2) + assert daa1.npartitions == daa.npartitions * 2 + assert_eq(daa, daa1, check_divisions=False) + + +def test_repartition_n_to_one(): + daa = dak.from_lists([[[1, 2, 3], [], [4, 5]] * 2] * 52) + daa2 = daa.repartition(n_to_one=52) + assert daa2.npartitions == 1 + assert daa.compute().to_list() == daa2.compute().to_list() + daa2 = daa.repartition(n_to_one=53) + assert daa2.npartitions == 1 + assert daa.compute().to_list() == daa2.compute().to_list() + daa2 = daa.repartition(n_to_one=2) + assert daa2.npartitions == 26 + assert daa.compute().to_list() == daa2.compute().to_list() + daa2 = daa.repartition(n_to_one=10) + assert daa2.npartitions == 6 + assert daa.compute().to_list() == daa2.compute().to_list() + daa._divisions = (None,) * len(daa.divisions) + assert daa2.npartitions == 6 + assert daa.compute().to_list() == daa2.compute().to_list() + + def test_repartition_no_change(daa): daa1 = daa.repartition(divisions=(0, 5, 10, 15)) assert daa1.npartitions == 3