Coverage for src/cell_abm_pipeline/flows/run_batch_simulations.py: 0%
126 statements
« prev ^ index » next coverage.py v7.1.0, created at 2024-06-05 19:14 +0000
« prev ^ index » next coverage.py v7.1.0, created at 2024-06-05 19:14 +0000
1"""
2Workflow for running containerized models using AWS Batch.
4Working location structure:
6.. code-block:: bash
8 (name)
9 └── YYYY-MM-DD
10 ├── inits
11 │ └── (name)_(group)_(seed).(extension)
12 └── inputs
13 └── (name)_(group)_(index).xml
15The simulation series manifest, produced by the summarize manifest flow, is used
16to identify which simulation conditions and seeds are missing. These conditions
17and seeds are converted into input files using the given template file, grouped
18by the specified job size. The relevant initialization and input files are then
19saved to a dated directory. All simulations in the same group will use the same
20initialization file for a given seed; if different initializations need to be
21used for different conditions, assign the conditions to different groups.
23Jobs are submitted to AWS Batch with the given infrastructure context. The jobs
24are periodically queried for status at the specified retry delay interval, for
25the specified number of retries. If jobs are still running after these retries
26are complete, the job is not terminated unless specified. Output logs are also
27saved after these retries are complete. Note that if the job is not complete
28when the logs are saved, only the logs available at that time will be saved.
30Note that this workflow works only if working location is an S3 bucket. For
31local working locations, use the run docker simulations flow instead.
32"""
34from dataclasses import dataclass, field
35from typing import Optional
37from arcade_collection.input import group_template_conditions
38from container_collection.batch import (
39 check_batch_job,
40 get_batch_logs,
41 make_batch_job,
42 register_batch_job,
43 submit_batch_job,
44 terminate_batch_job,
45)
46from container_collection.manifest import find_missing_conditions
47from container_collection.template import generate_input_contents
48from io_collection.keys import copy_key, make_key
49from io_collection.load import load_dataframe, load_text
50from io_collection.save import save_text
51from prefect import flow, get_run_logger
53from cell_abm_pipeline.tasks.physicell import render_physicell_template
56@dataclass
57class ParametersConfig:
58 """Parameter configuration for run batch simulations flow."""
60 model: str
61 """Name of model."""
63 image: str
64 """Name of model image."""
66 retries: int
67 """Number of retries to check if jobs are complete."""
69 retry_delay: int
70 """Delay between retries in seconds."""
72 seeds_per_job: int = 1
73 """Number of seeds per job."""
75 log_filter: str = ""
76 """Filter pattern for logs."""
78 terminate_jobs: bool = True
79 """True if jobs should be terminated after total retry time, False otherwise."""
81 save_logs: bool = True
82 """True to save job logs, False otherwise."""
84 clean_jobs: bool = True
85 """True to clean up job files, False otherwise."""
88@dataclass
89class ContextConfig:
90 """Context configuration for run batch simulations flow."""
92 working_location: str
93 """Location for input and output files (local path or S3 bucket)."""
95 manifest_location: str
96 """Location of manifest file (local path or S3 bucket)."""
98 template_location: str
99 """Location of template file (local path or S3 bucket)."""
101 account: str
102 """AWS account number."""
104 region: str
105 """AWS region."""
107 user: str
108 """User name prefix."""
110 vcpus: int
111 """Requested number of vcpus for AWS Batch job."""
113 memory: int
114 """Requested memory for AWS Batch job."""
116 queue: str
117 """Name of AWS Batch queue."""
120@dataclass
121class SeriesConfig:
122 """Series configuration for run batch simulations flow."""
124 name: str
125 """Name of the simulation series."""
127 manifest_key: str
128 """Key for manifest file."""
130 template_key: str
131 """Key for template file."""
133 seeds: list[int]
134 """List of series random seeds."""
136 conditions: list[dict]
137 """List of series condition dictionaries (must include unique condition "key")."""
139 extensions: list[str]
140 """List of file extensions in complete run."""
142 inits: list[dict] = field(default_factory=lambda: [])
143 """Initialization keys and associated group names."""
145 groups: dict[str, Optional[str]] = field(default_factory=lambda: {"_": ""})
146 """Initialization groups, keyed by group name."""
149@flow(name="run-batch-simulations")
150def run_flow(context: ContextConfig, series: SeriesConfig, parameters: ParametersConfig) -> None:
151 """Main run batch simulations flow."""
153 if not context.working_location.startswith("s3://"):
154 logger = get_run_logger()
155 logger.error("Batch simulations can only be run with S3 working location.")
156 return
158 manifest = load_dataframe(context.manifest_location, series.manifest_key)
159 template = load_text(context.template_location, series.template_key)
161 all_job_arns: list[str] = []
163 for group in series.groups.keys():
164 if series.groups[group] is None:
165 continue
167 group_key = series.name if group == "_" else f"{series.name}_{group}"
168 group_conditions = [
169 condition
170 for condition in series.conditions
171 if group == "_" or condition["group"] == group
172 ]
173 group_inits = [init for init in series.inits if group == "_" or init["group"] == group]
175 # Find missing conditions.
176 missing_conditions = find_missing_conditions(
177 manifest, series.name, group_conditions, series.seeds, series.extensions
178 )
180 if len(missing_conditions) == 0:
181 continue
183 # Convert missing conditions into model input files.
184 input_contents: list[str] = []
186 if parameters.model.upper() == "ARCADE":
187 condition_sets = group_template_conditions(missing_conditions, parameters.seeds_per_job)
188 input_contents = generate_input_contents(template, condition_sets)
189 elif parameters.model.upper() == "PHYSICELL":
190 input_contents = render_physicell_template(template, missing_conditions, group_key)
192 if len(input_contents) == 0:
193 continue
195 # Copy source init files to target init files.
196 valid_seeds = {condition["seed"] for condition in missing_conditions}
197 for init in group_inits:
198 if len(valid_seeds.intersection(init["seeds"])) == 0:
199 continue
201 source_key = make_key(init["name"], "inits", f"inits.{parameters.model.upper()}")
202 source = make_key(source_key, f"{init['name']}_{init['key']}")
204 target_key = make_key(series.name, "{{timestamp}}", "inits")
205 targets = [make_key(target_key, f"{group_key}_{seed:04d}") for seed in init["seeds"]]
207 for target in targets:
208 for ext in init["extensions"]:
209 copy_key(context.working_location, f"{source}.{ext}", f"{target}.{ext}")
211 # Create job definition.
212 registry = f"{context.account}.dkr.ecr.{context.region}.amazonaws.com"
213 job_key = make_key(context.working_location, series.name, "{{timestamp}}/")
214 job_definition = make_batch_job(
215 f"{context.user}_{group_key}",
216 f"{registry}/{context.user}/{parameters.image}",
217 context.vcpus,
218 context.memory,
219 [
220 {"name": "SIMULATION_TYPE", "value": "AWS"},
221 {"name": "BATCH_WORKING_URL", "value": job_key},
222 {"name": "FILE_SET_NAME", "value": group_key},
223 ],
224 f"arn:aws:iam::{context.account}:role/BatchJobRole",
225 )
226 job_definition_arn = register_batch_job(job_definition)
228 # Save input files.
229 for index, input_content in enumerate(input_contents):
230 input_key = make_key(series.name, "{{timestamp}}", "inputs", f"{group_key}_{index}.xml")
231 save_text(context.working_location, input_key, input_content)
233 # Submit jobs.
234 job_arns = submit_batch_job(
235 group_key,
236 job_definition_arn,
237 context.user,
238 context.queue,
239 len(input_contents),
240 )
241 all_job_arns = all_job_arns + job_arns
243 for job_arn in all_job_arns:
244 exitcode = check_batch_job.with_options(
245 retries=parameters.retries, retry_delay_seconds=parameters.retry_delay
246 ).submit(job_arn, parameters.retries)
248 wait_for = [exitcode]
250 if parameters.terminate_jobs:
251 terminate_status = terminate_batch_job.submit(job_arn, wait_for=wait_for)
252 wait_for = [terminate_status]
254 if parameters.save_logs:
255 logs = get_batch_logs.submit(job_arn, parameters.log_filter, wait_for=wait_for)
256 log_key = make_key(series.name, "{{timestamp}}", "logs", f"{job_arn}.log")
257 save_text.submit(context.working_location, log_key, logs)
258 wait_for = [logs]