Skip to content

Commit

Permalink
Refactor color utilities
Browse files Browse the repository at this point in the history
  • Loading branch information
tpvasconcelos committed Oct 18, 2024
1 parent ceb1800 commit 61a5202
Show file tree
Hide file tree
Showing 19 changed files with 313 additions and 275 deletions.
4 changes: 2 additions & 2 deletions cicd_utils/ridgeplot_examples/_lincoln_weather.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@

import plotly.graph_objects as go

from ridgeplot._coloring.colormodes import Colormode
from ridgeplot._coloring.colors import Color, ColorScale
from ridgeplot._color.interpolation import Colormode
from ridgeplot._types import Color, ColorScale


def main(
Expand Down
2 changes: 1 addition & 1 deletion src/ridgeplot/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

from __future__ import annotations

from ridgeplot._coloring.colors import list_all_colorscale_names
from ridgeplot._color.colorscale import list_all_colorscale_names
from ridgeplot._ridgeplot import ridgeplot
from ridgeplot._version import __version__

Expand Down
File renamed without changes.
52 changes: 52 additions & 0 deletions src/ridgeplot/_color/colorscale.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Any, cast

import plotly.express as px
from _plotly_utils.basevalidators import ColorscaleValidator as _ColorscaleValidator

from ridgeplot._types import Color, ColorScale

if TYPE_CHECKING:
from collections.abc import Collection


def list_all_colorscale_names() -> list[str]:
"""Get a list with all available colorscale names.
.. versionadded:: 0.1.21
Replaced the deprecated function ``get_all_colorscale_names()``.
Returns
-------
list[str]
A list with all available colorscale names.
"""
# Add 'default' for backwards compatibility
px_colorscales = px.colors.named_colorscales()
return sorted({"default", *px_colorscales, *(f"{name}_r" for name in px_colorscales)})


class ColorscaleValidator(_ColorscaleValidator): # type: ignore[misc]
def __init__(self) -> None:
super().__init__("colorscale", "ridgeplot")

@property
def named_colorscales(self) -> dict[str, list[str]]:
named_colorscales = cast(dict[str, list[str]], super().named_colorscales)
if "default" not in named_colorscales:
# Add 'default' for backwards compatibility
named_colorscales["default"] = px.colors.DEFAULT_PLOTLY_COLORS
return named_colorscales

def validate_coerce(self, v: Any) -> ColorScale:
coerced = super().validate_coerce(v)
if coerced is None:
self.raise_invalid_val(coerced)
return cast(ColorScale, [tuple(c) for c in coerced])


def validate_and_coerce_colorscale(colorscale: ColorScale | Collection[Color] | str) -> ColorScale:
"""Convert mixed colorscale representations to the canonical
:data:`ColorScale` format."""
return ColorscaleValidator().validate_coerce(colorscale)
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -1,22 +1,20 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import TYPE_CHECKING, Literal, Protocol
from typing import TYPE_CHECKING, Literal, Protocol, cast

from ridgeplot._coloring.colors import (
ColorScale,
apply_alpha,
interpolate_color,
round_color,
from plotly import express as px

from ridgeplot._color.colorscale import (
validate_and_coerce_colorscale,
)
from ridgeplot._types import CollectionL2
from ridgeplot._color.utils import apply_alpha, round_color, to_rgb
from ridgeplot._types import CollectionL2, Color, ColorScale
from ridgeplot._utils import get_xy_extrema, normalise_min_max

if TYPE_CHECKING:
from collections.abc import Collection

from ridgeplot._coloring.colors import Color
from ridgeplot._types import Densities, Numeric

Colormode = Literal["row-index", "trace-index", "trace-index-row-wise", "mean-minmax", "mean-means"]
Expand Down Expand Up @@ -119,6 +117,32 @@ def _interpolate_mean_means(ctx: InterpolationContext) -> ColorscaleInterpolants
]


def interpolate_color(colorscale: ColorScale, p: float) -> Color:
"""Get a color from a colorscale at a given interpolation point ``p``."""
if not (0 <= p <= 1):
raise ValueError(
f"The interpolation point 'p' should be a float value between 0 and 1, not {p}."
)
scale = [s for s, _ in colorscale]
colors = [c for _, c in colorscale]
del colorscale
if p in scale:
return colors[scale.index(p)]
colors = [to_rgb(c) for c in colors]
ceil = min(filter(lambda s: s > p, scale))
floor = max(filter(lambda s: s < p, scale))
p_normalised = normalise_min_max(p, min_=floor, max_=ceil)
return cast(
str,
px.colors.find_intermediate_color(
lowcolor=colors[scale.index(floor)],
highcolor=colors[scale.index(ceil)],
intermed=p_normalised,
colortype="rgb",
),
)


def compute_trace_colors(
colorscale: ColorScale | Collection[Color] | str,
colormode: Colormode,
Expand Down
51 changes: 51 additions & 0 deletions src/ridgeplot/_color/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from __future__ import annotations

from typing import TYPE_CHECKING, cast

from plotly import express as px

from ridgeplot._color.css_colors import CSS_NAMED_COLORS, CssNamedColor

if TYPE_CHECKING:
from ridgeplot._types import Color


def to_rgb(color: Color) -> str:
if not isinstance(color, (str, tuple)):
raise TypeError(f"Expected str or tuple for color, got {type(color)} instead.")
if isinstance(color, tuple):
r, g, b = color
rgb = f"rgb({r}, {g}, {b})"
elif color.startswith("#"):
return to_rgb(cast(str, px.colors.hex_to_rgb(color)))
elif color.startswith(("rgb(", "rgba(")):
rgb = color
elif color in CSS_NAMED_COLORS:
color = cast(CssNamedColor, color)
return to_rgb(CSS_NAMED_COLORS[color])
else:
raise ValueError(
f"color should be a tuple or a str representation "
f"of a hex or rgb color, got {color!r} instead."
)
px.colors.validate_colors(rgb)
return rgb


def unpack_rgb(rgb: str) -> tuple[float, float, float, float] | tuple[float, float, float]:
prefix = rgb.split("(")[0] + "("
values_str = map(str.strip, rgb.removeprefix(prefix).removesuffix(")").split(","))
values_num = tuple(int(v) if v.isdecimal() else float(v) for v in values_str)
return values_num # type: ignore[return-value]


def apply_alpha(color: Color, alpha: float) -> str:
values = unpack_rgb(to_rgb(color))
return f"rgba({', '.join(map(str, values[:3]))}, {alpha})"


def round_color(color: Color, ndigits: int) -> str:
color = to_rgb(color)
prefix = color.split("(")[0] + "("
values_round = tuple(v if isinstance(v, int) else round(v, ndigits) for v in unpack_rgb(color))
return f"{prefix}{', '.join(map(str, values_round))})"
146 changes: 0 additions & 146 deletions src/ridgeplot/_coloring/colors.py

This file was deleted.

5 changes: 3 additions & 2 deletions src/ridgeplot/_figure_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,16 @@

from plotly import graph_objects as go

from ridgeplot._coloring.colormodes import (
from ridgeplot._color.interpolation import (
Colormode,
InterpolationContext,
compute_trace_colors,
)
from ridgeplot._types import (
CollectionL1,
CollectionL2,
Color,
ColorScale,
DensityTrace,
is_flat_str_collection,
nest_shallow_collection,
Expand All @@ -28,7 +30,6 @@
if TYPE_CHECKING:
from collections.abc import Collection

from ridgeplot._coloring.colors import Color, ColorScale
from ridgeplot._types import Densities, Numeric


Expand Down
5 changes: 3 additions & 2 deletions src/ridgeplot/_ridgeplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import warnings
from typing import TYPE_CHECKING, cast

from ridgeplot._coloring.colormodes import Colormode
from ridgeplot._color.interpolation import Colormode
from ridgeplot._figure_factory import (
LabelsArray,
ShallowLabelsArray,
Expand All @@ -12,6 +12,8 @@
from ridgeplot._kde import estimate_densities
from ridgeplot._missing import MISSING, MissingType
from ridgeplot._types import (
Color,
ColorScale,
Densities,
Samples,
ShallowDensities,
Expand All @@ -26,7 +28,6 @@

import plotly.graph_objects as go

from ridgeplot._coloring.colors import Color, ColorScale
from ridgeplot._kde import KDEBandwidth, KDEPoints


Expand Down
Loading

0 comments on commit 61a5202

Please sign in to comment.