Skip to content

Commit

Permalink
beignet.func.space
Browse files Browse the repository at this point in the history
  • Loading branch information
0x00b1 committed Apr 22, 2024
1 parent 0c5b18d commit ba76031
Showing 1 changed file with 48 additions and 4 deletions.
52 changes: 48 additions & 4 deletions src/beignet/operators/_transform.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Tuple

import torch
from torch import Tensor
from torch.autograd import Function

Expand Down Expand Up @@ -47,7 +48,27 @@ def forward(transformation: Tensor, position: Tensor) -> Tensor:
Tensor
Affine transformed position of shape `(..., dimension)`.
"""
return _transform(transformation, position)
indexes = "".join(
[chr(ord("a") + index) for index in range(position.ndim - 1)]
)

match transformation.ndim:
case 0:
return position * transformation
case 1:
return torch.einsum(
f"i,{indexes}i->{indexes}i",
transformation,
position,
)
case 2:
return torch.einsum(
f"ij,{indexes}j->{indexes}i",
transformation,
position,
)
case _:
raise ValueError

@staticmethod
def setup_context(ctx, inputs, output):
Expand All @@ -65,11 +86,34 @@ def jvp(

output = _transform(transformation, position)

grad_output = grad_position + _transform(
grad_transformation,
position,
indexes = "".join(
[chr(ord("a") + index) for index in range(position.ndim - 1)]
)

match grad_transformation.ndim:
case 0:
transformed_grad = position * grad_transformation

grad_output = grad_position + transformed_grad
case 1:
transformed_grad = torch.einsum(
f"i,{indexes}i->{indexes}i",
grad_transformation,
position,
)

grad_output = grad_position + transformed_grad
case 2:
transformed_grad = torch.einsum(
f"ij,{indexes}j->{indexes}i",
grad_transformation,
position,
)

grad_output = grad_position + transformed_grad
case _:
raise ValueError

return output, grad_output

@staticmethod
Expand Down

0 comments on commit ba76031

Please sign in to comment.