-
Notifications
You must be signed in to change notification settings - Fork 10
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix nonstationary unit tests #230
Changes from 2 commits
3d61023
e0e9b47
c41eca6
ba98e4c
ca62dfd
4d22f40
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,7 +8,11 @@ | |
from MuyGPyS._src.util import auto_str | ||
from MuyGPyS.gp.deformation.deformation_fn import DeformationFn | ||
from MuyGPyS.gp.deformation.metric import MetricFn | ||
from MuyGPyS.gp.hyperparameter import VectorParam, NamedVectorParam | ||
from MuyGPyS.gp.hyperparameter import ScalarParam, VectorParam, NamedVectorParam | ||
from MuyGPyS.gp.hyperparameter.experimental import ( | ||
HierarchicalParam, | ||
NamedHierarchicalVectorParam, | ||
) | ||
|
||
|
||
@auto_str | ||
|
@@ -37,8 +41,18 @@ def __init__( | |
metric: MetricFn, | ||
length_scale: VectorParam, | ||
): | ||
name = "length_scale" | ||
params = length_scale._params | ||
# This is brittle and should be refactored | ||
if all(isinstance(p, ScalarParam) for p in params): | ||
self.length_scale = NamedVectorParam(name, length_scale) | ||
elif all(isinstance(p, HierarchicalParam) for p in params): | ||
self.length_scale = NamedHierarchicalVectorParam(name, length_scale) | ||
else: | ||
raise ValueError( | ||
"Expected uniform vector of ScalarParam or HierarchicalParam type for length_scale" | ||
) | ||
self.metric = metric | ||
self.length_scale = NamedVectorParam("length_scale", length_scale) | ||
|
||
def __call__(self, dists: mm.ndarray, **length_scales) -> mm.ndarray: | ||
""" | ||
|
@@ -70,7 +84,14 @@ def __call__(self, dists: mm.ndarray, **length_scales) -> mm.ndarray: | |
f"Difference tensor of shape {dists.shape} must have final " | ||
f"dimension size of {len(self.length_scale)}" | ||
) | ||
return self.metric(dists / self.length_scale(**length_scales)) | ||
length_scale = self.length_scale(**length_scales) | ||
# This is brittle and similar to what we do in Isotropy. | ||
if isinstance(length_scale, mm.ndarray) and len(length_scale.shape) > 0: | ||
shape = [None] * dists.ndim | ||
shape[0] = slice(None) | ||
shape[-1] = slice(None) | ||
length_scale = length_scale.T[tuple(shape)] | ||
return self.metric(dists / length_scale) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Similarly, I don't think that we need this if we disallow hierarchical parameters in the Anisotropy deformation. |
||
|
||
@mpi_chunk(return_count=1) | ||
def pairwise_tensor( | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -24,10 +24,9 @@ class HierarchicalParameter: | |
knot_features: | ||
Tensor of floats of shape `(knot_count, feature_count)` | ||
containing the feature vectors for each knot. | ||
knot_values: | ||
knot_params: | ||
List of scalar hyperparameters of length `knot_count` | ||
containing the initial values and optimization bounds for each knot. | ||
Float values will be converted to fixed scalar hyperparameters. | ||
kernel: | ||
Initialized higher-level GP kernel. | ||
""" | ||
|
@@ -108,7 +107,9 @@ def name(self) -> str: | |
def knot_values(self) -> mm.ndarray: | ||
return self._params() | ||
|
||
def __call__(self, batch_features, **kwargs) -> float: | ||
def __call__(self, batch_features=None, **kwargs) -> float: | ||
if batch_features is None: | ||
raise TypeError("batch_features keyword argument is required") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If |
||
params, kwargs = self._params.filter_kwargs(**kwargs) | ||
solve = mm.linalg.solve( | ||
self._Kin_higher + self._noise() * mm.eye(self._knot_count), | ||
|
@@ -159,6 +160,29 @@ def populate(self, hyperparameters: Dict) -> None: | |
self._params.populate(hyperparameters) | ||
|
||
|
||
class NamedHierarchicalVectorParameter(NamedVectorParam): | ||
def __init__(self, name: str, param: VectorParam): | ||
self._params = [ | ||
NamedHierarchicalParameter(name + str(i), p) | ||
for i, p in enumerate(param._params) | ||
] | ||
self._name = name | ||
|
||
def filter_kwargs(self, **kwargs) -> Tuple[Dict, Dict]: | ||
params = { | ||
key: kwargs[key] for key in kwargs if key.startswith(self._name) | ||
} | ||
kwargs = { | ||
key: kwargs[key] for key in kwargs if not key.startswith(self._name) | ||
} | ||
if "batch_features" in kwargs: | ||
for p in self._params: | ||
params.setdefault( | ||
p.name(), p(kwargs["batch_features"], **params) | ||
) | ||
return params, kwargs | ||
|
||
Comment on lines
+161
to
+182
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think that we need hierarchical vector parameters at the moment if we remove the Anisotropy use case, but it doesn't hurt to leave this here for now because we might want this in the future. |
||
|
||
def sample_knots(feature_count: int, knot_count: int) -> mm.ndarray: | ||
""" | ||
Samples knots from feature matrix. | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,15 +12,16 @@ | |
from MuyGPyS.gp import MuyGPS | ||
from MuyGPyS.gp.kernels import Matern, RBF | ||
from MuyGPyS.gp.deformation import l2, Isotropy, Anisotropy | ||
from MuyGPyS.gp.hyperparameter import ScalarParam | ||
from MuyGPyS.gp.hyperparameter import ( | ||
Parameter, | ||
VectorParameter, | ||
) | ||
from MuyGPyS.gp.hyperparameter.experimental import ( | ||
HierarchicalNonstationaryHyperparameter, | ||
HierarchicalParameter, | ||
NamedHierarchicalParam, | ||
sample_knots, | ||
) | ||
from MuyGPyS.gp.tensors import ( | ||
make_train_tensors, | ||
batch_features_tensor, | ||
) | ||
from MuyGPyS.gp.tensors import batch_features_tensor | ||
from MuyGPyS.neighbors import NN_Wrapper | ||
from MuyGPyS.optimize.batch import sample_batch | ||
|
||
|
@@ -54,24 +55,28 @@ def test_hierarchical_nonstationary_hyperparameter( | |
response_count=1, | ||
) | ||
knot_features = train["input"] | ||
knot_values = train["output"] | ||
knot_values = VectorParameter( | ||
*[Parameter(x) for x in np.squeeze(train["output"])] | ||
) | ||
batch_features = test["input"] | ||
hyp = HierarchicalNonstationaryHyperparameter( | ||
knot_features, | ||
knot_values, | ||
kernel, | ||
hyp = NamedHierarchicalParam( | ||
"custom_param_name", | ||
HierarchicalParameter( | ||
knot_features, | ||
knot_values, | ||
kernel, | ||
), | ||
) | ||
hyperparameters = hyp(batch_features) | ||
_check_ndarray( | ||
self.assertEqual, hyperparameters, mm.ftype, shape=(batch_count, 1) | ||
self.assertEqual, hyperparameters, mm.ftype, shape=(batch_count,) | ||
) | ||
|
||
@parameterized.parameters( | ||
( | ||
( | ||
feature_count, | ||
type(knot_values[0]), | ||
high_level_kernel, | ||
type(high_level_kernel).__name__, | ||
deformation, | ||
) | ||
for feature_count in [2, 17] | ||
|
@@ -80,35 +85,35 @@ def test_hierarchical_nonstationary_hyperparameter( | |
sample_knots(feature_count=feature_count, knot_count=knot_count) | ||
] | ||
for knot_values in [ | ||
np.random.uniform(size=knot_count), | ||
[ScalarParam(i) for i in range(knot_count)], | ||
VectorParameter(*[Parameter(i) for i in range(knot_count)]), | ||
] | ||
for high_level_kernel in [RBF(), Matern()] | ||
for deformation in [ | ||
Isotropy( | ||
l2, | ||
length_scale=HierarchicalNonstationaryHyperparameter( | ||
length_scale=HierarchicalParameter( | ||
knot_features, knot_values, high_level_kernel | ||
), | ||
), | ||
Anisotropy( | ||
l2, | ||
**{ | ||
f"length_scale{i}": HierarchicalNonstationaryHyperparameter( | ||
knot_features, | ||
knot_values, | ||
high_level_kernel, | ||
) | ||
for i in range(feature_count) | ||
}, | ||
VectorParameter( | ||
*[ | ||
HierarchicalParameter( | ||
knot_features, | ||
knot_values, | ||
high_level_kernel, | ||
) | ||
for _ in range(feature_count) | ||
] | ||
), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we should comment this out for now, given what I mentioned earlier in the thread. It is good to know that it works for |
||
), | ||
] | ||
) | ||
) | ||
def test_hierarchical_nonstationary_rbf( | ||
self, | ||
feature_count, | ||
knot_values_type, | ||
high_level_kernel, | ||
deformation, | ||
Comment on lines
117
to
118
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is just to differentiate the tests in the output. It's useful when they fail. Without it, it looks like this:
Whereas with it, it's a bit nicer:
|
||
): | ||
|
@@ -133,7 +138,7 @@ def test_hierarchical_nonstationary_rbf( | |
batch_indices, batch_nn_indices = sample_batch( | ||
nbrs_lookup, batch_count, data_count | ||
) | ||
(_, pairwise_diffs, _, _) = make_train_tensors( | ||
(_, pairwise_diffs, _, _) = muygps.make_train_tensors( | ||
batch_indices, | ||
batch_nn_indices, | ||
data["input"], | ||
|
@@ -142,7 +147,7 @@ def test_hierarchical_nonstationary_rbf( | |
|
||
batch_features = batch_features_tensor(data["input"], batch_indices) | ||
|
||
Kin = muygps.kernel(pairwise_diffs, batch_features) | ||
Kin = muygps.kernel(pairwise_diffs, batch_features=batch_features) | ||
|
||
_check_ndarray( | ||
self.assertEqual, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I actually don't think that we want the anisotropic deformation to allow hierarchical parameters for now. Amanda and I have talked about this and concluded that it is probably too complex for now. We might reintroduce this concept later. For now, please remove references to hierarchical parameters from the Anisotropy module.