Skip to content

Commit

Permalink
Clean up typing
Browse files Browse the repository at this point in the history
  • Loading branch information
marvinpfoertner committed Aug 28, 2020
1 parent 534f8e6 commit f00c817
Show file tree
Hide file tree
Showing 9 changed files with 63 additions and 46 deletions.
Empty file removed src/probnum/_lib/__init__.py
Empty file.
23 changes: 0 additions & 23 deletions src/probnum/_lib/argtypes.py

This file was deleted.

5 changes: 3 additions & 2 deletions src/probnum/random_variables/_dirac.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@
import numpy as np

from probnum import utils as _utils
from probnum._lib.argtypes import (
from probnum.typing import (
ShapeType,
# Argument Types
ArrayLikeGetitemArgType,
RandomStateArgType,
ShapeArgType,
)
from probnum.typing import ShapeType

from . import _random_variable

Expand Down
7 changes: 4 additions & 3 deletions src/probnum/random_variables/_normal.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,15 @@
import scipy.stats

from probnum import utils as _utils
from probnum._lib.argtypes import (
from probnum.linalg import linops
from probnum.typing import (
ShapeType,
# Argument Types
ArrayLikeGetitemArgType,
FloatArgType,
RandomStateArgType,
ShapeArgType,
)
from probnum.linalg import linops
from probnum.typing import ShapeType

from . import _random_variable

Expand Down
6 changes: 4 additions & 2 deletions src/probnum/random_variables/_random_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,15 @@
import numpy as np

from probnum import utils as _utils
from probnum._lib.argtypes import (
from probnum.typing import (
RandomStateType,
ShapeType,
# Argument Types
DTypeArgType,
FloatArgType,
RandomStateArgType,
ShapeArgType,
)
from probnum.typing import RandomStateType, ShapeType

try:
# functools.cached_property is only available in Python >=3.8
Expand Down
31 changes: 30 additions & 1 deletion src/probnum/typing.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,36 @@
from typing import Tuple, Union
import numbers
from typing import Iterable, Tuple, Union

import numpy as np

########################################################################################
# API Types
########################################################################################

ShapeType = Tuple[int, ...]

RandomStateType = Union[np.random.RandomState, np.random.Generator]

########################################################################################
# Argument Types
########################################################################################

IntArgType = Union[int, numbers.Integral, np.integer]
FloatArgType = Union[float, numbers.Real, np.floating]

ShapeArgType = Union[IntArgType, Iterable[IntArgType]]
DTypeArgType = Union[np.dtype, str]

ScalarArgType = Union[int, float, complex, numbers.Number, np.float_]

ArrayLikeGetitemArgType = Union[
int,
slice,
np.ndarray,
np.newaxis,
None,
type(Ellipsis),
Tuple[Union[int, slice, np.ndarray, np.newaxis, None, type(Ellipsis)], ...],
]

RandomStateArgType = Union[None, int, np.random.RandomState, np.random.Generator]
1 change: 0 additions & 1 deletion src/probnum/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from .arrayutils import *
from .fctutils import *
from .randomutils import *
from .scalarutils import *

# Public classes and functions. Order is reflected in documentation.
__all__ = [
Expand Down
26 changes: 22 additions & 4 deletions src/probnum/utils/argutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,20 @@
import numpy as np
import scipy._lib._util

from probnum.typing import ShapeType, RandomStateType
from probnum._lib.argtypes import ShapeArgType, RandomStateArgType
from probnum.typing import (
DTypeArgType,
RandomStateArgType,
RandomStateType,
ScalarArgType,
ShapeArgType,
ShapeType,
)

__all__ = ["as_shape", "as_random_state", "as_numpy_scalar"]


def as_random_state(x: RandomStateArgType) -> RandomStateType:
return scipy._lib._util.check_random_state(x)


def as_shape(x: ShapeArgType) -> ShapeType:
Expand All @@ -26,5 +38,11 @@ def as_shape(x: ShapeArgType) -> ShapeType:
return tuple(int(item) for item in x)


def as_random_state(x: RandomStateArgType) -> RandomStateType:
return scipy._lib._util.check_random_state(x)
def as_numpy_scalar(x: ScalarArgType, dtype: DTypeArgType = None) -> np.generic:
is_scalar = np.isscalar(x)
is_scalar_array = isinstance(x, np.ndarray) and x.ndim == 0

if not (is_scalar or is_scalar_array):
raise ValueError("The given input is not a scalar.")

return np.asarray(x, dtype=dtype)[()]
10 changes: 0 additions & 10 deletions src/probnum/utils/scalarutils.py

This file was deleted.

0 comments on commit f00c817

Please sign in to comment.