Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(Advanced) Indexing, Slicing, Masking, Reshaping and Transposition for Random Variables and Distributions #134

Merged
34 changes: 33 additions & 1 deletion src/probnum/prob/distributions/dirac.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,17 +90,49 @@ def sample(self, size=(), seed=None):
A=self.parameters["support"], reps=tuple([*size, *np.repeat(1, ndims)])
)

def __getitem__(self, key):
"""
Marginalization for multivariate Dirac distributions, expressed by means of
(advanced) indexing, masking and slicing.

This method
supports all modes of array indexing presented in

https://numpy.org/doc/1.19/reference/arrays.indexing.html.

Parameters
----------
key : int or slice or ndarray or tuple of None, int, slice, or ndarray
Indices, slice objects and/or boolean masks specifying which entries to keep
while marginalizing over all other entries.
"""
return Dirac(
support=self.parameters["support"][key], random_state=self.random_state,
)

def reshape(self, newshape):
try:
# Reshape support
self._parameters["support"].reshape(newshape=newshape)
return Dirac(
support=self.parameters["support"].reshape(newshape),
random_state=self.random_state,
)
except ValueError:
raise ValueError(
"Cannot reshape this Dirac distribution to the given shape: {}".format(
str(newshape)
)
)

def _reshape_inplace(self, newshape):
self.parameters["support"].shape = newshape

def transpose(self, **axes):
return Dirac(
support=self.parameters["support"].transpose(**axes),
random_state=self.random_state,
)

# Binary arithmetic operations
def __add__(self, other):
if isinstance(other, Dirac):
Expand Down
39 changes: 39 additions & 0 deletions src/probnum/prob/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,15 @@ def shape(self):
"""Shape of samples from this distribution."""
return self._shape

@shape.setter
def shape(self, newshape):
self._reshape_inplace(newshape)

self._set_shape(newshape)

def _reshape_inplace(self, newshape):
raise NotImplementedError

@property
def dtype(self):
"""``Dtype`` of elements of samples from this distribution."""
Expand Down Expand Up @@ -416,3 +425,33 @@ def reshape(self, newshape):
self.__class__.__name__
)
)

def transpose(self, *axes):
raise NotImplementedError(
"Transposition not implemented for distribution of type: {}.".format(
self.__class__.__name__
)
)

def __getitem__(self, key):
"""
(Advanced) indexing, masking and slicing into (realizations of) this
distribution.

This is essentially marginalization for multivariate distributions. This method
supports all modes of array indexing presented in

https://numpy.org/doc/1.19/reference/arrays.indexing.html.

However, the available modes of indexing vary with the concrete distribution.

Parameters
----------
key : int or slice or ndarray or tuple of None, int, slice, or ndarray
Indices, slice objects and/or boolean masks specifying which entries to keep
while marinalizing over all other entries.
"""
raise NotImplementedError(
"(Advanced) indexing and slicing is not implemented for distribution of "
"type: {}.".format(self.__class__.__name__)
)
110 changes: 89 additions & 21 deletions src/probnum/prob/distributions/normal.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,79 @@ def cov(self):
def var(self):
raise NotImplementedError

def __getitem__(self, key):
"""
Marginalization in multi- and matrixvariate normal distributions, expressed by
means of (advanced) indexing, masking and slicing.

We support all modes of array indexing presented in

https://numpy.org/doc/1.19/reference/arrays.indexing.html.

Note that, currently, this method does not work for normal distributions other
than the multi- and matrixvariate versions.

Parameters
----------
key : int or slice or ndarray or tuple of None, int, slice, or ndarray
Indices, slice objects and/or boolean masks specifying which entries to keep
while marginalizing over all other entries.
"""
if not isinstance(key, tuple):
key = (key,)

# Select entries from mean
mean = self.mean()[key]

# Select submatrix from covariance matrix
cov = self.cov().reshape(self.mean().shape + self.mean().shape)
cov = cov[key][tuple([slice(None)] * mean.ndim) + key]

if cov.ndim > 2:
cov = cov.reshape(mean.size, mean.size)

return Normal(mean=mean, cov=cov, random_state=self.random_state)

def reshape(self, newshape):
try:
reshaped_mean = self.mean().reshape(newshape)
except ValueError:
raise ValueError(
f"Cannot reshape this normal distribution to the given shape: "
f"{newshape}"
)

reshaped_cov = self.cov()

if reshaped_mean.ndim > 0 and reshaped_cov.ndim == 0:
reshaped_cov = reshaped_cov.reshape(1, 1)

return Normal(
mean=reshaped_mean, cov=reshaped_cov, random_state=self.random_state,
)

def _reshape_inplace(self, newshape):
self.mean.shape = newshape

if self.mean().ndim > 0 and self.cov().ndim == 0:
self.cov().shape = (1, 1)

def transpose(self, *axes):
if len(axes) == 1 and isinstance(axes[0], tuple):
axes = axes[0]
elif (len(axes) == 1 and axes[0] is None) or len(axes) == 0:
axes = tuple(reversed(range(self.mean().ndim)))

mean_t = self.mean().transpose(*axes).copy()

# Transpose covariance
cov_axes = axes + tuple(mean_t.ndim + axis for axis in axes)
cov_t = self.cov().reshape(self.mean().shape + self.mean().shape)
cov_t = cov_t.transpose(*cov_axes).copy()
cov_t = cov_t.reshape(mean_t.size, mean_t.size)

return Normal(mean=mean_t, cov=cov_t, random_state=self.random_state)

# Binary arithmetic

def __add__(self, other):
Expand Down Expand Up @@ -375,14 +448,15 @@ def sample(self, size=()):
loc=self.mean(), scale=self.std(), size=size, random_state=self.random_state
)

def reshape(self, newshape):
if np.prod(newshape) != 1:
raise ValueError(
f"Cannot reshape distribution with shape {self.shape} into shape {newshape}."
)
self.parameters["mean"] = np.reshape(self.parameters["mean"], newshape=newshape)
self.parameters["cov"] = np.reshape(self.parameters["cov"], newshape=newshape)
self._shape = newshape
def transpose(self, *axes):
return Normal(
mean=self.mean().copy(),
cov=self.cov().copy(),
random_state=self.random_state,
)

def _reshape_inplace(self, newshape):
raise NotImplementedError
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Inplace reshaping for univariate normals is actually impossible due to the fact that the attributes of numpy's array scalars are immutable. See also https://numpy.org/doc/stable/reference/arrays.scalars.html. We should overthink #130. It would be best if scalar quantities behave exactly the same as arrays. Maybe we should use arrays of shape () internally?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes i think representing everything internally as arrays makes a lot of sense. This removes quite a few annoying conditionals in the code as well.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright, I'll rephrase #130 then?



class _MultivariateNormal(Normal):
Expand Down Expand Up @@ -430,9 +504,6 @@ def sample(self, size=()):
mean=self.mean(), cov=self.cov(), size=size, random_state=self.random_state
)

def reshape(self, newshape):
raise NotImplementedError

# Arithmetic Operations

def __matmul__(self, other):
Expand Down Expand Up @@ -522,16 +593,7 @@ def sample(self, size=()):
size=size,
random_state=self.random_state,
)
return ravelled.reshape(self.shape)

def reshape(self, newshape):
if np.prod(newshape) != np.prod(self.shape):
raise ValueError(
f"Cannot reshape distribution with shape {self.shape} into shape {newshape}."
)
self.parameters["mean"] = np.reshape(self.parameters["mean"], newshape=newshape)
self.parameters["cov"] = np.reshape(self.parameters["cov"], newshape=newshape)
self._shape = newshape
return ravelled.reshape(ravelled.shape[:-1] + self.shape)

# Arithmetic Operations
# TODO: implement special rules for matrix-variate RVs and Kronecker
Expand Down Expand Up @@ -601,6 +663,12 @@ def sample(self, size=()):
def reshape(self, newshape):
raise NotImplementedError

def _reshape_inplace(self, newshape):
raise NotImplementedError

def transpose(self, *axes):
raise NotImplementedError

# Arithmetic Operations

# TODO: implement special rules for matrix-variate RVs and Kronecker structured covariances
Expand Down
30 changes: 27 additions & 3 deletions src/probnum/prob/randomvariable.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,12 @@ def shape(self):
"""Shape of realizations of the random variable."""
return self._shape

@shape.setter
def shape(self, newshape):
self._distribution.shape = newshape

self._set_shape(self._distribution, newshape)

@property
def dtype(self):
"""Data type of (elements of) a realization of this random variable."""
Expand Down Expand Up @@ -203,9 +209,24 @@ def reshape(self, newshape):
-------
reshaped_rv : ``self`` with the new dimensions of ``shape``.
"""
self._shape = newshape
self._distribution.reshape(newshape=newshape)
return self
return RandomVariable(distribution=self._distribution.reshape(newshape))

def transpose(self, *axes):
"""
Transpose the random variable.

Parameters
----------
axes : None, tuple of ints, or n ints
See documentation of numpy.ndarray.transpose.

Returns
-------
transposed_rv : The transposed random variable.
"""
return RandomVariable(distribution=self._distribution.transpose(*axes))

T = property(transpose)

# Binary arithmetic operations

Expand Down Expand Up @@ -238,6 +259,9 @@ def _rv_from_binary_operation(self, other, op):
)
return combined_rv

def __getitem__(self, key):
return RandomVariable(distribution=self.distribution[key])

def __add__(self, other):
return self._rv_from_binary_operation(other=other, op=operator.add)

Expand Down
2 changes: 1 addition & 1 deletion src/probnum/utils/scalarutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@ def as_numpy_scalar(x):
if not np.isscalar(x):
raise ValueError("The given input is not a scalar")

return np.asarray(x).item()
return np.array([x])[0]
Loading