Coverage for src/cell_abm_pipeline/flows/run_batch_simulations.py: 0%

126 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2024-06-05 19:14 +0000

1""" 

2Workflow for running containerized models using AWS Batch. 

3 

4Working location structure: 

5 

6.. code-block:: bash 

7 

8 (name) 

9 └── YYYY-MM-DD 

10 ├── inits 

11 │ └── (name)_(group)_(seed).(extension) 

12 └── inputs 

13 └── (name)_(group)_(index).xml 

14 

15The simulation series manifest, produced by the summarize manifest flow, is used 

16to identify which simulation conditions and seeds are missing. These conditions 

17and seeds are converted into input files using the given template file, grouped 

18by the specified job size. The relevant initialization and input files are then 

19saved to a dated directory. All simulations in the same group will use the same 

20initialization file for a given seed; if different initializations need to be 

21used for different conditions, assign the conditions to different groups. 

22 

23Jobs are submitted to AWS Batch with the given infrastructure context. The jobs 

24are periodically queried for status at the specified retry delay interval, for 

25the specified number of retries. If jobs are still running after these retries 

26are complete, the job is not terminated unless specified. Output logs are also 

27saved after these retries are complete. Note that if the job is not complete 

28when the logs are saved, only the logs available at that time will be saved. 

29 

30Note that this workflow works only if working location is an S3 bucket. For 

31local working locations, use the run docker simulations flow instead. 

32""" 

33 

34from dataclasses import dataclass, field 

35from typing import Optional 

36 

37from arcade_collection.input import group_template_conditions 

38from container_collection.batch import ( 

39 check_batch_job, 

40 get_batch_logs, 

41 make_batch_job, 

42 register_batch_job, 

43 submit_batch_job, 

44 terminate_batch_job, 

45) 

46from container_collection.manifest import find_missing_conditions 

47from container_collection.template import generate_input_contents 

48from io_collection.keys import copy_key, make_key 

49from io_collection.load import load_dataframe, load_text 

50from io_collection.save import save_text 

51from prefect import flow, get_run_logger 

52 

53from cell_abm_pipeline.tasks.physicell import render_physicell_template 

54 

55 

56@dataclass 

57class ParametersConfig: 

58 """Parameter configuration for run batch simulations flow.""" 

59 

60 model: str 

61 """Name of model.""" 

62 

63 image: str 

64 """Name of model image.""" 

65 

66 retries: int 

67 """Number of retries to check if jobs are complete.""" 

68 

69 retry_delay: int 

70 """Delay between retries in seconds.""" 

71 

72 seeds_per_job: int = 1 

73 """Number of seeds per job.""" 

74 

75 log_filter: str = "" 

76 """Filter pattern for logs.""" 

77 

78 terminate_jobs: bool = True 

79 """True if jobs should be terminated after total retry time, False otherwise.""" 

80 

81 save_logs: bool = True 

82 """True to save job logs, False otherwise.""" 

83 

84 clean_jobs: bool = True 

85 """True to clean up job files, False otherwise.""" 

86 

87 

88@dataclass 

89class ContextConfig: 

90 """Context configuration for run batch simulations flow.""" 

91 

92 working_location: str 

93 """Location for input and output files (local path or S3 bucket).""" 

94 

95 manifest_location: str 

96 """Location of manifest file (local path or S3 bucket).""" 

97 

98 template_location: str 

99 """Location of template file (local path or S3 bucket).""" 

100 

101 account: str 

102 """AWS account number.""" 

103 

104 region: str 

105 """AWS region.""" 

106 

107 user: str 

108 """User name prefix.""" 

109 

110 vcpus: int 

111 """Requested number of vcpus for AWS Batch job.""" 

112 

113 memory: int 

114 """Requested memory for AWS Batch job.""" 

115 

116 queue: str 

117 """Name of AWS Batch queue.""" 

118 

119 

120@dataclass 

121class SeriesConfig: 

122 """Series configuration for run batch simulations flow.""" 

123 

124 name: str 

125 """Name of the simulation series.""" 

126 

127 manifest_key: str 

128 """Key for manifest file.""" 

129 

130 template_key: str 

131 """Key for template file.""" 

132 

133 seeds: list[int] 

134 """List of series random seeds.""" 

135 

136 conditions: list[dict] 

137 """List of series condition dictionaries (must include unique condition "key").""" 

138 

139 extensions: list[str] 

140 """List of file extensions in complete run.""" 

141 

142 inits: list[dict] = field(default_factory=lambda: []) 

143 """Initialization keys and associated group names.""" 

144 

145 groups: dict[str, Optional[str]] = field(default_factory=lambda: {"_": ""}) 

146 """Initialization groups, keyed by group name.""" 

147 

148 

149@flow(name="run-batch-simulations") 

150def run_flow(context: ContextConfig, series: SeriesConfig, parameters: ParametersConfig) -> None: 

151 """Main run batch simulations flow.""" 

152 

153 if not context.working_location.startswith("s3://"): 

154 logger = get_run_logger() 

155 logger.error("Batch simulations can only be run with S3 working location.") 

156 return 

157 

158 manifest = load_dataframe(context.manifest_location, series.manifest_key) 

159 template = load_text(context.template_location, series.template_key) 

160 

161 all_job_arns: list[str] = [] 

162 

163 for group in series.groups.keys(): 

164 if series.groups[group] is None: 

165 continue 

166 

167 group_key = series.name if group == "_" else f"{series.name}_{group}" 

168 group_conditions = [ 

169 condition 

170 for condition in series.conditions 

171 if group == "_" or condition["group"] == group 

172 ] 

173 group_inits = [init for init in series.inits if group == "_" or init["group"] == group] 

174 

175 # Find missing conditions. 

176 missing_conditions = find_missing_conditions( 

177 manifest, series.name, group_conditions, series.seeds, series.extensions 

178 ) 

179 

180 if len(missing_conditions) == 0: 

181 continue 

182 

183 # Convert missing conditions into model input files. 

184 input_contents: list[str] = [] 

185 

186 if parameters.model.upper() == "ARCADE": 

187 condition_sets = group_template_conditions(missing_conditions, parameters.seeds_per_job) 

188 input_contents = generate_input_contents(template, condition_sets) 

189 elif parameters.model.upper() == "PHYSICELL": 

190 input_contents = render_physicell_template(template, missing_conditions, group_key) 

191 

192 if len(input_contents) == 0: 

193 continue 

194 

195 # Copy source init files to target init files. 

196 valid_seeds = {condition["seed"] for condition in missing_conditions} 

197 for init in group_inits: 

198 if len(valid_seeds.intersection(init["seeds"])) == 0: 

199 continue 

200 

201 source_key = make_key(init["name"], "inits", f"inits.{parameters.model.upper()}") 

202 source = make_key(source_key, f"{init['name']}_{init['key']}") 

203 

204 target_key = make_key(series.name, "{{timestamp}}", "inits") 

205 targets = [make_key(target_key, f"{group_key}_{seed:04d}") for seed in init["seeds"]] 

206 

207 for target in targets: 

208 for ext in init["extensions"]: 

209 copy_key(context.working_location, f"{source}.{ext}", f"{target}.{ext}") 

210 

211 # Create job definition. 

212 registry = f"{context.account}.dkr.ecr.{context.region}.amazonaws.com" 

213 job_key = make_key(context.working_location, series.name, "{{timestamp}}/") 

214 job_definition = make_batch_job( 

215 f"{context.user}_{group_key}", 

216 f"{registry}/{context.user}/{parameters.image}", 

217 context.vcpus, 

218 context.memory, 

219 [ 

220 {"name": "SIMULATION_TYPE", "value": "AWS"}, 

221 {"name": "BATCH_WORKING_URL", "value": job_key}, 

222 {"name": "FILE_SET_NAME", "value": group_key}, 

223 ], 

224 f"arn:aws:iam::{context.account}:role/BatchJobRole", 

225 ) 

226 job_definition_arn = register_batch_job(job_definition) 

227 

228 # Save input files. 

229 for index, input_content in enumerate(input_contents): 

230 input_key = make_key(series.name, "{{timestamp}}", "inputs", f"{group_key}_{index}.xml") 

231 save_text(context.working_location, input_key, input_content) 

232 

233 # Submit jobs. 

234 job_arns = submit_batch_job( 

235 group_key, 

236 job_definition_arn, 

237 context.user, 

238 context.queue, 

239 len(input_contents), 

240 ) 

241 all_job_arns = all_job_arns + job_arns 

242 

243 for job_arn in all_job_arns: 

244 exitcode = check_batch_job.with_options( 

245 retries=parameters.retries, retry_delay_seconds=parameters.retry_delay 

246 ).submit(job_arn, parameters.retries) 

247 

248 wait_for = [exitcode] 

249 

250 if parameters.terminate_jobs: 

251 terminate_status = terminate_batch_job.submit(job_arn, wait_for=wait_for) 

252 wait_for = [terminate_status] 

253 

254 if parameters.save_logs: 

255 logs = get_batch_logs.submit(job_arn, parameters.log_filter, wait_for=wait_for) 

256 log_key = make_key(series.name, "{{timestamp}}", "logs", f"{job_arn}.log") 

257 save_text.submit(context.working_location, log_key, logs) 

258 wait_for = [logs]