Skip to content

Commit

Permalink
Test arithmetic fallbacks
Browse files Browse the repository at this point in the history
  • Loading branch information
schmidtjonathan committed Aug 12, 2021
1 parent 220f60b commit 25a1686
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 19 deletions.
20 changes: 1 addition & 19 deletions src/probnum/linops/_arithmetic_fallbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,6 @@ def _inv(self) -> "ScaledLinearOperator":

return ScaledLinearOperator(self._linop.inv(), 1.0 / self._scalar)

def __mul__(self, other: BinaryOperandType) -> "LinearOperator":
if np.isscalar(other):
return ScaledLinearOperator(linop=self._linop, scalar=self._scalar * other)

return super().__mul__(other)

def __repr__(self) -> str:
return f"{self._scalar} * {self._linop}"

Expand All @@ -72,11 +66,6 @@ class SumLinearOperator(LinearOperator):
"""Sum of two linear operators."""

def __init__(self, *summands: LinearOperator):
if not all(isinstance(summand, LinearOperator) for summand in summands):
raise TypeError("All summands must be `LinearOperator`s")

if len(summands) < 2:
raise ValueError("There must be at least two summands")

if not all(summand.shape == summands[0].shape for summand in summands):
raise ValueError("All summands must have the same shape")
Expand Down Expand Up @@ -136,9 +125,7 @@ def _mul_fallback(
) -> Union[LinearOperator, NotImplementedType]:
res = NotImplemented

if isinstance(op1, LinearOperator) and isinstance(op2, LinearOperator):
pass # TODO: Implement generic Hadamard product
elif isinstance(op1, LinearOperator):
if isinstance(op1, LinearOperator):
if np.ndim(op2) == 0:
res = ScaledLinearOperator(op1, op2)
elif isinstance(op2, LinearOperator):
Expand All @@ -154,11 +141,6 @@ class ProductLinearOperator(LinearOperator):
"""(Operator) Product of two linear operators."""

def __init__(self, *factors: LinearOperator):
if not all(isinstance(factor, LinearOperator) for factor in factors):
raise TypeError("All factors must be `LinearOperator`s")

if len(factors) < 2:
raise ValueError("There must be at least two factors")

if not all(
lfactor.shape[1] == rfactor.shape[0]
Expand Down
45 changes: 45 additions & 0 deletions tests/test_linops/test_arithmetics_fallbacks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
"""Tests for linear operator arithmetics fallbacks."""

import numpy as np
import pytest

from probnum.linops._arithmetic_fallbacks import (
NegatedLinearOperator,
ProductLinearOperator,
ScaledLinearOperator,
SumLinearOperator,
)
from probnum.linops._linear_operator import Matrix
from probnum.problems.zoo.linalg import random_spd_matrix


@pytest.fixture
def rng():
return np.random.default_rng(123)


@pytest.fixture
def scalar():
return 3.14


@pytest.fixture
def rand_spd_mat(rng):
return Matrix(random_spd_matrix(rng, dim=4))


def test_scaled_linop(rand_spd_mat, scalar):
with pytest.raises(TypeError):
ScaledLinearOperator(np.random.rand(4, 4), scalar=scalar)
with pytest.raises(TypeError):
ScaledLinearOperator(rand_spd_mat, scalar=np.ones(4))

scaled1 = ScaledLinearOperator(rand_spd_mat, scalar=0.0)
scaled2 = ScaledLinearOperator(rand_spd_mat, scalar=scalar)

with pytest.raises(np.linalg.LinAlgError):
scaled1.inv()

assert np.allclose(
scaled2.inv().todense(), (1.0 / scalar) * scaled2._linop.inv().todense()
)

0 comments on commit 25a1686

Please sign in to comment.