Coverage for src/cell_abm_pipeline/__main__.py: 0%

98 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2024-06-05 19:14 +0000

1import hashlib 

2import importlib 

3import os 

4import sys 

5from datetime import datetime 

6from types import ModuleType 

7from typing import Optional 

8 

9from omegaconf import DictConfig, OmegaConf 

10from prefect.blocks.system import Secret 

11from prefect.deployments import Deployment 

12 

13from cell_abm_pipeline.__config__ import ( 

14 display_config, 

15 make_config_from_dotlist, 

16 make_config_from_file, 

17 make_config_from_yaml, 

18) 

19 

20 

21def main() -> None: 

22 if len(sys.argv) < 2: 

23 return 

24 

25 if "--dryrun" in sys.argv: 

26 sys.argv.remove("--dryrun") 

27 dryrun = True 

28 else: 

29 dryrun = False 

30 

31 if "--deploy" in sys.argv: 

32 sys.argv.remove("--deploy") 

33 deploy = True 

34 else: 

35 deploy = False 

36 

37 OmegaConf.register_new_resolver("secret", lambda secret: Secret.load(secret).get()) 

38 OmegaConf.register_new_resolver("concat", lambda items: ":".join(sorted(items))) 

39 OmegaConf.register_new_resolver( 

40 "home", lambda path: os.path.join(os.path.expanduser("~"), path) 

41 ) 

42 

43 module_name = sys.argv[1].replace("-", "_") 

44 module = get_module(module_name) 

45 

46 if module is None: 

47 return 

48 

49 if len(sys.argv) > 2 and sys.argv[2] == "::": 

50 config = make_config_from_dotlist(module, sys.argv[3:]) 

51 elif len(sys.argv) == 3: 

52 config = make_config_from_file(module, sys.argv[2]) 

53 else: 

54 config = make_config_from_yaml(module, sys.argv[2:]) 

55 

56 display_config(config) 

57 

58 if dryrun: 

59 return 

60 

61 if deploy: 

62 create_deployment(module, config) 

63 else: 

64 run_flow(module, config) 

65 

66 

67def get_module(module_name: str) -> Optional[ModuleType]: 

68 module_spec = importlib.util.find_spec(f"..flows.{module_name}", package=__name__) 

69 

70 if module_spec is not None: 

71 module = importlib.import_module(f"..flows.{module_name}", package=__name__) 

72 else: 

73 response = input(f"Module {module_name} does not exist. Create template for module [y/n]? ") 

74 if response[0] == "y": 

75 create_flow_template(module_name) 

76 

77 module = None 

78 

79 return module 

80 

81 

82def create_flow_template(module_name: str) -> None: 

83 path = os.path.dirname(os.path.abspath(__file__)) 

84 

85 with open(f"{path}/__template__.py", "r", encoding="utf-8") as file: 

86 template = file.read() 

87 

88 template = template.replace("name-of-flow", module_name.replace("_", "-")) 

89 

90 with open(f"{path}/flows/{module_name}.py", "w", encoding="utf-8") as file: 

91 file.write(template) 

92 

93 

94def run_flow(module: ModuleType, config: DictConfig) -> None: 

95 context = OmegaConf.to_object(config.context) 

96 series = OmegaConf.to_object(config.series) 

97 parameters = OmegaConf.to_object(config.parameters) 

98 

99 module.run_flow(context, series, parameters) 

100 

101 

102def create_deployment(module: ModuleType, config: DictConfig) -> None: 

103 context = OmegaConf.to_object(config.context) 

104 series = OmegaConf.to_object(config.series) 

105 parameters = OmegaConf.to_object(config.parameters) 

106 

107 flow_name = module.__name__.split(".")[-1].replace("_", "-") 

108 

109 name = input("Deployment name: ") 

110 name = name.replace("{{timestamp}}", datetime.now().strftime("%Y-%m-%d")) 

111 

112 if series is not None and hasattr(series, "name"): 

113 name = name.replace("{{name}}", series.name) 

114 

115 work_queue_name = input("Deployment queue (default if None): ") 

116 work_queue_name = "default" if work_queue_name == "" else work_queue_name 

117 

118 infra_overrides = {} 

119 

120 if context is not None and hasattr(context, "region"): 

121 infra_overrides = {"env": {"AWS_DEFAULT_REGION": context.region}} 

122 

123 deployment = Deployment.build_from_flow(flow=module.run_flow, name=name) 

124 checksum = hashlib.md5(OmegaConf.to_yaml(config, resolve=True).encode("utf-8")).hexdigest() 

125 

126 full_name = f"\033[1m{flow_name}/{name}\033[0m" 

127 

128 if deployment.load() and deployment.version != checksum: 

129 response = input(f"Deployment {full_name} already exists. Overwrite [y/n]? ") 

130 if response[0] != "y": 

131 return 

132 

133 deployment.update( 

134 version=checksum, 

135 parameters={ 

136 "context": context, 

137 "series": series, 

138 "parameters": parameters, 

139 }, 

140 work_queue_name=work_queue_name, 

141 infra_overrides=infra_overrides, 

142 ) 

143 deployment.apply() 

144 

145 print(f"Deployment {full_name} updated.") 

146 elif deployment.version != checksum: 

147 deployment = Deployment.build_from_flow( 

148 flow=module.run_flow, 

149 name=name, 

150 version=checksum, 

151 parameters={ 

152 "context": context, 

153 "series": series, 

154 "parameters": parameters, 

155 }, 

156 work_queue_name=work_queue_name, 

157 infra_overrides=infra_overrides, 

158 apply=True, 

159 ) 

160 

161 print(f"Deployment {full_name} created in queue \033[92m{ work_queue_name }\033[0m.") 

162 elif deployment.work_queue_name != work_queue_name: 

163 response = input(f"Update {full_name} queue to \033[92m{ work_queue_name }\033[0m [y/n]? ") 

164 if response[0] != "y": 

165 return 

166 

167 deployment.update(work_queue_name=work_queue_name) 

168 deployment.apply() 

169 

170 print(f"Deployment {full_name} queue updated.") 

171 else: 

172 print(f"Deployment {full_name} with same configuration already exists.") 

173 

174 

175if __name__ == "__main__": 

176 main()