Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add a PyTorch backend #541

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft

Conversation

jpivarski
Copy link
Member

I've become familiar with PyTorch recently because of writing https://github.com/hsf-training/deep-learning-intro-for-hep/

I've also been looking at the Vector documentation because I think it needs an overhaul to be more physicist-friendly. Along the way, I noticed that there's no PyTorch backend yet, but it would be really useful to have one. Vector's approach to NumPy arrays is to expect them to be structured arrays, but feature vectors in an ML model are always unstructured. (Note: there's a conversion function: np.lib.recfunctions.structured_to_unstructured.)

Generally, feature vectors in an ML model will have a few indexes corresponding to vector coordinates and many others that don't. If the first 4 features are $p_T$, $\eta$, $\phi$, and mass, we might want to denote that with pt_index=0, phi_index=2, eta_index=1, mass_index=3 in such a way that they can be picked out of a tensor named features like

features[..., pt_index]
features[..., phi_index]
features[..., eta_index]
features[..., mass_index]

It would be nice if the features vector was a subclass of torch.Tensor that produces the above via

features.pt
features.phi
features.eta
features.mass

And then if someone asks for

features.pz

it would compute $p_z$ using the appropriate compute function. With torch as the lib argument of the vector._compute functions, they would all be autodiffed and could be used in an optimization procedure with backpropagation. The library functions that vector._compute needs,

allowed_lib_functions = [
"absolute",
"sign",
"copysign",
"maximum",
"minimum",
"sqrt",
"exp",
"log",
"sin",
"cos",
"tan",
"arcsin",
"arccos",
"arctan",
"arctan2",
"sinh",
"cosh",
"tanh",
"arcsinh",
"arccosh",
"arctanh",
"isclose",
]

are all defined in the torch module:

>>> torch.absolute
<built-in method absolute of type object at 0x7596c071cde0>
>>> torch.sign
<built-in method sign of type object at 0x7596c071cde0>
>>> torch.copysign
<built-in method copysign of type object at 0x7596c071cde0>
>>> torch.maximum
<built-in method maximum of type object at 0x7596c071cde0>
>>> torch.minimum
<built-in method minimum of type object at 0x7596c071cde0>
>>> torch.sqrt
<built-in method sqrt of type object at 0x7596c071cde0>
>>> torch.exp
<built-in method exp of type object at 0x7596c071cde0>
>>> torch.log
<built-in method log of type object at 0x7596c071cde0>
>>> torch.sin
<built-in method sin of type object at 0x7596c071cde0>
>>> torch.cos
<built-in method cos of type object at 0x7596c071cde0>
>>> torch.tan
<built-in method tan of type object at 0x7596c071cde0>
>>> torch.arcsin
<built-in method arcsin of type object at 0x7596c071cde0>
>>> torch.arccos
<built-in method arccos of type object at 0x7596c071cde0>
>>> torch.arctan
<built-in method arctan of type object at 0x7596c071cde0>
>>> torch.arctan2
<built-in method arctan2 of type object at 0x7596c071cde0>
>>> torch.sinh
<built-in method sinh of type object at 0x7596c071cde0>
>>> torch.cosh
<built-in method cosh of type object at 0x7596c071cde0>
>>> torch.tanh
<built-in method tanh of type object at 0x7596c071cde0>
>>> torch.arcsinh
<built-in method arcsinh of type object at 0x7596c071cde0>
>>> torch.arccosh
<built-in method arccosh of type object at 0x7596c071cde0>
>>> torch.arctanh
<built-in method arctanh of type object at 0x7596c071cde0>
>>> torch.isclose
<built-in method isclose of type object at 0x7596c071cde0>

so they probably don't even need a shim (which SymPy needed).

Below is the start of an implementation, using https://pytorch.org/docs/stable/notes/extending.html#extending-torch-python-api as a guide. PyTorch defines a __torch_function__ method (see this investigation), making it possible to overload without even creating real subclasses of torch.Tensor, but I think it's a good idea to make subclasses of torch.Tensor because these are mostly-normal feature vectors: they just have a few extra properties and methods.

But then I got to the point where I'd have to wrap all of the functions and remembered that that's where all of the complexity is. Some functions (possibly methods or properties) take 1 input vectors and return a non-vector, others return a vector, while some other functions take 2 input vectors with both kinds of output, I don't think there are any functions that take more than 2, but there are some functions that don't do anything to the vector properties, like a PyTorch function to move data to and from the GPU or change its dtype. (Possible simplification: maybe all vector components can be forced to be float32?)

Some of the functions will have to shuffle the indexes to make them line up. Say, for instance, that you have featuresA with x_index=0, y_index=1 and featuresB with x_index=4, y_index=2. When you add featuresA + featuresB, you'll need to pass

featuresA[..., [0, 1]], featuresB[..., [4, 2]]

into the vector._compute.planar.add.dispatch function.

So that's where I left the implementation, as a sketch of the idea of interpreting the axis=-1 dimension of feature arrays as vector components, passing torch as the compute functions' lib. Considering that each of the different types of functions has to be handled differently before calling compute functions, this is not as easy as I thought (a one-day project), but it's still not a huge project. I'd also like to find out if there's a "market" for this backend: I had assumed that spatial and momentum vector calculations would be useful as (the first) part of an ML model, but I wonder if anyone has any known use-cases.

Also, I have to say that the ML "vector" and "tensor" terminology is incredibly confusing in this context. When we say that a feature-set has 2D, 3D, or 4D spatial or momentum vector components, we have to be sure to not call that feature-set a "feature vector," since that's a different thing.

@jpivarski
Copy link
Member Author

jpivarski commented Dec 12, 2024

And there could also be some normal PyTorch functions that shouldn't be allowed to compute at all. For instance, multiplying two feature-sets that contain vectors (as opposed to cross-products, which would be allowed... but what does the cross-product do to the features that are not vector components?!?).

By the way, the use-case for features that are not vector components is for all of the other variables that describe a particle, such as isolation, charge, mass, etc. Since this is an input to an ML model, you probably want to throw the whole kitchen sink into it.

Here's another normal PyTorch function that shouldn't be allowed, or should be restricted in complex ways: reshape. If you reshape the last axis, the pointers to vector component indexes won't mean what they used to mean.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant