Skip to content

Commit

Permalink
Improve is_canonical_colorscale and validate_canonical_colorscale
Browse files Browse the repository at this point in the history
  • Loading branch information
tpvasconcelos committed Oct 17, 2024
1 parent 8ad21d4 commit 4d520e9
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 50 deletions.
31 changes: 21 additions & 10 deletions src/ridgeplot/_colors.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,15 +48,6 @@
"""


def _is_canonical_colorscale(
colorscale: ColorScale | Collection[Color] | str,
) -> TypeIs[ColorScale]:
if isinstance(colorscale, str):
return False
shape = get_collection_array_shape(colorscale)
return len(shape) == 2 and shape[1] == 2


def _colormap_loader() -> dict[str, ColorScale]:
colors: dict[str, ColorScale] = json.loads(_PATH_TO_COLORS_JSON.read_text())
for name, colorscale in colors.items():
Expand All @@ -67,9 +58,29 @@ def _colormap_loader() -> dict[str, ColorScale]:
_COLORSCALE_MAPPING: LazyMapping[str, ColorScale] = LazyMapping(loader=_colormap_loader)


def is_canonical_colorscale(
colorscale: ColorScale | Collection[Color] | str,
) -> TypeIs[ColorScale]:
if isinstance(colorscale, str) or not isinstance(colorscale, Collection):
return False
shape = get_collection_array_shape(colorscale)
if not (len(shape) == 2 and shape[1] == 2):
return False
scale, colors = zip(*colorscale)
return (
all(isinstance(s, (int, float)) for s in scale) and
all(isinstance(c, (str, tuple)) for c in colors)
) # fmt: skip


def validate_canonical_colorscale(colorscale: ColorScale) -> None:
"""Validate the structure, scale values, and colors of a colorscale in the
canonical format."""
if not is_canonical_colorscale(colorscale):
raise TypeError(
"The colorscale should be a collection of tuples of two elements: "
"a scale value and a color."
)
scale, colors = zip(*colorscale)
validate_scale_values(scale=scale)
validate_colors(colors=colors)
Expand Down Expand Up @@ -186,7 +197,7 @@ def normalise_colorscale(colorscale: ColorScale | Collection[Color] | str) -> Co
:data:`ColorScale` format."""
if isinstance(colorscale, str):
return get_colorscale(name=colorscale)
if _is_canonical_colorscale(colorscale):
if is_canonical_colorscale(colorscale):
validate_canonical_colorscale(colorscale)
return colorscale
# There is a bug in mypy that results in the type narrowing not working
Expand Down
89 changes: 49 additions & 40 deletions tests/unit/test_colors.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@
ColorScale,
_any_to_rgb,
_colormap_loader,
_is_canonical_colorscale,
apply_alpha,
canonical_colorscale_from_list,
get_colorscale,
interpolate_color,
is_canonical_colorscale,
list_all_colorscale_names,
normalise_colorscale,
validate_canonical_colorscale,
Expand All @@ -38,23 +38,37 @@
(1.0, "rgb(253, 231, 37)"),
)

# ==============================================================
# --- _is_canonical_colorscale()
# ==============================================================


@pytest.mark.parametrize(
("colorscale", "expected"),
VALID_COLORSCALES = [
VIRIDIS,
# tuple of tuples of rgb colors
(
(0.0, "rgb(68, 1, 84)"),
(0.4444444444444444, "rgb(38, 130, 142)"),
(1.0, "rgb(253, 231, 37)"),
),
# list of lists of hex colors
[
(VIRIDIS, True),
(VIRIDIS[0], False),
("viridis", False),
(["red", "blue", "green"], False),
(((0, "red"), (1, "blue")), True),
[0, "#440154"],
[0.5019607843137255, "#21918c"],
[1, "#fde725"],
],
)
def test_is_canonical_colorscale(colorscale: ColorScale | Any, expected: bool) -> None:
assert _is_canonical_colorscale(colorscale) == expected
# Another simple example
((0, "red"), (1, "blue")),
]

INVALID_COLORSCALES = [
# is not collection
(1, TypeError),
("viridis", TypeError),
# is not collection of tuples
((1, 2, 3), TypeError),
(["red", "blue", "green"], TypeError),
(VIRIDIS[0], TypeError),
# inner tuples should have length 2 (for the scale and color values)
(((1, 2, 3), (4, 5, 6)), TypeError),
# Wrong order of scale and color values for the inner tuples
((("a", 1), ("b", 2)), TypeError),
]


# ==============================================================
Expand All @@ -80,52 +94,47 @@ def test_plotly_colorscale_mapping() -> None:


# ==============================================================
# --- validate_colorscale()
# --- is_canonical_colorscale()
# ==============================================================


@pytest.mark.parametrize(
"colorscale",
("colorscale", "expected"),
[
# tuple of tuples of rgb colors
(
(0.0, "rgb(68, 1, 84)"),
(0.4444444444444444, "rgb(38, 130, 142)"),
(1.0, "rgb(253, 231, 37)"),
),
# list of lists of hex colors
[
[0, "#440154"],
[0.5019607843137255, "#21918c"],
[1, "#fde725"],
],
*[(cs, True) for cs in VALID_COLORSCALES],
*[(cs[0], False) for cs in INVALID_COLORSCALES],
],
)
def test_validate_colorscale(colorscale: ColorScale) -> None:
def test_is_canonical_colorscale(colorscale: ColorScale | Any, expected: bool) -> None:
assert is_canonical_colorscale(colorscale) == expected


# ==============================================================
# --- validate_canonical_colorscale()
# ==============================================================


@pytest.mark.parametrize("colorscale", VALID_COLORSCALES)
def test_validate_canonical_colorscale(colorscale: ColorScale) -> None:
validate_canonical_colorscale(colorscale=colorscale)


@pytest.mark.parametrize(
("colorscale", "expected_exception"),
[
# is not collection
(1, TypeError),
# is not collection of tuples
((1, 2, 3), TypeError),
# inner tuples should have length 2 (for the scale and color values)
(((1, 2, 3), (4, 5, 6)), ValueError),
*INVALID_COLORSCALES,
# Invalid scale values: first and last numbers
# in the scale must be 0.0 and 1.0 respectively
((("a", 1), ("b", 2)), PlotlyError),
(((1, "a"), (2, "b")), PlotlyError),
(((1, "a"), (0, "a")), PlotlyError),
],
)
def test_validate_colorscale_fails_for_invalid_colorscale(
def test_validate_canonical_colorscale_fails_for_invalid_colorscale(
colorscale: Any,
expected_exception: type[Exception],
) -> None:
pytest.raises(expected_exception, validate_canonical_colorscale, colorscale=colorscale)
with pytest.raises(expected_exception):
validate_canonical_colorscale(colorscale)


# ==============================================================
Expand Down

0 comments on commit 4d520e9

Please sign in to comment.