Coverage for src/abm_initialization_collection/image/plot_contact_sheet.py: 22%
65 statements
« prev ^ index » next coverage.py v7.1.0, created at 2024-07-26 20:12 +0000
« prev ^ index » next coverage.py v7.1.0, created at 2024-07-26 20:12 +0000
1from math import ceil, sqrt
2from typing import Optional, Union
4import matplotlib as mpl
5import matplotlib.pyplot as plt
6import numpy as np
7import pandas as pd
9mpl.use("Agg")
10mpl.rc("figure", dpi=200)
11mpl.rc("font", size=8)
12mpl.rc("axes", titlesize=10, titleweight="bold")
15def plot_contact_sheet(
16 data: pd.DataFrame, reference: Optional[pd.DataFrame] = None
17) -> mpl.figure.Figure:
18 """
19 Plot contact sheet of images for each z slice.
21 Each z slices is plotted on a separate subplot in the contact sheet. If a
22 reference is given, the reference is used to determine x and y bounds and
23 coordinates that only exist in the reference are shown in gray.
25 Parameters
26 ----------
27 data
28 Sampled data with x, y, and z coordinates.
29 reference
30 Reference data with x, y, and z coordinates.
32 Returns
33 -------
34 :
35 Contact sheet figure.
36 """
38 z_layers = sorted(data.z.unique() if reference is None else reference.z.unique())
40 n_rows, n_cols, indices = separate_rows_cols(z_layers)
41 fig, axs = make_subplots(n_rows, n_cols)
43 max_id = int(data.id.max())
44 min_id = int(data.id.min())
45 min_x = data.x.min() if reference is None else reference.x.min()
46 max_x = data.x.max() if reference is None else reference.x.max()
47 min_y = data.y.min() if reference is None else reference.y.min()
48 max_y = data.y.max() if reference is None else reference.y.max()
50 for i, j, k in indices:
51 ax = select_axes(axs, i, j, n_rows, n_cols)
52 ax.set_xlim([min_x - 1, max_x + 1])
53 ax.set_ylim([max_y + 1, min_y - 1])
55 if k is None:
56 ax.axis("off")
57 continue
59 z_slice = data[data.z == z_layers[k]]
61 patches = [plt.Circle((x, y), radius=0.5) for x, y in zip(z_slice.x, z_slice.y)]
62 collection = mpl.collections.PatchCollection(patches, cmap="jet")
63 collection.set_array(z_slice.id)
64 collection.set_clim([min_id, max_id])
65 ax.add_collection(collection)
67 if reference is not None:
68 z_slice_reference = reference[reference.z == z_layers[k]]
69 filtered = pd.merge(z_slice, z_slice_reference, how="outer", indicator=True)
70 removed = filtered[filtered["_merge"] == "right_only"]
72 patches = [plt.Circle((x, y), radius=0.3) for x, y in zip(removed.x, removed.y)]
73 collection = mpl.collections.PatchCollection(patches, facecolor="#ccc")
74 ax.add_collection(collection)
76 if np.issubdtype(z_layers[k], np.integer):
77 ax.set_title(f"z = {z_layers[k]}")
78 else:
79 ax.set_title(f"z = {z_layers[k]:.2f}")
81 ax.set_aspect("equal", adjustable="box")
82 ax.get_xaxis().set_ticks([])
83 ax.get_yaxis().set_ticks([])
85 fig.tight_layout()
87 return fig
90def separate_rows_cols(items: list[str]) -> tuple[int, int, list[tuple[int, int, Optional[int]]]]:
91 """
92 Separate list of items into approximately equal number of indexed rows and columns.
94 Parameters
95 ----------
96 items
97 List of items.
99 Returns
100 -------
101 :
102 Number of rows, number of columns, and indices for each item.
103 """
105 n_items = len(items)
106 n_cols = ceil(sqrt(len(items)))
107 n_rows = ceil(len(items) / n_cols)
109 all_indices = [(i, j, i * n_cols + j) for i in range(n_rows) for j in range(n_cols)]
110 indices = [(i, j, k if k < n_items else None) for i, j, k in all_indices]
112 return n_rows, n_cols, indices
115def make_subplots(
116 n_rows: int, n_cols: int
117) -> tuple[mpl.figure.Figure, Union[mpl.axes.Axes, np.ndarray]]:
118 """
119 Create subplots for specified number of rows and columns.
121 Parameters
122 ----------
123 n_rows
124 Number of rows in contact sheet plot.
125 n_cols
126 Number of columns in contact sheet plot.
128 Returns
129 -------
130 :
131 Figure and axes objects.
132 """
134 plt.close("all")
135 fig, axs = plt.subplots(n_rows, n_cols, sharex="all", sharey="all")
136 return fig, axs
139def select_axes(
140 axs: Union[mpl.axes.Axes, np.ndarray], i: int, j: int, n_rows: int, n_cols: int
141) -> mpl.axes.Axes:
142 """
143 Select the axes object for the given indexed location.
145 Parameters
146 ----------
147 axs
148 Axes object.
149 i
150 Row index.
151 j
152 Column index.
153 n_rows
154 Number of rows in contact sheet plot.
155 n_cols
156 Number of columns in contact sheet plot.
158 Returns
159 -------
160 mpl.axes.Axes
161 Axes object at indexed location.
162 """
164 if n_rows == 1 and n_cols == 1:
165 return axs
166 if n_rows == 1:
167 return axs[j]
169 return axs[i, j]