Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
0x00b1 committed Jun 10, 2024
1 parent 08199e9 commit fe35ab8
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 53 deletions.
31 changes: 9 additions & 22 deletions src/beignet/_apply_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,28 +22,15 @@ def _apply_transform(input: Tensor, transform: Tensor) -> Tensor:
Affine transformed position vector, has the same shape as the
position vector.
"""
if transform.ndim == 0:
return input * transform

indices = [chr(ord("a") + index) for index in range(input.ndim - 1)]

indices = "".join(indices)

if transform.ndim == 1:
return torch.einsum(
"i,...i->...i",
transform,
input,
)

if transform.ndim == 2:
return torch.einsum(
f"ij,...j->...i",
transform,
input,
)

raise ValueError
match transform.ndim:
case 0:
return input * transform
case 1:
return torch.einsum("i,...i->...i", transform, input)
case 2:
return torch.einsum("ij,...j->...i", transform, input)
case _:
raise ValueError


class _ApplyTransform(Function):
Expand Down
132 changes: 101 additions & 31 deletions src/beignet/func/_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@
import torch
from torch import Tensor

from beignet._apply_transform import _apply_transform, apply_transform
from beignet._invert_transform import invert_transform
import beignet

T = TypeVar("T")

Expand Down Expand Up @@ -126,7 +125,25 @@ def displacement_fn(
raise ValueError

if perturbation is not None:
return _apply_transform(perturbation, input - other)
transform = input - other

match transform.ndim:
case 0:
return perturbation * transform
case 1:
return torch.einsum(
"i,...i->...i",
transform,
perturbation,
)
case 2:
return torch.einsum(
"ij,...j->...i",
transform,
perturbation,
)
case _:
raise ValueError

return input - other

Expand All @@ -136,7 +153,7 @@ def shift_fn(input: Tensor, other: Tensor, **_) -> Tensor:
return displacement_fn, shift_fn

if parallelepiped:
inverted_transform = invert_transform(dimensions)
inverted_transform = beignet.invert_transform(dimensions)

if normalized:

Expand All @@ -154,22 +171,38 @@ def displacement_fn(
if "transform" in kwargs:
_transform = kwargs["transform"]

if "updated_transformation" in kwargs:
_transform = kwargs["updated_transformation"]
if "updated_transform" in kwargs:
_transform = kwargs["updated_transform"]

if len(input.shape) != 1:
raise ValueError

if input.shape != other.shape:
raise ValueError

displacement = apply_transform(
displacement = beignet.apply_transform(
torch.remainder(input - other + 1.0 * 0.5, 1.0) - 1.0 * 0.5,
_transform,
)

if perturbation is not None:
return _apply_transform(perturbation, displacement)
match displacement.ndim:
case 0:
return perturbation * displacement
case 1:
return torch.einsum(
"i,...i->...i",
displacement,
perturbation,
)
case 2:
return torch.einsum(
"ij,...j->...i",
displacement,
perturbation,
)
case _:
raise ValueError

return displacement

Expand All @@ -186,12 +219,12 @@ def shift_fn(input: Tensor, other: Tensor, **kwargs) -> Tensor:
if "transform" in kwargs:
_transform = kwargs["transform"]

_inverted_transform = invert_transform(_transform)
_inverted_transform = beignet.invert_transform(_transform)

if "updated_transformation" in kwargs:
_transform = kwargs["updated_transformation"]
if "updated_transform" in kwargs:
_transform = kwargs["updated_transform"]

return u(input, apply_transform(other, _inverted_transform))
return u(input, beignet.apply_transform(other, _inverted_transform))

return displacement_fn, shift_fn

Expand All @@ -203,12 +236,12 @@ def shift_fn(input: Tensor, other: Tensor, **kwargs) -> Tensor:
if "transform" in kwargs:
_transform = kwargs["transform"]

_inverted_transform = invert_transform(_transform)
_inverted_transform = beignet.invert_transform(_transform)

if "updated_transformation" in kwargs:
_transform = kwargs["updated_transformation"]
if "updated_transform" in kwargs:
_transform = kwargs["updated_transform"]

return input + apply_transform(other, _inverted_transform)
return input + beignet.apply_transform(other, _inverted_transform)

return displacement_fn, shift_fn

Expand All @@ -226,27 +259,43 @@ def displacement_fn(
if "transform" in kwargs:
_transform = kwargs["transform"]

_inverted_transform = invert_transform(_transform)
_inverted_transform = beignet.invert_transform(_transform)

if "updated_transformation" in kwargs:
_transform = kwargs["updated_transformation"]
if "updated_transform" in kwargs:
_transform = kwargs["updated_transform"]

input = apply_transform(input, _inverted_transform)
other = apply_transform(other, _inverted_transform)
input = beignet.apply_transform(input, _inverted_transform)
other = beignet.apply_transform(other, _inverted_transform)

if len(input.shape) != 1:
raise ValueError

if input.shape != other.shape:
raise ValueError

displacement = apply_transform(
displacement = beignet.apply_transform(
torch.remainder(input - other + 1.0 * 0.5, 1.0) - 1.0 * 0.5,
_transform,
)

if perturbation is not None:
return _apply_transform(perturbation, displacement)
match displacement.ndim:
case 0:
return perturbation * displacement
case 1:
return torch.einsum(
"i,...i->...i",
displacement,
perturbation,
)
case 2:
return torch.einsum(
"ij,...j->...i",
displacement,
perturbation,
)
case _:
raise ValueError

return displacement

Expand All @@ -263,17 +312,17 @@ def shift_fn(input: Tensor, other: Tensor, **kwargs) -> Tensor:
if "transform" in kwargs:
_transform = kwargs["transform"]

_inverted_transform = invert_transform(
_inverted_transform = beignet.invert_transform(
_transform,
)

if "updated_transformation" in kwargs:
_transform = kwargs["updated_transformation"]
if "updated_transform" in kwargs:
_transform = kwargs["updated_transform"]

return apply_transform(
return beignet.apply_transform(
u(
apply_transform(_inverted_transform, input),
apply_transform(_inverted_transform, other),
beignet.apply_transform(_inverted_transform, input),
beignet.apply_transform(_inverted_transform, other),
),
_transform,
)
Expand All @@ -298,10 +347,31 @@ def displacement_fn(
if input.shape != other.shape:
raise ValueError

displacement = torch.remainder(input - other + dimensions * 0.5, dimensions)
displacement = torch.remainder(
input - other + dimensions * 0.5,
dimensions,
)

if perturbation is not None:
return _apply_transform(perturbation, displacement - dimensions * 0.5)
transform = displacement - dimensions * 0.5

match transform.ndim:
case 0:
return perturbation * transform
case 1:
return torch.einsum(
"i,...i->...i",
transform,
perturbation,
)
case 2:
return torch.einsum(
"ij,...j->...i",
transform,
perturbation,
)
case _:
raise ValueError

return displacement - dimensions * 0.5

Expand Down

0 comments on commit fe35ab8

Please sign in to comment.