Coverage for src/abm_initialization_collection/image/plot_contact_sheet.py: 22%

65 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2024-09-25 19:25 +0000

1from __future__ import annotations 

2 

3from math import ceil, sqrt 

4from typing import TYPE_CHECKING 

5 

6import matplotlib as mpl 

7import matplotlib.pyplot as plt 

8import numpy as np 

9 

10if TYPE_CHECKING: 

11 import pandas as pd 

12 

13mpl.use("Agg") 

14mpl.rc("figure", dpi=200) 

15mpl.rc("font", size=8) 

16mpl.rc("axes", titlesize=10, titleweight="bold") 

17 

18 

19def plot_contact_sheet( 

20 data: pd.DataFrame, reference: pd.DataFrame | None = None 

21) -> mpl.figure.Figure: 

22 """ 

23 Plot contact sheet of images for each z slice. 

24 

25 Each z slices is plotted on a separate subplot in the contact sheet. If a 

26 reference is given, the reference is used to determine x and y bounds and 

27 coordinates that only exist in the reference are shown in gray. 

28 

29 Parameters 

30 ---------- 

31 data 

32 Sampled data with x, y, and z coordinates. 

33 reference 

34 Reference data with x, y, and z coordinates. 

35 

36 Returns 

37 ------- 

38 : 

39 Contact sheet figure. 

40 """ 

41 

42 z_layers = sorted(data.z.unique() if reference is None else reference.z.unique()) 

43 

44 n_rows, n_cols, indices = separate_rows_cols(z_layers) 

45 fig, axs = make_subplots(n_rows, n_cols) 

46 

47 max_id = int(data.id.max()) 

48 min_id = int(data.id.min()) 

49 min_x = data.x.min() if reference is None else reference.x.min() 

50 max_x = data.x.max() if reference is None else reference.x.max() 

51 min_y = data.y.min() if reference is None else reference.y.min() 

52 max_y = data.y.max() if reference is None else reference.y.max() 

53 

54 for i, j, k in indices: 

55 ax = select_axes(axs, i, j, n_rows, n_cols) 

56 ax.set_xlim([min_x - 1, max_x + 1]) 

57 ax.set_ylim([max_y + 1, min_y - 1]) 

58 

59 if k is None: 

60 ax.axis("off") 

61 continue 

62 

63 z_slice = data[data.z == z_layers[k]] 

64 

65 patches = [plt.Circle((x, y), radius=0.5) for x, y in zip(z_slice.x, z_slice.y)] 

66 collection = mpl.collections.PatchCollection(patches, cmap="jet") 

67 collection.set_array(z_slice.id) 

68 collection.set_clim([min_id, max_id]) 

69 ax.add_collection(collection) 

70 

71 if reference is not None: 

72 z_slice_reference = reference[reference.z == z_layers[k]] 

73 filtered = z_slice.merge(z_slice_reference, how="outer", indicator=True) 

74 removed = filtered[filtered["_merge"] == "right_only"] 

75 

76 patches = [plt.Circle((x, y), radius=0.3) for x, y in zip(removed.x, removed.y)] 

77 collection = mpl.collections.PatchCollection(patches, facecolor="#ccc") 

78 ax.add_collection(collection) 

79 

80 if np.issubdtype(z_layers[k], np.integer): 

81 ax.set_title(f"z = {z_layers[k]}") 

82 else: 

83 ax.set_title(f"z = {z_layers[k]:.2f}") 

84 

85 ax.set_aspect("equal", adjustable="box") 

86 ax.get_xaxis().set_ticks([]) 

87 ax.get_yaxis().set_ticks([]) 

88 

89 fig.tight_layout() 

90 

91 return fig 

92 

93 

94def separate_rows_cols(items: list[str]) -> tuple[int, int, list[tuple[int, int, int | None]]]: 

95 """ 

96 Separate list of items into approximately equal number of indexed rows and columns. 

97 

98 Parameters 

99 ---------- 

100 items 

101 List of items. 

102 

103 Returns 

104 ------- 

105 : 

106 Number of rows, number of columns, and indices for each item. 

107 """ 

108 

109 n_items = len(items) 

110 n_cols = ceil(sqrt(len(items))) 

111 n_rows = ceil(len(items) / n_cols) 

112 

113 all_indices = [(i, j, i * n_cols + j) for i in range(n_rows) for j in range(n_cols)] 

114 indices = [(i, j, k if k < n_items else None) for i, j, k in all_indices] 

115 

116 return n_rows, n_cols, indices 

117 

118 

119def make_subplots(n_rows: int, n_cols: int) -> tuple[mpl.figure.Figure, mpl.axes.Axes | np.ndarray]: 

120 """ 

121 Create subplots for specified number of rows and columns. 

122 

123 Parameters 

124 ---------- 

125 n_rows 

126 Number of rows in contact sheet plot. 

127 n_cols 

128 Number of columns in contact sheet plot. 

129 

130 Returns 

131 ------- 

132 : 

133 Figure and axes objects. 

134 """ 

135 

136 plt.close("all") 

137 fig, axs = plt.subplots(n_rows, n_cols, sharex="all", sharey="all") 

138 return fig, axs 

139 

140 

141def select_axes( 

142 axs: mpl.axes.Axes | np.ndarray, i: int, j: int, n_rows: int, n_cols: int 

143) -> mpl.axes.Axes: 

144 """ 

145 Select the axes object for the given indexed location. 

146 

147 Parameters 

148 ---------- 

149 axs 

150 Axes object. 

151 i 

152 Row index. 

153 j 

154 Column index. 

155 n_rows 

156 Number of rows in contact sheet plot. 

157 n_cols 

158 Number of columns in contact sheet plot. 

159 

160 Returns 

161 ------- 

162 mpl.axes.Axes 

163 Axes object at indexed location. 

164 """ 

165 

166 if n_rows == 1 and n_cols == 1: 

167 return axs 

168 if n_rows == 1: 

169 return axs[j] 

170 

171 return axs[i, j]