Coverage for src/cell_abm_pipeline/__main__.py: 0%
98 statements
« prev ^ index » next coverage.py v7.1.0, created at 2024-06-05 19:14 +0000
« prev ^ index » next coverage.py v7.1.0, created at 2024-06-05 19:14 +0000
1import hashlib
2import importlib
3import os
4import sys
5from datetime import datetime
6from types import ModuleType
7from typing import Optional
9from omegaconf import DictConfig, OmegaConf
10from prefect.blocks.system import Secret
11from prefect.deployments import Deployment
13from cell_abm_pipeline.__config__ import (
14 display_config,
15 make_config_from_dotlist,
16 make_config_from_file,
17 make_config_from_yaml,
18)
21def main() -> None:
22 if len(sys.argv) < 2:
23 return
25 if "--dryrun" in sys.argv:
26 sys.argv.remove("--dryrun")
27 dryrun = True
28 else:
29 dryrun = False
31 if "--deploy" in sys.argv:
32 sys.argv.remove("--deploy")
33 deploy = True
34 else:
35 deploy = False
37 OmegaConf.register_new_resolver("secret", lambda secret: Secret.load(secret).get())
38 OmegaConf.register_new_resolver("concat", lambda items: ":".join(sorted(items)))
39 OmegaConf.register_new_resolver(
40 "home", lambda path: os.path.join(os.path.expanduser("~"), path)
41 )
43 module_name = sys.argv[1].replace("-", "_")
44 module = get_module(module_name)
46 if module is None:
47 return
49 if len(sys.argv) > 2 and sys.argv[2] == "::":
50 config = make_config_from_dotlist(module, sys.argv[3:])
51 elif len(sys.argv) == 3:
52 config = make_config_from_file(module, sys.argv[2])
53 else:
54 config = make_config_from_yaml(module, sys.argv[2:])
56 display_config(config)
58 if dryrun:
59 return
61 if deploy:
62 create_deployment(module, config)
63 else:
64 run_flow(module, config)
67def get_module(module_name: str) -> Optional[ModuleType]:
68 module_spec = importlib.util.find_spec(f"..flows.{module_name}", package=__name__)
70 if module_spec is not None:
71 module = importlib.import_module(f"..flows.{module_name}", package=__name__)
72 else:
73 response = input(f"Module {module_name} does not exist. Create template for module [y/n]? ")
74 if response[0] == "y":
75 create_flow_template(module_name)
77 module = None
79 return module
82def create_flow_template(module_name: str) -> None:
83 path = os.path.dirname(os.path.abspath(__file__))
85 with open(f"{path}/__template__.py", "r", encoding="utf-8") as file:
86 template = file.read()
88 template = template.replace("name-of-flow", module_name.replace("_", "-"))
90 with open(f"{path}/flows/{module_name}.py", "w", encoding="utf-8") as file:
91 file.write(template)
94def run_flow(module: ModuleType, config: DictConfig) -> None:
95 context = OmegaConf.to_object(config.context)
96 series = OmegaConf.to_object(config.series)
97 parameters = OmegaConf.to_object(config.parameters)
99 module.run_flow(context, series, parameters)
102def create_deployment(module: ModuleType, config: DictConfig) -> None:
103 context = OmegaConf.to_object(config.context)
104 series = OmegaConf.to_object(config.series)
105 parameters = OmegaConf.to_object(config.parameters)
107 flow_name = module.__name__.split(".")[-1].replace("_", "-")
109 name = input("Deployment name: ")
110 name = name.replace("{{timestamp}}", datetime.now().strftime("%Y-%m-%d"))
112 if series is not None and hasattr(series, "name"):
113 name = name.replace("{{name}}", series.name)
115 work_queue_name = input("Deployment queue (default if None): ")
116 work_queue_name = "default" if work_queue_name == "" else work_queue_name
118 infra_overrides = {}
120 if context is not None and hasattr(context, "region"):
121 infra_overrides = {"env": {"AWS_DEFAULT_REGION": context.region}}
123 deployment = Deployment.build_from_flow(flow=module.run_flow, name=name)
124 checksum = hashlib.md5(OmegaConf.to_yaml(config, resolve=True).encode("utf-8")).hexdigest()
126 full_name = f"\033[1m{flow_name}/{name}\033[0m"
128 if deployment.load() and deployment.version != checksum:
129 response = input(f"Deployment {full_name} already exists. Overwrite [y/n]? ")
130 if response[0] != "y":
131 return
133 deployment.update(
134 version=checksum,
135 parameters={
136 "context": context,
137 "series": series,
138 "parameters": parameters,
139 },
140 work_queue_name=work_queue_name,
141 infra_overrides=infra_overrides,
142 )
143 deployment.apply()
145 print(f"Deployment {full_name} updated.")
146 elif deployment.version != checksum:
147 deployment = Deployment.build_from_flow(
148 flow=module.run_flow,
149 name=name,
150 version=checksum,
151 parameters={
152 "context": context,
153 "series": series,
154 "parameters": parameters,
155 },
156 work_queue_name=work_queue_name,
157 infra_overrides=infra_overrides,
158 apply=True,
159 )
161 print(f"Deployment {full_name} created in queue \033[92m{ work_queue_name }\033[0m.")
162 elif deployment.work_queue_name != work_queue_name:
163 response = input(f"Update {full_name} queue to \033[92m{ work_queue_name }\033[0m [y/n]? ")
164 if response[0] != "y":
165 return
167 deployment.update(work_queue_name=work_queue_name)
168 deployment.apply()
170 print(f"Deployment {full_name} queue updated.")
171 else:
172 print(f"Deployment {full_name} with same configuration already exists.")
175if __name__ == "__main__":
176 main()