Coverage for src/cell_abm_pipeline/flows/run_fargate_calculations.py: 0%
129 statements
« prev ^ index » next coverage.py v7.1.0, created at 2024-06-05 19:14 +0000
« prev ^ index » next coverage.py v7.1.0, created at 2024-06-05 19:14 +0000
1"""
2Workflow for running containerized calculations using Fargate.
4This workflow is used to run registered calculation flows across different
5simulation conditions and random seeds in parallel. The configurations for the
6selected calculation are passed into the corresponding flows.
8Some calculations can be chunked (in which the calculation is only run on a
9subset of the cells for the given condition, seed, and tick) in order to further
10parallelize the flow. For chunked calculations, re-running the flow will merge
11completed chunks into a single file.
13The flow will aim to avoid re-running any existing calculations. Calculations
14are skipped if the calculation output file already exists, or if the specific
15chunk already exists. Calculations for additional ticks are appended into the
16existing calculation output file.
18If the submit tasks option is turned off, the flow will print the full pipeline
19command instead, which can then be run locally. Running commands locally can be
20useful for conditions that require more CPUs/memory than are available.
22Note that this workflow works only if working location is an S3 bucket.
23"""
25import importlib
26from dataclasses import dataclass, field
27from enum import Enum
28from typing import Optional
30import pandas as pd
31from container_collection.fargate import (
32 make_fargate_task,
33 register_fargate_task,
34 submit_fargate_task,
35)
36from io_collection.keys import check_key, make_key, remove_key
37from io_collection.load import load_dataframe
38from io_collection.save import save_dataframe
39from prefect import flow, get_run_logger
41from cell_abm_pipeline.__config__ import make_dotlist_from_config
44class Calculation(Enum):
45 """Registered calculation types."""
47 COEFFICIENTS = ("calculate_coefficients", "COEFFICIENTS", True)
49 NEIGHBORS = ("calculate_neighbors", "NEIGHBORS", False)
51 POSITIONS = ("calculate_positions", "POSITIONS", False)
53 PROPERTIES = ("calculate_properties", "PROPERTIES", True)
55 IMAGE_PROPERTIES = ("calculate_image_properties", "PROPERTIES", False)
58@dataclass
59class ParametersConfig:
60 """Parameter configuration for run fargate calculations flow."""
62 image: str
63 """Name of pipeline image."""
65 ticks: list[int]
66 """List of ticks to run flow on."""
68 calculate: Optional[Calculation] = None
69 """Calculation type."""
71 chunk: Optional[int] = None
72 """Chunk size, if possible for the given calculation type."""
74 submit_tasks: bool = True
75 """True to submit calculation tasks, False otherwise."""
77 overrides: dict = field(default_factory=lambda: {})
78 """Overrides for the specific calculation type."""
81@dataclass
82class ContextConfig:
83 """Context configuration for run fargate calculations flow."""
85 working_location: str
86 """Location for input and output files (local path or S3 bucket)."""
88 account: str
89 """AWS account number."""
91 region: str
92 """AWS region."""
94 user: str
95 """User name prefix."""
97 vcpus: int
98 """Requested number of vcpus for AWS Fargate task."""
100 memory: int
101 """Requested memory for AWS Fargate task."""
103 cluster: str
104 """AWS Fargate cluster name."""
106 security_groups: str
107 """AWS Fargate security groups, separated by colon."""
109 subnets: str
110 """AWS Fargate subnets groups, separated by colon."""
113@dataclass
114class SeriesConfig:
115 """Series configuration for run fargate calculations flow."""
117 name: str
118 """Name of the simulation series."""
120 seeds: list[int]
121 """List of series random seeds."""
123 conditions: list[dict]
124 """List of series condition dictionaries (must include unique condition "key")."""
127@flow(name="run-fargate-calculations")
128def run_flow(context: ContextConfig, series: SeriesConfig, parameters: ParametersConfig) -> None:
129 """Main run fargate calculations flow."""
131 # Check that a valid calculation type is selected.
132 if parameters.calculate is None:
133 logger = get_run_logger()
134 logger.error(
135 "No valid calculation type selected. Valid options: [ %s ]",
136 " | ".join([member.name for member in Calculation]),
137 )
138 return
140 # Check that the working location is vali.d
141 if not context.working_location.startswith("s3://"):
142 logger = get_run_logger()
143 logger.error("Fargate calculations can only be run with S3 working location.")
144 return
146 # Get the calculation type.
147 module_name, suffix, chunkable = parameters.calculate.value
148 calc_path_key = make_key(series.name, "calculations", f"calculations.{suffix}")
150 # Get the region suffix, if it exists.
151 region = ""
152 if "region" in parameters.overrides is not None:
153 region = f"_{parameters.overrides['region']}"
155 # Create and register the task definition for the calculation.
156 if parameters.submit_tasks:
157 task_definition = make_fargate_task(
158 module_name,
159 parameters.image,
160 context.account,
161 context.region,
162 context.user,
163 context.vcpus,
164 context.memory,
165 )
166 task_definition_arn = register_fargate_task(task_definition)
168 # Import the module for the specified calculation.
169 module = importlib.import_module(f"..{module_name}", package=__name__)
171 # Create the context and series configs for the calculation.
172 context_config = module.ContextConfig(working_location=context.working_location)
173 series_config = module.SeriesConfig(name=series.name)
175 for condition in series.conditions:
176 for seed in series.seeds:
177 series_key = f"{series.name}_{condition['key']}_{seed:04d}"
179 # If the calculation is chunkable, load results to identify chunks.
180 if chunkable:
181 results_key = make_key(series.name, "results", f"{series_key}.csv")
182 results = load_dataframe(context.working_location, results_key)
184 # Check if the compiled calculation result already exists.
185 calc_key = make_key(calc_path_key, f"{series_key}{region}.{suffix}.csv")
186 calc_key_exists = check_key(context.working_location, calc_key)
188 # If the compiled calculation result already exists, load the calculated ticks.
189 existing_ticks = []
190 if calc_key_exists:
191 existing_calc = load_dataframe(context.working_location, calc_key, usecols=["TICK"])
192 existing_ticks = list(existing_calc["TICK"].unique())
194 for tick in parameters.ticks:
195 # Skip the tick if it exists in the compiled calculation result.
196 if tick in existing_ticks:
197 continue
199 # Check if the individual calculation result already exists.
200 tick_key = make_key(calc_path_key, f"{series_key}_{tick:06d}{region}.{suffix}.csv")
201 tick_key_exists = check_key(context.working_location, tick_key)
203 # Skip the tick if the individual calculation result exists.
204 if tick_key_exists:
205 continue
207 completed_offset_keys = []
208 missing_offsets_overrides = []
210 # If the calculation is chunkable, get the completed and missing chunk offsets.
211 if chunkable:
212 total = results[results["TICK"] == tick].shape[0]
213 chunk = parameters.chunk
214 all_offsets = list(range(0, total, chunk)) if chunk is not None else [0]
216 for offset in all_offsets:
217 if chunk is not None:
218 offset_key = tick_key.replace(
219 f".{suffix}.csv",
220 f".{offset:04d}.{chunk:04d}.{suffix}.csv",
221 )
222 offset_key_exists = check_key(context.working_location, offset_key)
224 if offset_key_exists:
225 completed_offset_keys.append(offset_key)
226 continue
228 missing_offsets_overrides.append({"offset": offset, "chunk": chunk})
229 else:
230 missing_offsets_overrides.append({})
232 # Create commands and submit (or display) tasks.
233 for offset_overrides in missing_offsets_overrides:
234 parameters_config = module.ParametersConfig(
235 key=condition["key"],
236 seed=seed,
237 tick=tick,
238 **offset_overrides,
239 **parameters.overrides,
240 )
242 config = {
243 "context": context_config,
244 "series": series_config,
245 "parameters": parameters_config,
246 }
248 command = ["abmpipe", module_name, "::"] + make_dotlist_from_config(config)
250 if parameters.submit_tasks:
251 submit_fargate_task.with_options(retries=2, retry_delay_seconds=1)(
252 module_name,
253 task_definition_arn,
254 context.user,
255 context.cluster,
256 context.security_groups.split(":"),
257 context.subnets.split(":"),
258 command,
259 )
260 else:
261 print(" ".join(command))
263 # If all chunk results exist, compile into unchunked result.
264 if (
265 chunkable
266 and len(completed_offset_keys) == len(all_offsets)
267 and chunk is not None
268 ):
269 tick_calcs = []
271 for key in completed_offset_keys:
272 tick_calcs.append(load_dataframe(context.working_location, key))
274 calc_dataframe = pd.concat(tick_calcs, ignore_index=True)
275 save_dataframe(context.working_location, tick_key, calc_dataframe, index=False)
277 for key in completed_offset_keys:
278 remove_key(context.working_location, key)