Coverage for src/abm_initialization_collection/sample/remove_unconnected_regions.py: 100%

59 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2024-09-25 19:25 +0000

1import numpy as np 

2import pandas as pd 

3from scipy.spatial import distance 

4from skimage import measure 

5 

6 

7def remove_unconnected_regions( 

8 samples: pd.DataFrame, unconnected_threshold: float, unconnected_filter: str 

9) -> pd.DataFrame: 

10 """ 

11 Remove unconnected regions. 

12 

13 Parameters 

14 ---------- 

15 samples 

16 Sample cell ids and coordinates. 

17 unconnected_threshold 

18 Distance for removing unconnected regions. 

19 unconnected_filter 

20 Filter type for assigning unconnected coordinates. 

21 

22 Returns 

23 ------- 

24 : 

25 Samples with unconnected regions removed. 

26 """ 

27 

28 if unconnected_filter == "connectivity": 

29 return remove_unconnected_by_connectivity(samples) 

30 

31 if unconnected_filter == "distance": 

32 return remove_unconnected_by_distance(samples, unconnected_threshold) 

33 

34 message = f"invalid filter type {unconnected_filter}" 

35 raise ValueError(message) 

36 

37 

38def remove_unconnected_by_connectivity(samples: pd.DataFrame) -> pd.DataFrame: 

39 """ 

40 Remove unconnected regions based on simple connectivity. 

41 

42 Parameters 

43 ---------- 

44 samples 

45 Sample cell ids and coordinates. 

46 

47 Returns 

48 ------- 

49 : 

50 Samples with unconnected regions removed. 

51 """ 

52 

53 minimums = get_sample_minimums(samples) 

54 maximums = get_sample_maximums(samples) 

55 

56 array = convert_to_integer_array(samples, minimums, maximums) 

57 

58 array_connected = np.zeros(array.shape, dtype="int") 

59 labels = measure.label(array, connectivity=1) 

60 

61 # Sort labeled regions by size. 

62 regions = np.bincount(labels.flatten())[1:] 

63 regions_sorted = sorted( 

64 [(i + 1, n) for i, n in enumerate(regions)], 

65 key=lambda tup: tup[1], 

66 reverse=True, 

67 ) 

68 

69 # Iterate through all regions and copy the largest connected region to array. 

70 ids_added = set() 

71 for index, _ in regions_sorted: 

72 cell_id = next(iter(set(array[labels == index]))) 

73 

74 if cell_id not in ids_added: 

75 array_connected[labels == index] = cell_id 

76 ids_added.add(cell_id) 

77 

78 # Convert back to dataframe. 

79 samples_connected = convert_to_dataframe(array_connected, minimums) 

80 return samples_connected.sort_values(by=["id", "x", "y", "z"]).reset_index(drop=True) 

81 

82 

83def remove_unconnected_by_distance(samples: pd.DataFrame, threshold: float) -> pd.DataFrame: 

84 """ 

85 Remove unconnected regions based on distance. 

86 

87 Parameters 

88 ---------- 

89 samples 

90 Sample cell ids and coordinates. 

91 threshold 

92 Distance for removing unconnected regions. 

93 

94 Returns 

95 ------- 

96 : 

97 Samples with unconnected regions removed. 

98 """ 

99 

100 all_connected: list = [] 

101 

102 # Iterate through each id and filter out samples above the distance threshold. 

103 for cell_id, group in samples.groupby("id"): 

104 coordinates = group[["x", "y", "z"]].to_numpy() 

105 distances = [ 

106 get_minimum_distance(np.array([coordinate]), coordinates) for coordinate in coordinates 

107 ] 

108 connected = [ 

109 (cell_id, x, y, z) 

110 for distance, (x, y, z) in zip(distances, coordinates) 

111 if distance < threshold 

112 ] 

113 all_connected = all_connected + connected 

114 

115 # Convert back to dataframe. 

116 samples_connected = pd.DataFrame(all_connected, columns=["id", "x", "y", "z"]) 

117 return samples_connected.sort_values(by=["id", "x", "y", "z"]).reset_index(drop=True) 

118 

119 

120def get_sample_minimums(samples: pd.DataFrame) -> tuple[int, int, int]: 

121 """ 

122 Get minimums in x, y, and z directions for samples. 

123 

124 Parameters 

125 ---------- 

126 samples 

127 Sample cell ids and coordinates. 

128 

129 Returns 

130 ------- 

131 Tuple of minimums. 

132 """ 

133 

134 min_x = min(samples.x) 

135 min_y = min(samples.y) 

136 min_z = min(samples.z) 

137 return (min_x, min_y, min_z) 

138 

139 

140def get_sample_maximums(samples: pd.DataFrame) -> tuple[int, int, int]: 

141 """ 

142 Get maximums in x, y, and z directions for samples. 

143 

144 Parameters 

145 ---------- 

146 samples 

147 Sample cell ids and coordinates. 

148 

149 Returns 

150 ------- 

151 Tuple of maximums. 

152 """ 

153 

154 max_x = max(samples.x) 

155 max_y = max(samples.y) 

156 max_z = max(samples.z) 

157 return (max_x, max_y, max_z) 

158 

159 

160def convert_to_integer_array( 

161 samples: pd.DataFrame, 

162 minimums: tuple[int, int, int], 

163 maximums: tuple[int, int, int], 

164) -> np.ndarray: 

165 """ 

166 Convert ids and coordinate samples to integer array. 

167 

168 Parameters 

169 ---------- 

170 samples 

171 Sample cell ids and coordinates. 

172 minimums 

173 Minimums in x, y, and z directions. 

174 maximums 

175 Maximums in x, y, and z directions. 

176 

177 Returns 

178 ------- 

179 : 

180 Array of ids. 

181 """ 

182 

183 length, width, height = np.subtract(maximums, minimums).astype("int32") 

184 array = np.zeros((height + 1, width + 1, length + 1), dtype="int32") 

185 

186 coordinates = samples[["x", "y", "z"]].to_numpy() - minimums 

187 array[tuple(np.transpose(np.flip(coordinates, axis=1)))] = samples.id 

188 

189 return array 

190 

191 

192def convert_to_dataframe(array: np.ndarray, minimums: tuple[int, int, int]) -> pd.DataFrame: 

193 """ 

194 Convert integer array to ids and coordinate samples. 

195 

196 Parameters 

197 ---------- 

198 array 

199 Integer array of ids. 

200 minimums 

201 Minimums in x, y, and z directions. 

202 

203 Returns 

204 ------- 

205 : 

206 Dataframe of ids and coordinates. 

207 """ 

208 

209 min_x, min_y, min_z = minimums 

210 

211 samples = [ 

212 (array[z, y, x], x + min_x, y + min_y, z + min_z) for z, y, x in zip(*np.where(array != 0)) 

213 ] 

214 

215 return pd.DataFrame(samples, columns=["id", "x", "y", "z"]) 

216 

217 

218def get_minimum_distance(source: np.ndarray, targets: np.ndarray) -> float: 

219 """ 

220 Get the minimum distance from point to array of points. 

221 

222 Parameters 

223 ---------- 

224 source 

225 Coordinates of source point with shape (1, 3) 

226 targets 

227 Coordinates for N target points with shape (3, N) 

228 

229 Returns 

230 ------- 

231 : 

232 Minimum distance between source and targets. 

233 """ 

234 

235 distances = distance.cdist(source, targets) 

236 return np.min(distances[distances != 0])