Coverage for src/cell_abm_pipeline/flows/run_fargate_calculations.py: 0%

129 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2024-06-05 19:14 +0000

1""" 

2Workflow for running containerized calculations using Fargate. 

3 

4This workflow is used to run registered calculation flows across different 

5simulation conditions and random seeds in parallel. The configurations for the 

6selected calculation are passed into the corresponding flows. 

7 

8Some calculations can be chunked (in which the calculation is only run on a 

9subset of the cells for the given condition, seed, and tick) in order to further 

10parallelize the flow. For chunked calculations, re-running the flow will merge 

11completed chunks into a single file. 

12 

13The flow will aim to avoid re-running any existing calculations. Calculations 

14are skipped if the calculation output file already exists, or if the specific 

15chunk already exists. Calculations for additional ticks are appended into the 

16existing calculation output file. 

17 

18If the submit tasks option is turned off, the flow will print the full pipeline 

19command instead, which can then be run locally. Running commands locally can be 

20useful for conditions that require more CPUs/memory than are available. 

21 

22Note that this workflow works only if working location is an S3 bucket. 

23""" 

24 

25import importlib 

26from dataclasses import dataclass, field 

27from enum import Enum 

28from typing import Optional 

29 

30import pandas as pd 

31from container_collection.fargate import ( 

32 make_fargate_task, 

33 register_fargate_task, 

34 submit_fargate_task, 

35) 

36from io_collection.keys import check_key, make_key, remove_key 

37from io_collection.load import load_dataframe 

38from io_collection.save import save_dataframe 

39from prefect import flow, get_run_logger 

40 

41from cell_abm_pipeline.__config__ import make_dotlist_from_config 

42 

43 

44class Calculation(Enum): 

45 """Registered calculation types.""" 

46 

47 COEFFICIENTS = ("calculate_coefficients", "COEFFICIENTS", True) 

48 

49 NEIGHBORS = ("calculate_neighbors", "NEIGHBORS", False) 

50 

51 POSITIONS = ("calculate_positions", "POSITIONS", False) 

52 

53 PROPERTIES = ("calculate_properties", "PROPERTIES", True) 

54 

55 IMAGE_PROPERTIES = ("calculate_image_properties", "PROPERTIES", False) 

56 

57 

58@dataclass 

59class ParametersConfig: 

60 """Parameter configuration for run fargate calculations flow.""" 

61 

62 image: str 

63 """Name of pipeline image.""" 

64 

65 ticks: list[int] 

66 """List of ticks to run flow on.""" 

67 

68 calculate: Optional[Calculation] = None 

69 """Calculation type.""" 

70 

71 chunk: Optional[int] = None 

72 """Chunk size, if possible for the given calculation type.""" 

73 

74 submit_tasks: bool = True 

75 """True to submit calculation tasks, False otherwise.""" 

76 

77 overrides: dict = field(default_factory=lambda: {}) 

78 """Overrides for the specific calculation type.""" 

79 

80 

81@dataclass 

82class ContextConfig: 

83 """Context configuration for run fargate calculations flow.""" 

84 

85 working_location: str 

86 """Location for input and output files (local path or S3 bucket).""" 

87 

88 account: str 

89 """AWS account number.""" 

90 

91 region: str 

92 """AWS region.""" 

93 

94 user: str 

95 """User name prefix.""" 

96 

97 vcpus: int 

98 """Requested number of vcpus for AWS Fargate task.""" 

99 

100 memory: int 

101 """Requested memory for AWS Fargate task.""" 

102 

103 cluster: str 

104 """AWS Fargate cluster name.""" 

105 

106 security_groups: str 

107 """AWS Fargate security groups, separated by colon.""" 

108 

109 subnets: str 

110 """AWS Fargate subnets groups, separated by colon.""" 

111 

112 

113@dataclass 

114class SeriesConfig: 

115 """Series configuration for run fargate calculations flow.""" 

116 

117 name: str 

118 """Name of the simulation series.""" 

119 

120 seeds: list[int] 

121 """List of series random seeds.""" 

122 

123 conditions: list[dict] 

124 """List of series condition dictionaries (must include unique condition "key").""" 

125 

126 

127@flow(name="run-fargate-calculations") 

128def run_flow(context: ContextConfig, series: SeriesConfig, parameters: ParametersConfig) -> None: 

129 """Main run fargate calculations flow.""" 

130 

131 # Check that a valid calculation type is selected. 

132 if parameters.calculate is None: 

133 logger = get_run_logger() 

134 logger.error( 

135 "No valid calculation type selected. Valid options: [ %s ]", 

136 " | ".join([member.name for member in Calculation]), 

137 ) 

138 return 

139 

140 # Check that the working location is vali.d 

141 if not context.working_location.startswith("s3://"): 

142 logger = get_run_logger() 

143 logger.error("Fargate calculations can only be run with S3 working location.") 

144 return 

145 

146 # Get the calculation type. 

147 module_name, suffix, chunkable = parameters.calculate.value 

148 calc_path_key = make_key(series.name, "calculations", f"calculations.{suffix}") 

149 

150 # Get the region suffix, if it exists. 

151 region = "" 

152 if "region" in parameters.overrides is not None: 

153 region = f"_{parameters.overrides['region']}" 

154 

155 # Create and register the task definition for the calculation. 

156 if parameters.submit_tasks: 

157 task_definition = make_fargate_task( 

158 module_name, 

159 parameters.image, 

160 context.account, 

161 context.region, 

162 context.user, 

163 context.vcpus, 

164 context.memory, 

165 ) 

166 task_definition_arn = register_fargate_task(task_definition) 

167 

168 # Import the module for the specified calculation. 

169 module = importlib.import_module(f"..{module_name}", package=__name__) 

170 

171 # Create the context and series configs for the calculation. 

172 context_config = module.ContextConfig(working_location=context.working_location) 

173 series_config = module.SeriesConfig(name=series.name) 

174 

175 for condition in series.conditions: 

176 for seed in series.seeds: 

177 series_key = f"{series.name}_{condition['key']}_{seed:04d}" 

178 

179 # If the calculation is chunkable, load results to identify chunks. 

180 if chunkable: 

181 results_key = make_key(series.name, "results", f"{series_key}.csv") 

182 results = load_dataframe(context.working_location, results_key) 

183 

184 # Check if the compiled calculation result already exists. 

185 calc_key = make_key(calc_path_key, f"{series_key}{region}.{suffix}.csv") 

186 calc_key_exists = check_key(context.working_location, calc_key) 

187 

188 # If the compiled calculation result already exists, load the calculated ticks. 

189 existing_ticks = [] 

190 if calc_key_exists: 

191 existing_calc = load_dataframe(context.working_location, calc_key, usecols=["TICK"]) 

192 existing_ticks = list(existing_calc["TICK"].unique()) 

193 

194 for tick in parameters.ticks: 

195 # Skip the tick if it exists in the compiled calculation result. 

196 if tick in existing_ticks: 

197 continue 

198 

199 # Check if the individual calculation result already exists. 

200 tick_key = make_key(calc_path_key, f"{series_key}_{tick:06d}{region}.{suffix}.csv") 

201 tick_key_exists = check_key(context.working_location, tick_key) 

202 

203 # Skip the tick if the individual calculation result exists. 

204 if tick_key_exists: 

205 continue 

206 

207 completed_offset_keys = [] 

208 missing_offsets_overrides = [] 

209 

210 # If the calculation is chunkable, get the completed and missing chunk offsets. 

211 if chunkable: 

212 total = results[results["TICK"] == tick].shape[0] 

213 chunk = parameters.chunk 

214 all_offsets = list(range(0, total, chunk)) if chunk is not None else [0] 

215 

216 for offset in all_offsets: 

217 if chunk is not None: 

218 offset_key = tick_key.replace( 

219 f".{suffix}.csv", 

220 f".{offset:04d}.{chunk:04d}.{suffix}.csv", 

221 ) 

222 offset_key_exists = check_key(context.working_location, offset_key) 

223 

224 if offset_key_exists: 

225 completed_offset_keys.append(offset_key) 

226 continue 

227 

228 missing_offsets_overrides.append({"offset": offset, "chunk": chunk}) 

229 else: 

230 missing_offsets_overrides.append({}) 

231 

232 # Create commands and submit (or display) tasks. 

233 for offset_overrides in missing_offsets_overrides: 

234 parameters_config = module.ParametersConfig( 

235 key=condition["key"], 

236 seed=seed, 

237 tick=tick, 

238 **offset_overrides, 

239 **parameters.overrides, 

240 ) 

241 

242 config = { 

243 "context": context_config, 

244 "series": series_config, 

245 "parameters": parameters_config, 

246 } 

247 

248 command = ["abmpipe", module_name, "::"] + make_dotlist_from_config(config) 

249 

250 if parameters.submit_tasks: 

251 submit_fargate_task.with_options(retries=2, retry_delay_seconds=1)( 

252 module_name, 

253 task_definition_arn, 

254 context.user, 

255 context.cluster, 

256 context.security_groups.split(":"), 

257 context.subnets.split(":"), 

258 command, 

259 ) 

260 else: 

261 print(" ".join(command)) 

262 

263 # If all chunk results exist, compile into unchunked result. 

264 if ( 

265 chunkable 

266 and len(completed_offset_keys) == len(all_offsets) 

267 and chunk is not None 

268 ): 

269 tick_calcs = [] 

270 

271 for key in completed_offset_keys: 

272 tick_calcs.append(load_dataframe(context.working_location, key)) 

273 

274 calc_dataframe = pd.concat(tick_calcs, ignore_index=True) 

275 save_dataframe(context.working_location, tick_key, calc_dataframe, index=False) 

276 

277 for key in completed_offset_keys: 

278 remove_key(context.working_location, key)