From 5e73f820bce66bf7fb712dcece6d9ef0d13bab0e Mon Sep 17 00:00:00 2001 From: colganwi Date: Fri, 27 Mar 2026 09:32:16 -0400 Subject: [PATCH 1/2] feat(pl): add hex color passthrough for branches and nodes When a color attribute's values all start with '#' and are valid hex color codes, _get_colors now passes them through directly to matplotlib instead of routing through the categorical palette or numeric colormap. This lets users pre-assign per-cell/per-branch colors and have them rendered verbatim. Co-Authored-By: Claude Sonnet 4.6 --- src/pycea/pl/_utils.py | 7 ++++++- tests/test_plot_tree.py | 31 +++++++++++++++++++++++++++++++ 2 files changed, 37 insertions(+), 1 deletion(-) diff --git a/src/pycea/pl/_utils.py b/src/pycea/pl/_utils.py index 885253d..30a7a46 100755 --- a/src/pycea/pl/_utils.py +++ b/src/pycea/pl/_utils.py @@ -368,7 +368,12 @@ def _get_colors( """Get colors for plotting.""" if len(data) == 0: raise ValueError(f"Key {key!r} is not present in any edge.") - if data.dtype.kind in ["i", "f"]: # Numeric + non_na = data.dropna() + if len(non_na) > 0 and non_na.apply(lambda v: isinstance(v, str) and v.startswith("#") and mcolors.is_color_like(v)).all(): # Hex passthrough + colors = [data[i] if (i in data.index and pd.notna(data.at[i])) else na_color for i in indicies] + legend = {} + n_categories = 0 + elif data.dtype.kind in ["i", "f"]: # Numeric norm = _get_norm(vmin=vmin, vmax=vmax, data=data) color_map = plt.get_cmap(cmap) # Vectorized: reindex to align with indicies (NaN for missing), then apply colormap in bulk diff --git a/tests/test_plot_tree.py b/tests/test_plot_tree.py index fb8f251..3ccce39 100755 --- a/tests/test_plot_tree.py +++ b/tests/test_plot_tree.py @@ -165,5 +165,36 @@ def test_annotation_bad_input(tdata): plt.close() +def test_hex_color_branches(tdata): + """Branches colored by a per-edge hex attribute use the raw hex values directly.""" + import matplotlib.colors as mcolors + hex_colors = {"1": "#e41a1c", "2": "#377eb8"} + for tree_key, tree in tdata.obst.items(): + for u, v, data in tree.edges(data=True): + data["hex_color"] = hex_colors[tree_key] + fig, ax = plt.subplots() + pycea.pl.branches(tdata, color="hex_color", depth_key="time", ax=ax) + edge_colors = ax.collections[0].get_colors() + expected = {mcolors.to_rgba(c) for c in hex_colors.values()} + actual = {tuple(row) for row in edge_colors} + assert actual == expected + plt.close() + + +def test_hex_color_nodes(tdata): + """Nodes colored by a per-node hex attribute use the raw hex values directly.""" + import matplotlib.colors as mcolors + hex_color = "#4daf4a" + for node, data in tdata.obst["1"].nodes(data=True): + data["hex_color"] = hex_color + fig, ax = plt.subplots() + pycea.pl.branches(tdata, tree="1", depth_key="time", ax=ax) + pycea.pl.nodes(tdata, nodes="leaves", color="hex_color", ax=ax) + node_colors = ax.collections[1].get_facecolors() + expected = mcolors.to_rgba(hex_color) + assert all(tuple(row) == expected for row in node_colors) + plt.close() + + if __name__ == "__main__": pytest.main(["-v", __file__]) From 35c776a39bce33ce339d809bf3911ff8d38eced0 Mon Sep 17 00:00:00 2001 From: colganwi Date: Fri, 27 Mar 2026 12:54:51 -0400 Subject: [PATCH 2/2] feat(pl): add outline_width to nodes and fix annotation vmin/vmax with label=False - pl.nodes: add outline_width parameter (default None) that draws a black border around visible nodes. Uses per-node edgecolors so transparent na nodes never receive a visible outline. - pl.annotation: fix IndexError when label=False (labels=[]) is combined with vmin+vmax (which sets share_cmap=True) or an is_array key. The guards `if labels is not None` now use `if labels:` so an empty list correctly falls through to the auto-label logic. Co-Authored-By: Claude Sonnet 4.6 --- src/pycea/pl/plot_tree.py | 21 ++++++++++++++++-- tests/test_plot_tree.py | 45 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 64 insertions(+), 2 deletions(-) diff --git a/src/pycea/pl/plot_tree.py b/src/pycea/pl/plot_tree.py index d5d2926..7be2ee2 100644 --- a/src/pycea/pl/plot_tree.py +++ b/src/pycea/pl/plot_tree.py @@ -237,6 +237,7 @@ def nodes( na_color: str = "#FFFFFF00", na_style: str = "none", na_size: float = 0, + outline_width: float | None = None, slot: Literal["obst", "obs", "X"] | None = None, ax: Axes | None = None, legend_kwargs: dict[str, Any] | None = None, @@ -279,6 +280,9 @@ def nodes( The marker to use for annotations with missing data. na_size The size to use for annotations with missing data. + outline_width + Width of a black outline drawn around each node marker. ``None`` (default) + draws no outline. ax A matplotlib axes object. If `None`, a new figure and axes will be created. slot @@ -416,6 +420,19 @@ def nodes( legends.append(_categorical_legend(style, marker_map=marker_map, type="marker")) else: raise ValueError("Invalid style value. Must be a marker name, or an str specifying an attribute of the nodes.") + # Apply outline + if outline_width is not None: + def _outline_edgecolors(face_colors): + if isinstance(face_colors, str): + return "black" + rgba = mcolors.to_rgba_array(face_colors) + return ["black" if a > 0 else "none" for a in rgba[:, 3]] + + kwargs.setdefault("edgecolors", _outline_edgecolors(kwargs.get("color"))) + kwargs.setdefault("linewidths", outline_width) + for kw in kwargs_list: + kw.setdefault("edgecolors", _outline_edgecolors(kw.get("color"))) + kw.setdefault("linewidths", outline_width) # Plot if len(kwargs_list) > 0: for kwargs in kwargs_list: @@ -576,7 +593,7 @@ def annotation( legends = [] max_categories = 0 if is_array: # single cmap for all columns - label = labels[0] if labels is not None else keys[0] + label = labels[0] if labels else keys[0] if is_square: data = data.loc[leaves, list(reversed(leaves))] end_lat = start_lat + attrs["depth"] * arc_span_rad * width / 0.05 @@ -609,7 +626,7 @@ def annotation( legends.append(_cbar_legend(label, color_map, norm)) # Add shared cmap if share_cmap and norm is not None: - if labels is not None: + if labels: label = labels[0] elif is_array: label = keys[0] diff --git a/tests/test_plot_tree.py b/tests/test_plot_tree.py index 3ccce39..7a1e6e6 100755 --- a/tests/test_plot_tree.py +++ b/tests/test_plot_tree.py @@ -165,6 +165,17 @@ def test_annotation_bad_input(tdata): plt.close() +def test_annotation_vmin_vmax_label_false(tdata): + """annotation must not raise when label=False and vmax/vmin are given.""" + fig, ax = plt.subplots() + pycea.pl.branches(tdata, depth_key="time", ax=ax) + # vmax only + pycea.pl.annotation(tdata, keys="x", label=False, legend=False, vmax=1.0, ax=ax) + # vmin + vmax (triggers share_cmap=True, which previously hit labels[0] on empty list) + pycea.pl.annotation(tdata, keys="x", label=False, legend=False, vmin=0.0, vmax=1.0, ax=ax) + plt.close() + + def test_hex_color_branches(tdata): """Branches colored by a per-edge hex attribute use the raw hex values directly.""" import matplotlib.colors as mcolors @@ -196,5 +207,39 @@ def test_hex_color_nodes(tdata): plt.close() +def test_nodes_outline_width(tdata): + """outline_width draws a black edge only around visible (non-na) nodes.""" + import matplotlib.colors as mcolors + # Default (no outline): does not error + fig, ax = plt.subplots() + pycea.pl.branches(tdata, depth_key="time", ax=ax) + pycea.pl.nodes(tdata, nodes="internal", ax=ax) + plt.close() + + # With outline: visible nodes get black edge at given width + fig, ax = plt.subplots() + pycea.pl.branches(tdata, depth_key="time", ax=ax) + pycea.pl.nodes(tdata, nodes="internal", outline_width=1.5, ax=ax) + sc = ax.collections[1] + assert all(mcolors.to_rgba(c) == (0.0, 0.0, 0.0, 1.0) for c in sc.get_edgecolors()) + assert all(w == pytest.approx(1.5) for w in sc.get_linewidths()) + plt.close() + + # na nodes (alpha=0) must not get a black outline + fig, ax = plt.subplots() + pycea.pl.branches(tdata, depth_key="time", ax=ax) + # color="clade" → some nodes will be missing data and get na_color="#FFFFFF00" + pycea.pl.nodes(tdata, nodes="all", color="clade", outline_width=1.0, ax=ax) + sc = ax.collections[1] + face_rgba = mcolors.to_rgba_array(sc.get_facecolors()) + edge_rgba = mcolors.to_rgba_array(sc.get_edgecolors()) + for face, edge in zip(face_rgba, edge_rgba): + if face[3] == 0: # transparent (na) node + assert edge[3] == 0, "na nodes must not have a visible outline" + else: + assert tuple(edge) == (0.0, 0.0, 0.0, 1.0), "visible nodes must have a black outline" + plt.close() + + if __name__ == "__main__": pytest.main(["-v", __file__])