Coverage for src/abm_shape_collection/extract_shape_modes.py: 100%

43 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2024-09-25 19:34 +0000

1from typing import Callable 

2 

3import numpy as np 

4import pandas as pd 

5from sklearn.decomposition import PCA 

6 

7from abm_shape_collection.construct_mesh_from_points import construct_mesh_from_points 

8from abm_shape_collection.extract_mesh_projections import extract_mesh_projections 

9 

10 

11def extract_shape_modes( 

12 pca: PCA, 

13 data: pd.DataFrame, 

14 components: int, 

15 regions: list[str], 

16 order: int, 

17 delta: float, 

18 _construct_mesh_from_points: Callable = construct_mesh_from_points, 

19 _extract_mesh_projections: Callable = extract_mesh_projections, 

20) -> dict: 

21 """ 

22 Extract shape modes (latent walks in PC space) at the specified intervals. 

23 

24 Parameters 

25 ---------- 

26 pca 

27 Fit PCA object. 

28 data 

29 Sample data, with shape coefficients as columns. 

30 components 

31 Number of shape coefficients components. 

32 regions 

33 List of regions. 

34 order 

35 Order of the spherical harmonics coefficient parametrization. 

36 delta 

37 Interval for latent walk, bounded by -2 and +2 standard deviations. 

38 

39 Returns 

40 ------- 

41 : 

42 Map of regions to lists of shape modes at select points. 

43 """ 

44 

45 # pylint: disable=too-many-locals 

46 

47 # Transform data into shape mode space. 

48 columns = data.filter(like="shcoeffs").columns 

49 transform = pca.transform(data[columns].values) 

50 

51 # Calculate transformed means and standard deviations. 

52 means = transform.mean(axis=0) 

53 stds = transform.std(axis=0, ddof=1) 

54 

55 # Create bins. 

56 map_points = np.arange(-2, 2.5, delta) 

57 bin_edges = [-np.inf] + [point + delta / 2 for point in map_points[:-1]] + [np.inf] 

58 transform_binned = np.digitize(transform / stds, bin_edges) 

59 

60 # Initialize output dictionary. 

61 shape_modes: dict[str, list] = {} 

62 

63 for region in regions: 

64 region_shape_modes = [] 

65 

66 suffix = f".{region}" if region != "DEFAULT" else "" 

67 offsets = calculate_region_offsets(data, region) 

68 

69 for component in range(components): 

70 point_vector = np.zeros(components) 

71 

72 for point in map_points: 

73 point_bin = np.digitize(point, bin_edges) 

74 point_vector[component] = point 

75 

76 vector = means + np.multiply(stds, point_vector) 

77 indices = transform_binned[:, component] == point_bin 

78 

79 mesh = _construct_mesh_from_points(pca, vector, columns, order, suffix=suffix) 

80 

81 if region == "DEFAULT" or not any(indices): 

82 offset = None 

83 else: 

84 offset = ( 

85 offsets["x"][indices].mean(), 

86 offsets["y"][indices].mean(), 

87 offsets["z"][indices].mean(), 

88 ) 

89 

90 region_shape_modes.append( 

91 { 

92 "mode": component + 1, 

93 "point": point, 

94 "projections": _extract_mesh_projections( 

95 mesh, extents=False, offset=offset 

96 ), 

97 } 

98 ) 

99 

100 shape_modes[region] = region_shape_modes 

101 

102 return shape_modes 

103 

104 

105def calculate_region_offsets(data: pd.DataFrame, region: str) -> dict: 

106 """ 

107 Calculate offsets for non-default regions. 

108 

109 Parameters 

110 ---------- 

111 data 

112 Centroid location data. 

113 region 

114 Name of region (skipped if region is DEFAULT). 

115 

116 Returns 

117 ------- 

118 : 

119 Map of offsets in the x, y, and z directions. 

120 """ 

121 

122 if region == "DEFAULT": 

123 return {} 

124 

125 x_deltas = data[f"CENTER_X.{region}"].to_numpy() - data["CENTER_X"].to_numpy() 

126 y_deltas = data[f"CENTER_Y.{region}"].to_numpy() - data["CENTER_Y"].to_numpy() 

127 z_deltas = data[f"CENTER_Z.{region}"].to_numpy() - data["CENTER_Z"].to_numpy() 

128 angles = data["angle"].to_numpy() * np.pi / 180.0 

129 

130 sin_angles = np.sin(angles) 

131 cos_angles = np.cos(angles) 

132 

133 return { 

134 "x": x_deltas * cos_angles - y_deltas * sin_angles, 

135 "y": x_deltas * sin_angles + y_deltas * cos_angles, 

136 "z": z_deltas, 

137 }