Coverage for src/abm_initialization_collection/image/plot_contact_sheet.py: 22%

65 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2024-07-26 20:12 +0000

1from math import ceil, sqrt 

2from typing import Optional, Union 

3 

4import matplotlib as mpl 

5import matplotlib.pyplot as plt 

6import numpy as np 

7import pandas as pd 

8 

9mpl.use("Agg") 

10mpl.rc("figure", dpi=200) 

11mpl.rc("font", size=8) 

12mpl.rc("axes", titlesize=10, titleweight="bold") 

13 

14 

15def plot_contact_sheet( 

16 data: pd.DataFrame, reference: Optional[pd.DataFrame] = None 

17) -> mpl.figure.Figure: 

18 """ 

19 Plot contact sheet of images for each z slice. 

20 

21 Each z slices is plotted on a separate subplot in the contact sheet. If a 

22 reference is given, the reference is used to determine x and y bounds and 

23 coordinates that only exist in the reference are shown in gray. 

24 

25 Parameters 

26 ---------- 

27 data 

28 Sampled data with x, y, and z coordinates. 

29 reference 

30 Reference data with x, y, and z coordinates. 

31 

32 Returns 

33 ------- 

34 : 

35 Contact sheet figure. 

36 """ 

37 

38 z_layers = sorted(data.z.unique() if reference is None else reference.z.unique()) 

39 

40 n_rows, n_cols, indices = separate_rows_cols(z_layers) 

41 fig, axs = make_subplots(n_rows, n_cols) 

42 

43 max_id = int(data.id.max()) 

44 min_id = int(data.id.min()) 

45 min_x = data.x.min() if reference is None else reference.x.min() 

46 max_x = data.x.max() if reference is None else reference.x.max() 

47 min_y = data.y.min() if reference is None else reference.y.min() 

48 max_y = data.y.max() if reference is None else reference.y.max() 

49 

50 for i, j, k in indices: 

51 ax = select_axes(axs, i, j, n_rows, n_cols) 

52 ax.set_xlim([min_x - 1, max_x + 1]) 

53 ax.set_ylim([max_y + 1, min_y - 1]) 

54 

55 if k is None: 

56 ax.axis("off") 

57 continue 

58 

59 z_slice = data[data.z == z_layers[k]] 

60 

61 patches = [plt.Circle((x, y), radius=0.5) for x, y in zip(z_slice.x, z_slice.y)] 

62 collection = mpl.collections.PatchCollection(patches, cmap="jet") 

63 collection.set_array(z_slice.id) 

64 collection.set_clim([min_id, max_id]) 

65 ax.add_collection(collection) 

66 

67 if reference is not None: 

68 z_slice_reference = reference[reference.z == z_layers[k]] 

69 filtered = pd.merge(z_slice, z_slice_reference, how="outer", indicator=True) 

70 removed = filtered[filtered["_merge"] == "right_only"] 

71 

72 patches = [plt.Circle((x, y), radius=0.3) for x, y in zip(removed.x, removed.y)] 

73 collection = mpl.collections.PatchCollection(patches, facecolor="#ccc") 

74 ax.add_collection(collection) 

75 

76 if np.issubdtype(z_layers[k], np.integer): 

77 ax.set_title(f"z = {z_layers[k]}") 

78 else: 

79 ax.set_title(f"z = {z_layers[k]:.2f}") 

80 

81 ax.set_aspect("equal", adjustable="box") 

82 ax.get_xaxis().set_ticks([]) 

83 ax.get_yaxis().set_ticks([]) 

84 

85 fig.tight_layout() 

86 

87 return fig 

88 

89 

90def separate_rows_cols(items: list[str]) -> tuple[int, int, list[tuple[int, int, Optional[int]]]]: 

91 """ 

92 Separate list of items into approximately equal number of indexed rows and columns. 

93 

94 Parameters 

95 ---------- 

96 items 

97 List of items. 

98 

99 Returns 

100 ------- 

101 : 

102 Number of rows, number of columns, and indices for each item. 

103 """ 

104 

105 n_items = len(items) 

106 n_cols = ceil(sqrt(len(items))) 

107 n_rows = ceil(len(items) / n_cols) 

108 

109 all_indices = [(i, j, i * n_cols + j) for i in range(n_rows) for j in range(n_cols)] 

110 indices = [(i, j, k if k < n_items else None) for i, j, k in all_indices] 

111 

112 return n_rows, n_cols, indices 

113 

114 

115def make_subplots( 

116 n_rows: int, n_cols: int 

117) -> tuple[mpl.figure.Figure, Union[mpl.axes.Axes, np.ndarray]]: 

118 """ 

119 Create subplots for specified number of rows and columns. 

120 

121 Parameters 

122 ---------- 

123 n_rows 

124 Number of rows in contact sheet plot. 

125 n_cols 

126 Number of columns in contact sheet plot. 

127 

128 Returns 

129 ------- 

130 : 

131 Figure and axes objects. 

132 """ 

133 

134 plt.close("all") 

135 fig, axs = plt.subplots(n_rows, n_cols, sharex="all", sharey="all") 

136 return fig, axs 

137 

138 

139def select_axes( 

140 axs: Union[mpl.axes.Axes, np.ndarray], i: int, j: int, n_rows: int, n_cols: int 

141) -> mpl.axes.Axes: 

142 """ 

143 Select the axes object for the given indexed location. 

144 

145 Parameters 

146 ---------- 

147 axs 

148 Axes object. 

149 i 

150 Row index. 

151 j 

152 Column index. 

153 n_rows 

154 Number of rows in contact sheet plot. 

155 n_cols 

156 Number of columns in contact sheet plot. 

157 

158 Returns 

159 ------- 

160 mpl.axes.Axes 

161 Axes object at indexed location. 

162 """ 

163 

164 if n_rows == 1 and n_cols == 1: 

165 return axs 

166 if n_rows == 1: 

167 return axs[j] 

168 

169 return axs[i, j]