Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FIX: Enable str transformations to Color/Interval scales #3688

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions seaborn/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,18 @@ def __call__(self, value, clip=None):
return new_norm


def is_registered_colormap(name):
"""Handle changes to matplotlib colormap interface in 3.5."""
if _version_predates(mpl, "3.5"):
try:
mpl.cm.get_cmap(name)
return True
except ValueError:
return False
else:
return name in mpl.colormaps


def get_colormap(name):
"""Handle changes to matplotlib colormap interface in 3.6."""
try:
Expand Down
18 changes: 13 additions & 5 deletions seaborn/_core/properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from seaborn._core.rules import categorical_order, variable_type
from seaborn.palettes import QUAL_PALETTES, color_palette, blend_palette
from seaborn.utils import get_color_cycle
from seaborn._compat import is_registered_colormap

from typing import Any, Callable, Tuple, List, Union, Optional

Expand Down Expand Up @@ -43,6 +44,7 @@

class Property:
"""Base class for visual properties that can be set directly or be data scaling."""
_TRANS_ARGS = ["log", "symlog", "logit", "pow", "sqrt"]

# When True, scales for this property will populate the legend by default
legend = False
Expand Down Expand Up @@ -76,9 +78,8 @@ def infer_scale(self, arg: Any, data: Series) -> Scale:
# (e.g. color). How best to handle that? One option is to call super after
# handling property-specific possibilities (e.g. for color check that the
# arg is not a valid palette name) but that could get tricky.
trans_args = ["log", "symlog", "logit", "pow", "sqrt"]
if isinstance(arg, str):
if any(arg.startswith(k) for k in trans_args):
if any(arg.startswith(k) for k in self._TRANS_ARGS):
# TODO validate numeric type? That should happen centrally somewhere
return Continuous(trans=arg)
else:
Expand Down Expand Up @@ -183,6 +184,8 @@ def infer_scale(self, arg: Any, data: Series) -> Scale:
return Nominal(arg)
elif var_type == "datetime":
return Temporal(arg)
elif isinstance(arg, str) and any(arg.startswith(k) for k in self._TRANS_ARGS):
return Continuous(trans=arg)
# TODO other variable types
else:
return Continuous(arg)
Expand Down Expand Up @@ -607,8 +610,6 @@ def infer_scale(self, arg: Any, data: Series) -> Scale:
if callable(arg):
return Continuous(arg)

# TODO Do we accept str like "log", "pow", etc. for semantics?

if not isinstance(arg, str):
msg = " ".join([
f"A single scale argument for {self.variable} variables must be",
Expand All @@ -619,7 +620,14 @@ def infer_scale(self, arg: Any, data: Series) -> Scale:
if arg in QUAL_PALETTES:
return Nominal(arg)
elif var_type == "numeric":
return Continuous(arg)
# Prioritize actual colormaps, e.g. if a colormap named "pow" exists
if is_registered_colormap(arg):
return Continuous(arg)
elif any(arg.startswith(k) for k in self._TRANS_ARGS):
Copy link
Contributor

@thuiop thuiop May 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be startswith or == ? With the current code, passing "logfoo" for instance would redirect it to trans; is that what we want ? Same comment for other parts where this happens.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the comment. I didn't want to change this since it is similar to how it's performed elsewhere in this module and in other parts of the codebase (_core.scales for example). I reckon the original intention in using startswith was to support "log2", "symlog100", etc. Anyway, passing "logfoo" would fail even prior to this PR (with ValueError: 'logfoo' is not a valid palette name).

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1, the prefix checks allow the "string parameterized scale transform" spelling but maybe that needs better error checking.

return Continuous(trans=arg)
else:
return Continuous(arg)

# TODO implement scales for date variables and any others.
else:
return Nominal(arg)
Expand Down
22 changes: 22 additions & 0 deletions tests/_core/test_properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,17 @@ def test_inference(self, values, data_type, scale_class, vectors):
assert isinstance(scale, scale_class)
assert scale.values == values

@pytest.mark.parametrize(
"trans",
["pow", "sqrt", "log", "symlog", "logit", "log2", "symlog100"]
)
def test_inference_magic_args(self, trans, num_vector):

scale = Color().infer_scale(trans, num_vector)
assert isinstance(scale, Continuous)
assert scale.trans == trans
assert scale.values is None

def test_standardization(self):

f = Color().standardize
Expand Down Expand Up @@ -531,6 +542,17 @@ def test_mapped_interval_categorical(self, cat_vector):
n = cat_vector.nunique()
assert_array_equal(mapping([n - 1, 0]), self.prop().default_range)

@pytest.mark.parametrize(
"trans",
["pow", "sqrt", "log", "symlog", "log13", "logit", "symlog37"]
)
def test_inference_magic_args(self, trans, num_vector):

scale = self.prop().infer_scale(trans, num_vector)
assert isinstance(scale, Continuous)
assert scale.trans == trans
assert scale.values is None

def test_bad_scale_values_numeric_data(self, num_vector):

prop_name = self.prop.__name__.lower()
Expand Down
Loading