Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix nonstationary unit tests #230

Merged
merged 6 commits into from
May 23, 2024
Merged

Conversation

igoumiri
Copy link
Contributor

The tests now pass but I had to replicate several brittle pieces of logic. Open to suggestions on how to make this better.

Copy link
Member

@bwpriest bwpriest left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall looks pretty good, but I suggested a few changes. Mostly I think we want to disallow hierarchical parameters for Anisotropy for now. We can discuss more with Amanda if you want. I also had some questions about design decisions.

Comment on lines 49 to 50
elif all(isinstance(p, HierarchicalParam) for p in params):
self.length_scale = NamedHierarchicalVectorParam(name, length_scale)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually don't think that we want the anisotropic deformation to allow hierarchical parameters for now. Amanda and I have talked about this and concluded that it is probably too complex for now. We might reintroduce this concept later. For now, please remove references to hierarchical parameters from the Anisotropy module.

Comment on lines 87 to 94
length_scale = self.length_scale(**length_scales)
# This is brittle and similar to what we do in Isotropy.
if isinstance(length_scale, mm.ndarray) and len(length_scale.shape) > 0:
shape = [None] * dists.ndim
shape[0] = slice(None)
shape[-1] = slice(None)
length_scale = length_scale.T[tuple(shape)]
return self.metric(dists / length_scale)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similarly, I don't think that we need this if we disallow hierarchical parameters in the Anisotropy deformation.

Comment on lines +163 to +184
class NamedHierarchicalVectorParameter(NamedVectorParam):
def __init__(self, name: str, param: VectorParam):
self._params = [
NamedHierarchicalParameter(name + str(i), p)
for i, p in enumerate(param._params)
]
self._name = name

def filter_kwargs(self, **kwargs) -> Tuple[Dict, Dict]:
params = {
key: kwargs[key] for key in kwargs if key.startswith(self._name)
}
kwargs = {
key: kwargs[key] for key in kwargs if not key.startswith(self._name)
}
if "batch_features" in kwargs:
for p in self._params:
params.setdefault(
p.name(), p(kwargs["batch_features"], **params)
)
return params, kwargs

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think that we need hierarchical vector parameters at the moment if we remove the Anisotropy use case, but it doesn't hurt to leave this here for now because we might want this in the future.

Comment on lines 117 to 118
high_level_kernel,
deformation,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the high_level_kernel argument does not actually get used in the function body, and only seems to need to exist in the scope of the @parameterized.parameters decorator, so you should be able to safely remove this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just to differentiate the tests in the output. It's useful when they fail. Without it, it looks like this:

[ RUN      ] HierarchicalNonstationaryHyperparameterTest.test_hierarchical_nonstationary_rbf0 (2, <MuyGPyS.gp.deformation.isotropy.Isotropy>)
[       OK ] HierarchicalNonstationaryHyperparameterTest.test_hierarchical_nonstationary_rbf0 (2, <MuyGPyS.gp.deformation.isotropy.Isotropy>)
[ RUN      ] HierarchicalNonstationaryHyperparameterTest.test_hierarchical_nonstationary_rbf1 (2, <MuyGPyS.gp.deformation.isotropy.Isotropy>)
[       OK ] HierarchicalNonstationaryHyperparameterTest.test_hierarchical_nonstationary_rbf1 (2, <MuyGPyS.gp.deformation.isotropy.Isotropy>)
...

Whereas with it, it's a bit nicer:

[ RUN      ] HierarchicalNonstationaryHyperparameterTest.test_hierarchical_nonstationary_rbf0 (2, 'RBF', <MuyGPyS.gp.deformation.isotropy.Isotropy>)
[       OK ] HierarchicalNonstationaryHyperparameterTest.test_hierarchical_nonstationary_rbf0 (2, 'RBF', <MuyGPyS.gp.deformation.isotropy.Isotropy>)
[ RUN      ] HierarchicalNonstationaryHyperparameterTest.test_hierarchical_nonstationary_rbf1 (2, 'Matern', <MuyGPyS.gp.deformation.isotropy.Isotropy>)
[       OK ] HierarchicalNonstationaryHyperparameterTest.test_hierarchical_nonstationary_rbf1 (2, 'Matern', <MuyGPyS.gp.deformation.isotropy.Isotropy>)
...

Comment on lines 98 to 109
Anisotropy(
l2,
**{
f"length_scale{i}": HierarchicalNonstationaryHyperparameter(
knot_features,
knot_values,
high_level_kernel,
)
for i in range(feature_count)
},
VectorParameter(
*[
HierarchicalParameter(
knot_features,
knot_values,
high_level_kernel,
)
for _ in range(feature_count)
]
),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should comment this out for now, given what I mentioned earlier in the thread. It is good to know that it works for Anisotropy, though.

Comment on lines 111 to 112
if batch_features is None:
raise TypeError("batch_features keyword argument is required")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If batch_features is required, why not keep it as a positional argument? I don't think that this change is necessary.

@igoumiri igoumiri requested a review from bwpriest May 23, 2024 03:58
Copy link
Member

@bwpriest bwpriest left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

@bwpriest bwpriest merged commit 1156bb3 into LLNL:develop May 23, 2024
18 of 21 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants