Skip to content

Commit

Permalink
Scalar typetracer after reduction
Browse files Browse the repository at this point in the history
  • Loading branch information
martindurant committed Jul 29, 2024
1 parent a3c201d commit 0187dcb
Showing 1 changed file with 19 additions and 5 deletions.
24 changes: 19 additions & 5 deletions src/dask_awkward/lib/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,19 +500,17 @@ def f(self, other):
deps = [self]
plns = [self.name]
if is_dask_collection(other):
task = (op, self.key, *other.__dask_keys__())
deps.append(other)
plns.append(other.name)
if inv:
plns.insert(0, other.name)
task = (op, *other.__dask_keys__(), self.key)
else:
plns.append(other.name)
task = (op, self.key, *other.__dask_keys__())
else:
if inv:
task = (op, other, self.key)
else:
task = (op, self.key, other)
if inv:
plns.reverse()
graph = HighLevelGraph.from_collections(
name,
layer=AwkwardMaterializedLayer(
Expand All @@ -532,6 +530,11 @@ def f(self, other):
meta = op(other, self._meta)
else:
meta = op(self._meta, other)
if meta.ndim:
divisions = other.divisions if is_dask_collection(other) else [0, 1]
return new_array_object(
graph, name, meta=ak.Array(meta), divisions=divisions
)
return new_scalar_object(graph, name, meta=meta)

return f
Expand Down Expand Up @@ -570,6 +573,15 @@ def f(*args):
args = tuple(
ak.Array(arg.content) if isinstance(arg, MaybeNone) else arg for arg in args
)
args = tuple(
(
ak.Array(arg)
if isinstance(arg, ak._nplikes.typetracer.TypeTracerArray)
else arg
)
for arg in args
)

result = op(*args)
return result

Expand Down Expand Up @@ -2598,6 +2610,8 @@ def typetracer_array(a: ak.Array | Array) -> ak.Array:
behavior=a._behavior,
attrs=a._attrs,
)
elif isinstance(a, numbers.Number):
return ak.Array([a]).layout.to_typetracer()
else:
msg = (
"`a` should be an awkward array or a Dask awkward collection.\n"
Expand Down

0 comments on commit 0187dcb

Please sign in to comment.