Skip to content

Commit

Permalink
wip: add trace_types
Browse files Browse the repository at this point in the history
  • Loading branch information
tpvasconcelos committed Oct 16, 2024
1 parent 0044988 commit 048727f
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 13 deletions.
94 changes: 81 additions & 13 deletions src/ridgeplot/_figure_factory.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import TYPE_CHECKING, Literal, cast
from typing import TYPE_CHECKING, Any, Literal, cast

from plotly import graph_objects as go

Expand Down Expand Up @@ -50,6 +50,25 @@
>>> labels_array: ShallowLabelsArray = ["trace 1", "trace 2", "trace 3"]
"""

TraceType = Literal["area", "bar"]
"""The type of trace to draw in a ridgeplot."""

TraceTypesArray = CollectionL2[TraceType]
"""A :data:`TraceTypesArray` represents the types of traces in a ridgeplot.
Example
-------
>>> trace_types_array: TraceTypesArray = [
... ["area", "bar", "area"],
... ["bar", "area"],
... ]
"""

ShallowTraceTypesArray = CollectionL1[TraceType]
"""Shallow type for :data:`TraceTypesArray`.
Example
>>> trace_types_array: ShallowTraceTypesArray = ["area", "bar", "area"]
"""

ColorsArray = CollectionL2[str]
"""A :data:`ColorsArray` represents the colors of traces in a ridgeplot.
Expand Down Expand Up @@ -161,6 +180,7 @@ def _mul(a: tuple[Numeric, ...], b: tuple[Numeric, ...]) -> tuple[Numeric, ...]:
@dataclass
class RidgeplotTrace:
trace: DensityTrace
type: TraceType
label: str
color: str

Expand All @@ -181,6 +201,7 @@ def __init__(
coloralpha: float | None,
colormode: Colormode,
trace_labels: LabelsArray | ShallowLabelsArray | None,
trace_types: TraceTypesArray | ShallowTraceTypesArray | TraceType,
linewidth: float,
spacing: float,
show_yticklabels: bool,
Expand Down Expand Up @@ -215,15 +236,26 @@ def __init__(
else:
if is_flat_str_collection(trace_labels):
trace_labels = cast(ShallowLabelsArray, trace_labels)
trace_labels = cast(LabelsArray, nest_shallow_collection(trace_labels))
trace_labels = nest_shallow_collection(trace_labels)
trace_labels = cast(LabelsArray, trace_labels)
trace_labels = normalise_row_attrs(trace_labels, densities=densities)

if isinstance(trace_types, str):
trace_types = [[trace_types] * len(row) for row in densities]
else:
if is_flat_str_collection(trace_types):
trace_types = cast(ShallowTraceTypesArray, trace_types)
trace_types = nest_shallow_collection(trace_types)
trace_types = cast(TraceTypesArray, trace_types)
trace_types = normalise_row_attrs(trace_types, densities=densities)

self.densities = densities
self.colorscale = colorscale
self.coloralpha = float(coloralpha) if coloralpha is not None else None
self.colormode = colormode
self.trace_labels: LabelsArray = trace_labels
self.y_labels: LabelsArray = [ordered_dedup(row) for row in trace_labels]
self.trace_types: TraceTypesArray = trace_types
self.linewidth = float(linewidth)
self.spacing = float(spacing)
self.show_yticklabels = bool(show_yticklabels)
Expand All @@ -244,13 +276,15 @@ def __init__(
self.rows: list[RidgeplotRow] = [
RidgeplotRow(
traces=[
RidgeplotTrace(trace=trace, label=label, color=color)
for trace, label, color in zip_strict(traces, labels, colors)
RidgeplotTrace(trace=trace, type=trace_type, label=label, color=color)
for trace, label, trace_type, color in zip_strict(
traces, labels, tr_types, colors
)
],
y_shifted=float(-ith_row * self.y_max * self.spacing),
)
for ith_row, (traces, labels, colors) in enumerate(
zip_strict(self.densities, self.trace_labels, self.colors)
for ith_row, (traces, labels, tr_types, colors) in enumerate(
zip_strict(self.densities, self.trace_labels, self.trace_types, self.colors)
)
]

Expand Down Expand Up @@ -288,6 +322,7 @@ def draw_density_trace(
y: Collection[Numeric],
y_shifted: float,
label: str,
trace_type: TraceType,
color: str,
) -> None:
"""Draw a density trace.
Expand All @@ -296,23 +331,49 @@ def draw_density_trace(
fills the trace until the previously drawn trace (see
:meth:`draw_base`). This is why the base trace must be drawn first.
"""
self.draw_base(x=x, y_shifted=y_shifted)
self.fig.add_trace(
go.Scatter(
x=x,
TraceCls = go.Scatter if trace_type == "area" else go.Bar

if trace_type == "area":
self.draw_base(x=x, y_shifted=y_shifted)

kwargs: dict[str, Any]
if trace_type == "area":
kwargs = dict(
y=[y_i + y_shifted for y_i in y],
fillcolor=color,
name=label,
fill="tonexty",
mode="lines",
line=dict(
color="rgba(0,0,0,0.6)" if color is not None else None,
width=self.linewidth,
),
)
else:
kwargs = dict(
y=y,
base=y_shifted,
marker=dict(
color=color,
# TODO: Review these default values for marker_line
line=dict(
# color="rgba(0,0,0,0.6)" if color is not None else None,
# width=self.linewidth,
color="rgba(0,0,0,0.6)",
width=0.4,
),
),
# width=1, # TODO: how to handle this?
)

self.fig.add_trace(
TraceCls(
x=x,
name=label,
# Hover information
customdata=[[y_i] for y_i in y],
hovertemplate=_DEFAULT_HOVERTEMPLATE,
),
**kwargs,
)
)

def update_layout(self) -> None:
Expand All @@ -336,6 +397,8 @@ def update_layout(self) -> None:
showticklabels=True,
**axes_common,
)
# TODO: Review default layout for bar traces...
self.fig.update_layout(barmode="stack", bargap=0, bargroupgap=0)

def _compute_midpoints_row_index(self) -> MidpointsArray:
return [
Expand Down Expand Up @@ -401,7 +464,12 @@ def make_figure(self) -> go.Figure:
for trace in row.traces:
x, y = zip(*trace.trace)
self.draw_density_trace(
x=x, y=y, y_shifted=row.y_shifted, label=trace.label, color=trace.color
x=x,
y=y,
y_shifted=row.y_shifted,
label=trace.label,
trace_type=trace.type,
color=trace.color,
)
self.update_layout()
return self.fig
7 changes: 7 additions & 0 deletions src/ridgeplot/_ridgeplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
LabelsArray,
RidgeplotFigureFactory,
ShallowLabelsArray,
ShallowTraceTypesArray,
TraceType,
TraceTypesArray,
)
from ridgeplot._kde import estimate_densities
from ridgeplot._missing import MISSING, MissingType
Expand Down Expand Up @@ -72,6 +75,7 @@ def ridgeplot(
colormode: Colormode = "mean-minmax",
coloralpha: float | None = None,
labels: LabelsArray | ShallowLabelsArray | None = None,
trace_types: TraceTypesArray | ShallowTraceTypesArray | TraceType = "area",
linewidth: float = 1.0,
spacing: float = 0.5,
show_annotations: bool | MissingType = MISSING,
Expand Down Expand Up @@ -205,6 +209,8 @@ def ridgeplot(
instead, a list of labels is specified, it must be of the same
size/length as the number of traces.
# TODO: Add a note about the `trace_types` argument
linewidth : float
The traces' line width (in px).
Expand Down Expand Up @@ -283,6 +289,7 @@ def ridgeplot(
colorscale=colorscale,
coloralpha=coloralpha,
colormode=colormode,
trace_types=trace_types,
linewidth=linewidth,
spacing=spacing,
show_yticklabels=show_yticklabels,
Expand Down
2 changes: 2 additions & 0 deletions tests/unit/test_figure_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ def test_densities_must_be_4d(self, densities: Densities) -> None:
coloralpha=..., # type: ignore[arg-type]
colormode=..., # type: ignore[arg-type]
trace_labels=..., # type: ignore[arg-type]
trace_types=..., # type: ignore[arg-type]
linewidth=..., # type: ignore[arg-type]
spacing=..., # type: ignore[arg-type]
show_yticklabels=..., # type: ignore[arg-type]
Expand All @@ -122,6 +123,7 @@ def test_float_casting(self) -> None:
colorscale="YlOrRd",
colormode="trace-index",
trace_labels=[["A"], ["B"]],
trace_types="area",
show_yticklabels=True,
# Ensure that the following inputs are cast to float
coloralpha=1,
Expand Down

0 comments on commit 048727f

Please sign in to comment.