Coverage for src/abm_shape_collection/extract_mesh_projections.py: 100%

44 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2024-09-25 19:34 +0000

1from __future__ import annotations 

2 

3import tempfile 

4from enum import Enum 

5 

6import numpy as np 

7import trimesh 

8from vtk import vtkPLYWriter, vtkPolyData # pylint: disable=no-name-in-module 

9 

10PROJECTIONS: list[tuple[str, tuple[int, int, int], int]] = [ 

11 ("side1", (0, 1, 0), 1), 

12 ("side2", (1, 0, 0), 0), 

13 ("top", (0, 0, 1), 2), 

14] 

15"""Mesh projection names, normals, and extent axes.""" 

16 

17 

18class ProjectionType(Enum): 

19 """Projection slice types.""" 

20 

21 SLICE = 1 

22 """Slice projection type.""" 

23 

24 EXTENT = 2 

25 """Extent projection type.""" 

26 

27 

28def extract_mesh_projections( 

29 mesh: vtkPolyData | trimesh.Trimesh, 

30 projection_types: list[ProjectionType] | None = None, 

31 offset: tuple[float, float, float] | None = None, 

32) -> dict: 

33 """ 

34 Extract slices and/or extents from mesh. 

35 

36 Slice projections are taken as the cross section of the mesh with planes in 

37 the x, y, and z directions with origin at (0,0,0). Extent projections are 

38 taken as cross section of the mesh with planes in the x, y, and z directions 

39 at increments of 0.5. 

40 

41 Parameters 

42 ---------- 

43 mesh 

44 Mesh object. 

45 projection_types 

46 Mesh projection types. 

47 offset 

48 Mesh translation applied before extracting slices and/or meshes. 

49 

50 Returns 

51 ------- 

52 : 

53 Map of mesh projection path points. 

54 """ 

55 

56 if isinstance(mesh, vtkPolyData): 

57 mesh = convert_vtk_to_trimesh(mesh) 

58 

59 if projection_types is None: 

60 projection_types = [ProjectionType.SLICE, ProjectionType.EXTENT] 

61 

62 if offset is not None: 

63 mesh.apply_translation(offset) 

64 

65 projections: dict[str, list[list[list[float]]] | dict[float, list[list[list[float]]]]] = {} 

66 

67 if ProjectionType.SLICE in projection_types: 

68 for projection, normal, _ in PROJECTIONS: 

69 projections[f"{projection}_slice"] = get_mesh_slice(mesh, normal) 

70 

71 if ProjectionType.EXTENT in projection_types: 

72 for projection, normal, index in PROJECTIONS: 

73 projections[f"{projection}_extent"] = get_mesh_extent(mesh, normal, index) 

74 

75 return projections 

76 

77 

78def convert_vtk_to_trimesh(mesh: vtkPolyData) -> trimesh.Trimesh: 

79 """ 

80 Convert VTK polydata to trimesh object. 

81 

82 Parameters 

83 ---------- 

84 mesh 

85 VTK mesh object. 

86 

87 Returns 

88 ------- 

89 : 

90 Trimesh mesh object. 

91 """ 

92 with tempfile.NamedTemporaryFile() as temp: 

93 writer = vtkPLYWriter() 

94 writer.SetInputData(mesh) 

95 writer.SetFileTypeToASCII() 

96 writer.SetFileName(f"{temp.name}.ply") 

97 _ = writer.Write() 

98 return trimesh.load(f"{temp.name}.ply") 

99 

100 

101def get_mesh_slice(mesh: trimesh.Trimesh, normal: tuple[int, int, int]) -> list[list[list[float]]]: 

102 """ 

103 Get slice of mesh along plane for given normal as path points. 

104 

105 Parameters 

106 ---------- 

107 mesh 

108 Mesh object. 

109 normal 

110 Vector normal to slice plane. 

111 

112 Returns 

113 ------- 

114 : 

115 List of connected vertices in space specifying the slice. 

116 """ 

117 

118 mesh_slice = mesh.section_multiplane((0, 0, 0), normal, [0]) 

119 return [[list(point) for point in entity] for entity in mesh_slice[0].discrete] 

120 

121 

122def get_mesh_extent( 

123 mesh: trimesh.Trimesh, normal: tuple[int, int, int], index: int 

124) -> dict[float, list[list[list[float]]]]: 

125 """ 

126 Get extent of mesh along plane for given normal as path points. 

127 

128 Parameters 

129 ---------- 

130 mesh 

131 Mesh object. 

132 normal 

133 Vector normal to slice plane. 

134 index 

135 Index of normal axis. 

136 

137 Returns 

138 ------- 

139 : 

140 Map to list of connected vertices in space specifying the extent. 

141 """ 

142 

143 layers = int(mesh.extents[index] + 2) 

144 plane_indices = list(np.arange(-layers, layers + 1, 0.5)) 

145 mesh_extents = mesh.section_multiplane((0, 0, 0), normal, plane_indices) 

146 return { 

147 index: [[list(point) for point in entity] for entity in mesh_extent.discrete] 

148 for mesh_extent, index in zip(mesh_extents, plane_indices) 

149 if mesh_extent is not None 

150 }