Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: typetracer 'under-touching' #542

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 33 additions & 25 deletions src/vector/backends/awkward.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,18 @@
vector._import_awkward()

ArrayOrRecord = typing.TypeVar("ArrayOrRecord", bound=typing.Union[ak.Array, ak.Record])
Array = typing.TypeVar("Array")

behavior: typing.Any = {}


def _touch(array: Array) -> Array:
# make sure that touching is only done on Awkward arrays
if isinstance(array, (ak.Array, ak.Record)) and ak.backend(array) == "typetracer":
return ak.typetracer.touch_data(array)
return array


# coordinates classes are a formality for Awkward #############################


Expand Down Expand Up @@ -126,9 +134,9 @@ def from_fields(cls, array: ak.Array) -> AzimuthalAwkward:
"""
fields = ak.fields(array)
if "x" in fields and "y" in fields:
return AzimuthalAwkwardXY(array["x"], array["y"])
return AzimuthalAwkwardXY(_touch(array["x"]), _touch(array["y"]))
elif "rho" in fields and "phi" in fields:
return AzimuthalAwkwardRhoPhi(array["rho"], array["phi"])
return AzimuthalAwkwardRhoPhi(_touch(array["rho"]), _touch(array["phi"]))
else:
raise ValueError(
"array does not have azimuthal coordinates (x, y or rho, phi): "
Expand All @@ -154,17 +162,17 @@ def from_momentum_fields(cls, array: ak.Array) -> AzimuthalAwkward:
"""
fields = ak.fields(array)
if "x" in fields and "y" in fields:
return AzimuthalAwkwardXY(array["x"], array["y"])
return AzimuthalAwkwardXY(_touch(array["x"]), _touch(array["y"]))
elif "x" in fields and "py" in fields:
return AzimuthalAwkwardXY(array["x"], array["py"])
return AzimuthalAwkwardXY(_touch(array["x"]), _touch(array["py"]))
elif "px" in fields and "y" in fields:
return AzimuthalAwkwardXY(array["px"], array["y"])
return AzimuthalAwkwardXY(_touch(array["px"]), _touch(array["y"]))
elif "px" in fields and "py" in fields:
return AzimuthalAwkwardXY(array["px"], array["py"])
return AzimuthalAwkwardXY(_touch(array["px"]), _touch(array["py"]))
elif "rho" in fields and "phi" in fields:
return AzimuthalAwkwardRhoPhi(array["rho"], array["phi"])
return AzimuthalAwkwardRhoPhi(_touch(array["rho"]), _touch(array["phi"]))
elif "pt" in fields and "phi" in fields:
return AzimuthalAwkwardRhoPhi(array["pt"], array["phi"])
return AzimuthalAwkwardRhoPhi(_touch(array["pt"]), _touch(array["phi"]))
else:
raise ValueError(
"array does not have azimuthal coordinates (x/px, y/py or rho/pt, phi): "
Expand Down Expand Up @@ -206,11 +214,11 @@ def from_fields(cls, array: ak.Array) -> LongitudinalAwkward:
"""
fields = ak.fields(array)
if "z" in fields:
return LongitudinalAwkwardZ(array["z"])
return LongitudinalAwkwardZ(_touch(array["z"]))
elif "theta" in fields:
return LongitudinalAwkwardTheta(array["theta"])
return LongitudinalAwkwardTheta(_touch(array["theta"]))
elif "eta" in fields:
return LongitudinalAwkwardEta(array["eta"])
return LongitudinalAwkwardEta(_touch(array["eta"]))
else:
raise ValueError(
"array does not have longitudinal coordinates (z or theta or eta): "
Expand All @@ -237,13 +245,13 @@ def from_momentum_fields(cls, array: ak.Array) -> LongitudinalAwkward:
"""
fields = ak.fields(array)
if "z" in fields:
return LongitudinalAwkwardZ(array["z"])
return LongitudinalAwkwardZ(_touch(array["z"]))
elif "pz" in fields:
return LongitudinalAwkwardZ(array["pz"])
return LongitudinalAwkwardZ(_touch(array["pz"]))
elif "theta" in fields:
return LongitudinalAwkwardTheta(array["theta"])
return LongitudinalAwkwardTheta(_touch(array["theta"]))
elif "eta" in fields:
return LongitudinalAwkwardEta(array["eta"])
return LongitudinalAwkwardEta(_touch(array["eta"]))
else:
raise ValueError(
"array does not have longitudinal coordinates (z/pz or theta or eta): "
Expand Down Expand Up @@ -284,9 +292,9 @@ def from_fields(cls, array: ak.Array) -> TemporalAwkward:
"""
fields = ak.fields(array)
if "t" in fields:
return TemporalAwkwardT(array["t"])
return TemporalAwkwardT(_touch(array["t"]))
elif "tau" in fields:
return TemporalAwkwardTau(array["tau"])
return TemporalAwkwardTau(_touch(array["tau"]))
else:
raise ValueError(
"array does not have temporal coordinates (t or tau): "
Expand All @@ -312,21 +320,21 @@ def from_momentum_fields(cls, array: ak.Array) -> TemporalAwkward:
"""
fields = ak.fields(array)
if "t" in fields:
return TemporalAwkwardT(array["t"])
return TemporalAwkwardT(_touch(array["t"]))
elif "E" in fields:
return TemporalAwkwardT(array["E"])
return TemporalAwkwardT(_touch(array["E"]))
elif "e" in fields:
return TemporalAwkwardT(array["e"])
return TemporalAwkwardT(_touch(array["e"]))
elif "energy" in fields:
return TemporalAwkwardT(array["energy"])
return TemporalAwkwardT(_touch(array["energy"]))
elif "tau" in fields:
return TemporalAwkwardTau(array["tau"])
return TemporalAwkwardTau(_touch(array["tau"]))
elif "M" in fields:
return TemporalAwkwardTau(array["M"])
return TemporalAwkwardTau(_touch(array["M"]))
elif "m" in fields:
return TemporalAwkwardTau(array["m"])
return TemporalAwkwardTau(_touch(array["m"]))
elif "mass" in fields:
return TemporalAwkwardTau(array["mass"])
return TemporalAwkwardTau(_touch(array["mass"]))
else:
raise ValueError(
"array does not have temporal coordinates (t/E/e/energy or tau/M/m/mass): "
Expand Down
14 changes: 14 additions & 0 deletions tests/backends/test_dask_awkward.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,17 @@ def test_constructor():
assert isinstance(vec.compute(), vector.backends.awkward.VectorAwkward2D)
assert ak.all(vec.x.compute() == ak.Array([1, 1.1]))
assert ak.all(vec.y.compute() == ak.Array([2, 2.2]))


def test_necessary_columns():
vec = vector.Array([[{"pt": 1, "phi": 2}], [], [{"pt": 3, "phi": 4}]])
dak_vec = dak.from_awkward(vec, npartitions=1)

cols = next(iter(dak.report_necessary_columns(dak_vec).values()))

# this may seem weird at first: why would one need "phi" and "rho", if one asked for "pt"?
# the reason is that vector will build internally a class with "phi" and "rho",
# see: https://github.com/scikit-hep/vector/blob/608da2d55a74eed25635fd408d1075b568773c99/src/vector/backends/awkward.py#L166-L167
# So, even if one asks for "pt", "phi" and "rho" are as well in order to build the vector class in the first place.
# (the same argument holds true for all other vector classes)
assert cols == frozenset({"phi", "rho"})
Loading