Skip to content

Commit

Permalink
beignet.func.space
Browse files Browse the repository at this point in the history
  • Loading branch information
0x00b1 committed Apr 22, 2024
1 parent 2a8f5ed commit 32bf5ec
Showing 1 changed file with 55 additions and 49 deletions.
104 changes: 55 additions & 49 deletions src/beignet/func/_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@


def space(
dimensions: Optional[Tensor] = None,
dimensions: Tensor | None = None,
*,
normalized: bool = True,
parallelepiped: bool = True,
Expand Down Expand Up @@ -111,25 +111,27 @@ def space(
if dimensions is None:

def displacement_fn(
a: Tensor,
b: Tensor,
input: Tensor,
other: Tensor,
*,
perturbation: Tensor | None = None,
**_,
) -> Tensor:
if len(a.shape) != 1:
if len(input.shape) != 1:
raise ValueError

if a.shape != b.shape:
if input.shape != other.shape:
raise ValueError

if perturbation is not None:
return beignet.operators.__transform._transform(a - b, perturbation)
return beignet.operators.__transform._transform(
input - other, perturbation
)

return a - b
return input - other

def shift_fn(a: Tensor, b: Tensor, **_) -> Tensor:
return a + b
def shift_fn(input: Tensor, other: Tensor, **_) -> Tensor:
return input + other

return displacement_fn, shift_fn

Expand All @@ -139,8 +141,8 @@ def shift_fn(a: Tensor, b: Tensor, **_) -> Tensor:
if normalized:

def displacement_fn(
a: Tensor,
b: Tensor,
input: Tensor,
other: Tensor,
*,
perturbation: Optional[Tensor] = None,
**kwargs,
Expand All @@ -155,17 +157,15 @@ def displacement_fn(
if "updated_transformation" in kwargs:
_transformation = kwargs["updated_transformation"]

if len(a.shape) != 1:
if len(input.shape) != 1:
raise ValueError

if a.shape != b.shape:
if input.shape != other.shape:
raise ValueError

displacement = a - b

displacement = torch.remainder(displacement + 1.0 * 0.5, 1.0)

displacement = displacement - 1.0 * 0.5
displacement = (
torch.remainder(input - other + 1.0 * 0.5, 1.0) - 1.0 * 0.5
)

displacement = beignet.operators.transform(
_transformation, displacement
Expand All @@ -180,10 +180,10 @@ def displacement_fn(

if remapped:

def _u(a: Tensor, b: Tensor) -> Tensor:
return torch.remainder(a + b, 1.0)
def u(input: Tensor, other: Tensor) -> Tensor:
return torch.remainder(input + other, 1.0)

def shift_fn(a: Tensor, b: Tensor, **kwargs) -> Tensor:
def shift_fn(input: Tensor, other: Tensor, **kwargs) -> Tensor:
_transformation = dimensions

_inverse_transformation = inverse_transformation
Expand All @@ -198,11 +198,14 @@ def shift_fn(a: Tensor, b: Tensor, **kwargs) -> Tensor:
if "updated_transformation" in kwargs:
_transformation = kwargs["updated_transformation"]

return u(a, beignet.operators.transform(_inverse_transformation, b))
return u(
input,
beignet.operators.transform(_inverse_transformation, other),
)

return displacement_fn, shift_fn

def shift_fn(a: Tensor, b: Tensor, **kwargs) -> Tensor:
def shift_fn(input: Tensor, other: Tensor, **kwargs) -> Tensor:
_transformation = dimensions

_inverse_transformation = inverse_transformation
Expand All @@ -217,15 +220,17 @@ def shift_fn(a: Tensor, b: Tensor, **kwargs) -> Tensor:
if "updated_transformation" in kwargs:
_transformation = kwargs["updated_transformation"]

return a + beignet.operators.transform(_inverse_transformation, b)
return input + beignet.operators.transform(
_inverse_transformation, other
)

return displacement_fn, shift_fn

def displacement_fn(
a: Tensor,
b: Tensor,
input: Tensor,
other: Tensor,
*,
perturbation: Optional[Tensor] = None,
perturbation: Tensor | None = None,
**kwargs,
) -> Tensor:
_transformation = dimensions
Expand All @@ -242,16 +247,16 @@ def displacement_fn(
if "updated_transformation" in kwargs:
_transformation = kwargs["updated_transformation"]

a = beignet.operators.transform(_inverse_transformation, a)
b = beignet.operators.transform(_inverse_transformation, b)
input = beignet.operators.transform(_inverse_transformation, input)
other = beignet.operators.transform(_inverse_transformation, other)

if len(a.shape) != 1:
if len(input.shape) != 1:
raise ValueError

if a.shape != b.shape:
if input.shape != other.shape:
raise ValueError

displacement = torch.remainder(a - b + 1.0 * 0.5, 1.0) - 1.0 * 0.5
displacement = torch.remainder(input - other + 1.0 * 0.5, 1.0) - 1.0 * 0.5

displacement = beignet.operators.transform(_transformation, displacement)

Expand All @@ -264,10 +269,10 @@ def displacement_fn(

if remapped:

def u(a: Tensor, b: Tensor) -> Tensor:
return torch.remainder(a + b, 1.0)
def u(input: Tensor, other: Tensor) -> Tensor:
return torch.remainder(input + other, 1.0)

def _shift_fn(a: Tensor, b: Tensor, **kwargs) -> Tensor:
def shift_fn(input: Tensor, other: Tensor, **kwargs) -> Tensor:
_transformation = dimensions

_inverse_transformation = inverse_transformation
Expand All @@ -285,33 +290,34 @@ def _shift_fn(a: Tensor, b: Tensor, **kwargs) -> Tensor:
return beignet.operators.transform(
_transformation,
u(
beignet.operators.transform(_inverse_transformation, a),
beignet.operators.transform(_inverse_transformation, b),
beignet.operators.transform(_inverse_transformation, input),
beignet.operators.transform(_inverse_transformation, other),
),
)

return displacement_fn, _shift_fn
return displacement_fn, shift_fn

def shift_fn(a: Tensor, b: Tensor, **_) -> Tensor:
return a + b
def shift_fn(input: Tensor, other: Tensor, **_) -> Tensor:
return input + other

return displacement_fn, shift_fn

def displacement_fn(
a: Tensor,
b: Tensor,
input: Tensor,
other: Tensor,
*,
perturbation: Tensor | None = None,
**_,
) -> Tensor:
if len(a.shape) != 1:
if len(input.shape) != 1:
raise ValueError

if a.shape != b.shape:
if input.shape != other.shape:
raise ValueError

displacement = (
torch.remainder(a - b + dimensions * 0.5, dimensions) - dimensions * 0.5
torch.remainder(input - other + dimensions * 0.5, dimensions)
- dimensions * 0.5
)

if perturbation is not None:
Expand All @@ -321,11 +327,11 @@ def displacement_fn(

if remapped:

def shift_fn(a: Tensor, b: Tensor, **_) -> Tensor:
return torch.remainder(a + b, dimensions)
def shift_fn(input: Tensor, other: Tensor, **_) -> Tensor:
return torch.remainder(input + other, dimensions)
else:

def shift_fn(a: Tensor, b: Tensor, **_) -> Tensor:
return a + b
def shift_fn(input: Tensor, other: Tensor, **_) -> Tensor:
return input + other

return displacement_fn, shift_fn

0 comments on commit 32bf5ec

Please sign in to comment.