Coverage for src/abm_initialization_collection/sample/remove_unconnected_regions.py: 100%

61 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2024-07-26 20:12 +0000

1import numpy as np 

2import pandas as pd 

3from scipy.spatial import distance 

4from skimage import measure 

5 

6 

7def remove_unconnected_regions( 

8 samples: pd.DataFrame, unconnected_threshold: float, unconnected_filter: str 

9) -> pd.DataFrame: 

10 """ 

11 Removes unconnected regions. 

12 

13 Parameters 

14 ---------- 

15 samples 

16 Sample cell ids and coordinates. 

17 threshold 

18 Distance for removing unconnected regions. 

19 filter 

20 Filter type for assigning unconnected coordinates. 

21 

22 Returns 

23 ------- 

24 : 

25 Samples with unconnected regions removed. 

26 """ 

27 

28 if unconnected_filter == "connectivity": 

29 return remove_unconnected_by_connectivity(samples) 

30 

31 if unconnected_filter == "distance": 

32 return remove_unconnected_by_distance(samples, unconnected_threshold) 

33 

34 raise ValueError(f"invalid filter type {unconnected_filter}") 

35 

36 

37def remove_unconnected_by_connectivity(samples: pd.DataFrame) -> pd.DataFrame: 

38 """ 

39 Removes unconnected regions based on simple connectivity. 

40 

41 Parameters 

42 ---------- 

43 samples 

44 Sample cell ids and coordinates. 

45 

46 Returns 

47 ------- 

48 : 

49 Samples with unconnected regions removed. 

50 """ 

51 

52 minimums = get_sample_minimums(samples) 

53 maximums = get_sample_maximums(samples) 

54 

55 array = convert_to_integer_array(samples, minimums, maximums) 

56 

57 array_connected = np.zeros(array.shape, dtype="int") 

58 labels = measure.label(array, connectivity=1) 

59 

60 # Sort labeled regions by size. 

61 regions = np.bincount(labels.flatten())[1:] 

62 regions_sorted = sorted( 

63 [(i + 1, n) for i, n in enumerate(regions)], 

64 key=lambda tup: tup[1], 

65 reverse=True, 

66 ) 

67 

68 # Iterate through all regions and copy the largest connected region to array. 

69 ids_added = set() 

70 for index, _ in regions_sorted: 

71 cell_id = list(set(array[labels == index]))[0] 

72 

73 if cell_id not in ids_added: 

74 array_connected[labels == index] = cell_id 

75 ids_added.add(cell_id) 

76 else: 

77 print(f"Skipping unconnected region for cell id {cell_id}") 

78 

79 # Convert back to dataframe. 

80 samples_connected = convert_to_dataframe(array_connected, minimums) 

81 return samples_connected.sort_values(by=["id", "x", "y", "z"]).reset_index(drop=True) 

82 

83 

84def remove_unconnected_by_distance(samples: pd.DataFrame, threshold: float) -> pd.DataFrame: 

85 """ 

86 Removes unconnected regions based on distance. 

87 

88 Parameters 

89 ---------- 

90 samples 

91 Sample cell ids and coordinates. 

92 threshold 

93 Distance for removing unconnected regions. 

94 

95 Returns 

96 ------- 

97 : 

98 Samples with unconnected regions removed. 

99 """ 

100 

101 all_connected: list = [] 

102 

103 # Iterate through each id and filter out samples above the distance threshold. 

104 for cell_id, group in samples.groupby("id"): 

105 coordinates = group[["x", "y", "z"]].to_numpy() 

106 distances = [ 

107 get_minimum_distance(np.array([coordinate]), coordinates) for coordinate in coordinates 

108 ] 

109 connected = [ 

110 (cell_id, x, y, z) 

111 for distance, (x, y, z) in zip(distances, coordinates) 

112 if distance < threshold 

113 ] 

114 all_connected = all_connected + connected 

115 

116 # Convert back to dataframe. 

117 samples_connected = pd.DataFrame(all_connected, columns=["id", "x", "y", "z"]) 

118 return samples_connected.sort_values(by=["id", "x", "y", "z"]).reset_index(drop=True) 

119 

120 

121def get_sample_minimums(samples: pd.DataFrame) -> tuple[int, int, int]: 

122 """ 

123 Gets minimums in x, y, and z directions for samples. 

124 

125 Parameters 

126 ---------- 

127 samples 

128 Sample cell ids and coordinates. 

129 

130 Returns 

131 ------- 

132 Tuple of minimums. 

133 """ 

134 

135 min_x = min(samples.x) 

136 min_y = min(samples.y) 

137 min_z = min(samples.z) 

138 minimums = (min_x, min_y, min_z) 

139 return minimums 

140 

141 

142def get_sample_maximums(samples: pd.DataFrame) -> tuple[int, int, int]: 

143 """ 

144 Gets maximums in x, y, and z directions for samples. 

145 

146 Parameters 

147 ---------- 

148 samples 

149 Sample cell ids and coordinates. 

150 

151 Returns 

152 ------- 

153 Tuple of maximums. 

154 """ 

155 

156 max_x = max(samples.x) 

157 max_y = max(samples.y) 

158 max_z = max(samples.z) 

159 maximums = (max_x, max_y, max_z) 

160 return maximums 

161 

162 

163def convert_to_integer_array( 

164 samples: pd.DataFrame, 

165 minimums: tuple[int, int, int], 

166 maximums: tuple[int, int, int], 

167) -> np.ndarray: 

168 """ 

169 Converts ids and coordinate samples to integer array. 

170 

171 Parameters 

172 ---------- 

173 samples 

174 Sample cell ids and coordinates. 

175 minimums 

176 Minimums in x, y, and z directions. 

177 maximums 

178 Maximums in x, y, and z directions. 

179 

180 Returns 

181 ------- 

182 : 

183 Array of ids. 

184 """ 

185 

186 length, width, height = np.subtract(maximums, minimums).astype("int32") 

187 array = np.zeros((height + 1, width + 1, length + 1), dtype="int32") 

188 

189 coordinates = samples[["x", "y", "z"]].values - minimums 

190 array[tuple(np.transpose(np.flip(coordinates, axis=1)))] = samples.id 

191 

192 return array 

193 

194 

195def convert_to_dataframe(array: np.ndarray, minimums: tuple[int, int, int]) -> pd.DataFrame: 

196 """ 

197 Converts integer array to ids and coordinate samples. 

198 

199 Parameters 

200 ---------- 

201 array 

202 Integer array of ids. 

203 minimums 

204 Minimums in x, y, and z directions. 

205 

206 Returns 

207 ------- 

208 : 

209 Dataframe of ids and coordinates. 

210 """ 

211 

212 min_x, min_y, min_z = minimums 

213 

214 samples = [ 

215 (array[z, y, x], x + min_x, y + min_y, z + min_z) for z, y, x in zip(*np.where(array != 0)) 

216 ] 

217 

218 return pd.DataFrame(samples, columns=["id", "x", "y", "z"]) 

219 

220 

221def get_minimum_distance(source: np.ndarray, targets: np.ndarray) -> float: 

222 """ 

223 Get the minimum distance from point to array of points. 

224 

225 Parameters 

226 ---------- 

227 source 

228 Coordinates of source point with shape (1, 3) 

229 targets 

230 Coordinates for N target points with shape (3, N) 

231 

232 Returns 

233 ------- 

234 : 

235 Minimum distance between source and targets. 

236 """ 

237 

238 distances = distance.cdist(source, targets) 

239 return np.min(distances[distances != 0])