Skip to content

Commit

Permalink
beignet.func.space
Browse files Browse the repository at this point in the history
  • Loading branch information
0x00b1 committed Apr 22, 2024
1 parent 6be20c1 commit d05eda3
Showing 1 changed file with 34 additions and 21 deletions.
55 changes: 34 additions & 21 deletions src/beignet/func/_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@
import torch
from torch import Tensor

from beignet.operators.__transform import _transform
from beignet.operators._inverse_transform import inverse_transform
from beignet.operators._transform import transform
import beignet.operators
import beignet.operators.__transform
import beignet.operators._inverse_transform
import beignet.operators._transform

T = TypeVar("T")

Expand Down Expand Up @@ -125,7 +126,7 @@ def _displacement_fn(
raise ValueError

if perturbation is not None:
return _transform(a - b, perturbation)
return beignet.operators.__transform._transform(a - b, perturbation)

return a - b

Expand All @@ -135,7 +136,7 @@ def _shift_fn(a: Tensor, b: Tensor, **_) -> Tensor:
return _displacement_fn, _shift_fn

if parallelepiped:
inverse_transformation = inverse_transform(dimensions)
inverse_transformation = beignet.operators.inverse_transform(dimensions)

if normalized:

Expand Down Expand Up @@ -168,10 +169,14 @@ def _displacement_fn(

displacement = displacement - 1.0 * 0.5

displacement = transform(_transformation, displacement)
displacement = beignet.operators.transform(
_transformation, displacement
)

if perturbation is not None:
return _transform(displacement, perturbation)
return beignet.operators.__transform._transform(
displacement, perturbation
)

return displacement

Expand All @@ -188,12 +193,16 @@ def _shift_fn(a: Tensor, b: Tensor, **kwargs) -> Tensor:
if "transformation" in kwargs:
_transformation = kwargs["transformation"]

_inverse_transformation = inverse_transform(_transformation)
_inverse_transformation = beignet.operators.inverse_transform(
_transformation
)

if "updated_transformation" in kwargs:
_transformation = kwargs["updated_transformation"]

return _u(a, transform(_inverse_transformation, b))
return _u(
a, beignet.operators.transform(_inverse_transformation, b)
)

return _displacement_fn, _shift_fn

Expand All @@ -205,14 +214,14 @@ def _shift_fn(a: Tensor, b: Tensor, **kwargs) -> Tensor:
if "transformation" in kwargs:
_transformation = kwargs["transformation"]

_inverse_transformation = inverse_transform(
_inverse_transformation = beignet.operators.inverse_transform(
_transformation,
)

if "updated_transformation" in kwargs:
_transformation = kwargs["updated_transformation"]

return a + transform(_inverse_transformation, b)
return a + beignet.operators.transform(_inverse_transformation, b)

return _displacement_fn, _shift_fn

Expand All @@ -230,13 +239,15 @@ def _displacement_fn(
if "transformation" in kwargs:
_transformation = kwargs["transformation"]

_inverse_transformation = inverse_transform(_transformation)
_inverse_transformation = beignet.operators.inverse_transform(
_transformation
)

if "updated_transformation" in kwargs:
_transformation = kwargs["updated_transformation"]

a = transform(_inverse_transformation, a)
b = transform(_inverse_transformation, b)
a = beignet.operators.transform(_inverse_transformation, a)
b = beignet.operators.transform(_inverse_transformation, b)

if len(a.shape) != 1:
raise ValueError
Expand All @@ -250,10 +261,12 @@ def _displacement_fn(

displacement = displacement - 1.0 * 0.5

displacement = transform(_transformation, displacement)
displacement = beignet.operators.transform(_transformation, displacement)

if perturbation is not None:
return _transform(displacement, perturbation)
return beignet.operators.__transform._transform(
displacement, perturbation
)

return displacement

Expand All @@ -270,18 +283,18 @@ def _shift_fn(a: Tensor, b: Tensor, **kwargs) -> Tensor:
if "transformation" in kwargs:
_transformation = kwargs["transformation"]

_inverse_transformation = inverse_transform(
_inverse_transformation = beignet.operators.inverse_transform(
_transformation,
)

if "updated_transformation" in kwargs:
_transformation = kwargs["updated_transformation"]

return transform(
return beignet.operators._transform.transform(
_transformation,
_u(
transform(_inverse_transformation, a),
transform(_inverse_transformation, b),
beignet.operators.transform(_inverse_transformation, a),
beignet.operators.transform(_inverse_transformation, b),
),
)

Expand Down Expand Up @@ -310,7 +323,7 @@ def _displacement_fn(
displacement = displacement - dimensions * 0.5

if perturbation is not None:
return _transform(displacement, perturbation)
return beignet.operators.__transform._transform(displacement, perturbation)

return displacement

Expand Down

0 comments on commit d05eda3

Please sign in to comment.