-
Notifications
You must be signed in to change notification settings - Fork 10
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix nonstationary unit tests #230
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall looks pretty good, but I suggested a few changes. Mostly I think we want to disallow hierarchical parameters for Anisotropy
for now. We can discuss more with Amanda if you want. I also had some questions about design decisions.
MuyGPyS/gp/deformation/anisotropy.py
Outdated
elif all(isinstance(p, HierarchicalParam) for p in params): | ||
self.length_scale = NamedHierarchicalVectorParam(name, length_scale) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I actually don't think that we want the anisotropic deformation to allow hierarchical parameters for now. Amanda and I have talked about this and concluded that it is probably too complex for now. We might reintroduce this concept later. For now, please remove references to hierarchical parameters from the Anisotropy module.
MuyGPyS/gp/deformation/anisotropy.py
Outdated
length_scale = self.length_scale(**length_scales) | ||
# This is brittle and similar to what we do in Isotropy. | ||
if isinstance(length_scale, mm.ndarray) and len(length_scale.shape) > 0: | ||
shape = [None] * dists.ndim | ||
shape[0] = slice(None) | ||
shape[-1] = slice(None) | ||
length_scale = length_scale.T[tuple(shape)] | ||
return self.metric(dists / length_scale) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similarly, I don't think that we need this if we disallow hierarchical parameters in the Anisotropy deformation.
class NamedHierarchicalVectorParameter(NamedVectorParam): | ||
def __init__(self, name: str, param: VectorParam): | ||
self._params = [ | ||
NamedHierarchicalParameter(name + str(i), p) | ||
for i, p in enumerate(param._params) | ||
] | ||
self._name = name | ||
|
||
def filter_kwargs(self, **kwargs) -> Tuple[Dict, Dict]: | ||
params = { | ||
key: kwargs[key] for key in kwargs if key.startswith(self._name) | ||
} | ||
kwargs = { | ||
key: kwargs[key] for key in kwargs if not key.startswith(self._name) | ||
} | ||
if "batch_features" in kwargs: | ||
for p in self._params: | ||
params.setdefault( | ||
p.name(), p(kwargs["batch_features"], **params) | ||
) | ||
return params, kwargs | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think that we need hierarchical vector parameters at the moment if we remove the Anisotropy use case, but it doesn't hurt to leave this here for now because we might want this in the future.
high_level_kernel, | ||
deformation, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the high_level_kernel
argument does not actually get used in the function body, and only seems to need to exist in the scope of the @parameterized.parameters
decorator, so you should be able to safely remove this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is just to differentiate the tests in the output. It's useful when they fail. Without it, it looks like this:
[ RUN ] HierarchicalNonstationaryHyperparameterTest.test_hierarchical_nonstationary_rbf0 (2, <MuyGPyS.gp.deformation.isotropy.Isotropy>)
[ OK ] HierarchicalNonstationaryHyperparameterTest.test_hierarchical_nonstationary_rbf0 (2, <MuyGPyS.gp.deformation.isotropy.Isotropy>)
[ RUN ] HierarchicalNonstationaryHyperparameterTest.test_hierarchical_nonstationary_rbf1 (2, <MuyGPyS.gp.deformation.isotropy.Isotropy>)
[ OK ] HierarchicalNonstationaryHyperparameterTest.test_hierarchical_nonstationary_rbf1 (2, <MuyGPyS.gp.deformation.isotropy.Isotropy>)
...
Whereas with it, it's a bit nicer:
[ RUN ] HierarchicalNonstationaryHyperparameterTest.test_hierarchical_nonstationary_rbf0 (2, 'RBF', <MuyGPyS.gp.deformation.isotropy.Isotropy>)
[ OK ] HierarchicalNonstationaryHyperparameterTest.test_hierarchical_nonstationary_rbf0 (2, 'RBF', <MuyGPyS.gp.deformation.isotropy.Isotropy>)
[ RUN ] HierarchicalNonstationaryHyperparameterTest.test_hierarchical_nonstationary_rbf1 (2, 'Matern', <MuyGPyS.gp.deformation.isotropy.Isotropy>)
[ OK ] HierarchicalNonstationaryHyperparameterTest.test_hierarchical_nonstationary_rbf1 (2, 'Matern', <MuyGPyS.gp.deformation.isotropy.Isotropy>)
...
tests/experimental/nonstationary.py
Outdated
Anisotropy( | ||
l2, | ||
**{ | ||
f"length_scale{i}": HierarchicalNonstationaryHyperparameter( | ||
knot_features, | ||
knot_values, | ||
high_level_kernel, | ||
) | ||
for i in range(feature_count) | ||
}, | ||
VectorParameter( | ||
*[ | ||
HierarchicalParameter( | ||
knot_features, | ||
knot_values, | ||
high_level_kernel, | ||
) | ||
for _ in range(feature_count) | ||
] | ||
), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should comment this out for now, given what I mentioned earlier in the thread. It is good to know that it works for Anisotropy
, though.
if batch_features is None: | ||
raise TypeError("batch_features keyword argument is required") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If batch_features
is required, why not keep it as a positional argument? I don't think that this change is necessary.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
The tests now pass but I had to replicate several brittle pieces of logic. Open to suggestions on how to make this better.