Skip to content

Commit

Permalink
Add a style_categories method (#261)
Browse files Browse the repository at this point in the history
  • Loading branch information
jnothman authored Dec 29, 2023
1 parent 19e891e commit da64c81
Show file tree
Hide file tree
Showing 5 changed files with 189 additions and 7 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ What's new in version 0.9
- Fixes a bug where ``show_percentages`` used the incorrect denominator if
filtering (e.g. ``min_subset_size``) was applied. This bug was a regression
introduced in version 0.7. (:issue:`248`)
- Add a ``style_categories`` method to customize category plot styles, including
shading of rows in the intersection matrix, and bars in the totals plot.
(:issue:`261` with thanks to :user:`Marcel Albus <maralbus>`).
- Ability to disable totals plot with `totals_plot_elements=0`. (:issue:`246`)
- Added ``max_subset_rank`` to get only n most populous subsets. (:issue:`253`)

Expand Down
4 changes: 3 additions & 1 deletion examples/plot_highlight.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,9 @@

upset = UpSet(example, facecolor="gray")
upset.style_subsets(present="cat0", label="Contains cat0", facecolor="blue")
upset.style_subsets(present="cat1", label="Contains cat1", hatch="xx")
upset.style_subsets(
present="cat1", label="Contains cat1", hatch="xx", edgecolor="black"
)
upset.style_subsets(present="cat2", label="Contains cat2", edgecolor="red")

# reduce legend size:
Expand Down
41 changes: 41 additions & 0 deletions examples/plot_highlight_categories.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
"""
================================
Highlighting selected categories
================================
Demonstrates use of the `style_categories` method to mark some
categories differently.
"""

from matplotlib import pyplot as plt

from upsetplot import UpSet, generate_counts

example = generate_counts()


##########################################################################
# Categories can be shaded by name with the ``shading_`` parameters.

upset = UpSet(example)
upset.style_categories("cat2", shading_edgecolor="darkgreen", shading_linewidth=1)
upset.style_categories(
"cat1",
shading_facecolor="lavender",
)
upset.plot()
plt.suptitle("Shade or edge a category with color")
plt.show()


##########################################################################
# Category total bars can be styled with the ``bar_`` parameters.
# You can also specify categories using a list of names.

upset = UpSet(example)
upset.style_categories(
["cat2", "cat1"], bar_facecolor="aqua", bar_hatch="xx", bar_edgecolor="black"
)
upset.plot()
plt.suptitle("")
plt.show()
103 changes: 97 additions & 6 deletions upsetplot/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,7 @@ class UpSet:
"""

_default_figsize = (10, 6)
DPI = 100 # standard matplotlib value

def __init__(
self,
Expand Down Expand Up @@ -338,6 +339,7 @@ def __init__(
reverse=not self._horizontal,
include_empty_subsets=include_empty_subsets,
)
self.category_styles = {}
self.subset_styles = [
{"facecolor": facecolor} for i in range(len(self.intersections))
]
Expand Down Expand Up @@ -924,6 +926,16 @@ def plot_totals(self, ax):
)
self._label_sizes(ax, rects, "left" if self._horizontal else "top")

for category, rect in zip(self.totals.index.values, rects):
style = {
k[len("bar_") :]: v
for k, v in self.category_styles.get(category, {}).items()
if k.startswith("bar_")
}
style.setdefault("edgecolor", style.get("facecolor", self._facecolor))
for attr, val in style.items():
getattr(rect, "set_" + attr)(val)

max_total = self.totals.max()
if self._horizontal:
orig_ax.set_xlim(max_total, 0)
Expand All @@ -935,15 +947,34 @@ def plot_totals(self, ax):
ax.patch.set_visible(False)

def plot_shading(self, ax):
# alternating row shading (XXX: use add_patch(Rectangle)?)
for i in range(0, len(self.totals), 2):
# shade all rows, set every second row to zero visibility
for i, category in enumerate(self.totals.index):
default_shading = (
self._shading_color if i % 2 == 0 else (0.0, 0.0, 0.0, 0.0)
)
shading_style = {
k[len("shading_") :]: v
for k, v in self.category_styles.get(category, {}).items()
if k.startswith("shading_")
}

lw = shading_style.get(
"linewidth", 1 if shading_style.get("edgecolor") else 0
)
lw_padding = lw / (self._default_figsize[0] * self.DPI)
start_x = lw_padding
end_x = 1 - lw_padding * 3

rect = plt.Rectangle(
self._swapaxes(0, i - 0.4),
*self._swapaxes(*(1, 0.8)),
facecolor=self._shading_color,
lw=0,
self._swapaxes(start_x, i - 0.4),
*self._swapaxes(end_x, 0.8),
facecolor=shading_style.get("facecolor", default_shading),
edgecolor=shading_style.get("edgecolor", None),
ls=shading_style.get("linestyle", "-"),
lw=lw,
zorder=0,
)

ax.add_patch(rect)
ax.set_frame_on(False)
ax.tick_params(
Expand All @@ -962,6 +993,66 @@ def plot_shading(self, ax):
ax.set_xticklabels([])
ax.set_yticklabels([])

def style_categories(
self,
categories,
*,
bar_facecolor=None,
bar_hatch=None,
bar_edgecolor=None,
bar_linewidth=None,
bar_linestyle=None,
shading_facecolor=None,
shading_edgecolor=None,
shading_linewidth=None,
shading_linestyle=None,
):
"""Updates the style of the categories.
Select a category by name, and style either its total bar or its shading.
.. versionadded:: 0.9
Parameters
----------
categories : str or list[str]
Category names where the changed style applies.
bar_facecolor : str or RGBA matplotlib color tuple, optional.
Override the default facecolor in the totals plot.
bar_hatch : str, optional
Set a hatch for the totals plot.
bar_edgecolor : str or matplotlib color, optional
Set the edgecolor for total bars.
bar_linewidth : int, optional
Line width in points for total bar edges.
bar_linestyle : str, optional
Line style for edges.
shading_facecolor : str or RGBA matplotlib color tuple, optional.
Override the default alternating shading for specified categories.
shading_edgecolor : str or matplotlib color, optional
Set the edgecolor for bars, dots, and the line between dots.
shading_linewidth : int, optional
Line width in points for edges.
shading_linestyle : str, optional
Line style for edges.
"""
if isinstance(categories, str):
categories = [categories]
style = {
"bar_facecolor": bar_facecolor,
"bar_hatch": bar_hatch,
"bar_edgecolor": bar_edgecolor,
"bar_linewidth": bar_linewidth,
"bar_linestyle": bar_linestyle,
"shading_facecolor": shading_facecolor,
"shading_edgecolor": shading_edgecolor,
"shading_linewidth": shading_linewidth,
"shading_linestyle": shading_linestyle,
}
style = {k: v for k, v in style.items() if v is not None}
for category_name in categories:
self.category_styles.setdefault(category_name, {}).update(style)

def plot(self, fig=None):
"""Draw all parts of the plot onto fig or a new figure
Expand Down
45 changes: 45 additions & 0 deletions upsetplot/tests/test_upsetplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -1154,6 +1154,51 @@ def test_style_subsets_artists(orientation):
# matrix_line_collection = upset_axes["matrix"].collections[1]


@pytest.mark.parametrize(
(
"kwarg_list",
"expected_category_styles",
),
[
# Different forms of including two categories
(
[{"categories": ["cat1", "cat2"], "shading_facecolor": "red"}],
{
"cat1": {"shading_facecolor": "red"},
"cat2": {"shading_facecolor": "red"},
},
),
(
[
{"categories": ["cat1", "cat2"], "shading_facecolor": "red"},
{"categories": "cat1", "shading_facecolor": "green"},
],
{
"cat1": {"shading_facecolor": "green"},
"cat2": {"shading_facecolor": "red"},
},
),
(
[
{"categories": ["cat1", "cat2"], "shading_facecolor": "red"},
{"categories": "cat1", "shading_edgecolor": "green"},
],
{
"cat1": {"shading_facecolor": "red", "shading_edgecolor": "green"},
"cat2": {"shading_facecolor": "red"},
},
),
],
)
def test_categories(kwarg_list, expected_category_styles):
data = generate_counts()
upset = UpSet(data, facecolor="blue")
for kw in kwarg_list:
upset.style_categories(**kw)
actual_category_styles = upset.category_styles
assert actual_category_styles == expected_category_styles


def test_many_categories():
# Tests regressions against GH#193
n_cats = 250
Expand Down

0 comments on commit da64c81

Please sign in to comment.