Coverage for src/cell_abm_pipeline/flows/plot_cell_shapes.py: 0%

228 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2024-06-05 19:14 +0000

1""" 

2Workflow for plotting cell shapes. 

3 

4Working location structure: 

5 

6.. code-block:: bash 

7 

8 (name) 

9 ├── groups 

10 │ └── groups.CELL_SHAPES 

11 │ ├── (name).feature_correlations.(key).(region).csv 

12 │ ├── (name).feature_distributions.(feature).json 

13 │ ├── (name).mode_correlations.csv 

14 │ ├── (name).population_counts.(tick).csv 

15 │ ├── (name).population_stats.json 

16 │ ├── (name).shape_average.(key).(projection).json 

17 │ ├── (name).shape_errors.json 

18 │ ├── (name).shape_modes.(key).(region).(mode).(projection).json 

19 │ └── (name).variance_explained.csv 

20 └── plots 

21 └── plots.CELL_SHAPES 

22 ├── (name).feature_correlations.(key).(region).png 

23 ├── (name).feature_distributions.(feature).png 

24 ├── (name).mode_correlations.(key).(key).png 

25 ├── (name).population_counts.(tick).png 

26 ├── (name).population_stats.png 

27 ├── (name).shape_average.(key).(projection).svg 

28 ├── (name).shape_errors.png 

29 ├── (name).shape_modes.(key).(region).(mode).(projection).(point).svg 

30 └── (name).variance_explained.png 

31 

32Plots use grouped data from **groups.CELL_SHAPES**. Plots are saved to 

33**plots.CELL_SHAPES**. 

34""" 

35 

36from dataclasses import dataclass, field 

37 

38import numpy as np 

39from io_collection.keys import make_key 

40from io_collection.load import load_dataframe, load_json 

41from io_collection.save import save_figure, save_text 

42from prefect import flow 

43 

44from cell_abm_pipeline.flows.analyze_cell_shapes import PCA_COMPONENTS 

45from cell_abm_pipeline.flows.group_cell_shapes import ( 

46 CORRELATION_PROPERTIES, 

47 DISTRIBUTION_PROPERTIES, 

48 PROJECTIONS, 

49) 

50from cell_abm_pipeline.tasks import ( 

51 build_svg_image, 

52 make_bar_figure, 

53 make_heatmap_figure, 

54 make_histogram_figure, 

55 make_line_figure, 

56) 

57 

58PLOTS: list[str] = [ 

59 "feature_correlations", 

60 "feature_distributions", 

61 "mode_correlations", 

62 "population_counts", 

63 "population_stats", 

64 "shape_average", 

65 "shape_errors", 

66 "shape_modes", 

67 "variance_explained", 

68] 

69 

70 

71REGION_COLORS: dict[str, str] = {"DEFAULT": "#FF00FF", "NUCLEUS": "#00FFFF"} 

72 

73KEY_COLORS: list[str] = [ 

74 "#7F3C8D", 

75 "#11A579", 

76 "#3969AC", 

77 "#F2B701", 

78 "#E73F74", 

79 "#80BA5A", 

80 "#E68310", 

81 "#008695", 

82 "#CF1C90", 

83 "#f97b72", 

84 "#4b4b8f", 

85 "#A5AA99", 

86] 

87 

88 

89@dataclass 

90class ParametersConfigFeatureCorrelations: 

91 """Parameter configuration for plot cell shapes subflow - feature correlations.""" 

92 

93 properties: list[str] = field(default_factory=lambda: CORRELATION_PROPERTIES) 

94 """List of shape properties.""" 

95 

96 regions: list[str] = field(default_factory=lambda: ["DEFAULT"]) 

97 """List of subcellular regions.""" 

98 

99 components: int = PCA_COMPONENTS 

100 """Number of principal components (i.e. shape modes).""" 

101 

102 

103@dataclass 

104class ParametersConfigFeatureDistributions: 

105 """Parameter configuration for plot cell shapes subflow - feature distributions.""" 

106 

107 properties: list[str] = field(default_factory=lambda: DISTRIBUTION_PROPERTIES) 

108 """List of shape properties.""" 

109 

110 regions: list[str] = field(default_factory=lambda: ["(region)"]) 

111 """List of subcellular regions.""" 

112 

113 components: int = PCA_COMPONENTS 

114 """Number of principal components (i.e. shape modes).""" 

115 

116 

117@dataclass 

118class ParametersConfigModeCorrelations: 

119 """Parameter configuration for plot cell shapes subflow - mode correlations.""" 

120 

121 components: int = PCA_COMPONENTS 

122 """Number of principal components (i.e. shape modes).""" 

123 

124 

125@dataclass 

126class ParametersConfigPopulationCounts: 

127 """Parameter configuration for plot cell shapes subflow - population counts.""" 

128 

129 tick: int = 0 

130 """Simulation tick to use for plotting population counts.""" 

131 

132 

133@dataclass 

134class ParametersConfigPopulationStats: 

135 """Parameter configuration for plot cell shapes subflow - population stats.""" 

136 

137 

138@dataclass 

139class ParametersConfigShapeAverage: 

140 """Parameter configuration for plot cell shapes subflow - shape average.""" 

141 

142 projections: list[str] = field(default_factory=lambda: PROJECTIONS) 

143 """List of shape projections.""" 

144 

145 box: tuple[int, int] = field(default_factory=lambda: (100, 100)) 

146 """Size of bounding box.""" 

147 

148 scale: float = 1 

149 """Scaling for image.""" 

150 

151 

152@dataclass 

153class ParametersConfigShapeErrors: 

154 """Parameter configuration for plot cell shapes subflow - shape errors.""" 

155 

156 

157@dataclass 

158class ParametersConfigShapeModes: 

159 """Parameter configuration for plot cell shapes subflow - shape modes.""" 

160 

161 regions: list[str] = field(default_factory=lambda: ["(region)"]) 

162 """List of subcellular regions.""" 

163 

164 components: int = PCA_COMPONENTS 

165 """Number of principal components (i.e. shape modes).""" 

166 

167 projections: list[str] = field(default_factory=lambda: PROJECTIONS) 

168 """List of shape projections.""" 

169 

170 point: float = 0 

171 """Selected shape mode map point.""" 

172 

173 box: tuple[int, int] = field(default_factory=lambda: (100, 100)) 

174 """Size of bounding box.""" 

175 

176 scale: float = 1 

177 """Scaling for image.""" 

178 

179 colors: dict[str, str] = field(default_factory=lambda: REGION_COLORS) 

180 """Colors for each region.""" 

181 

182 

183@dataclass 

184class ParametersConfigVarianceExplained: 

185 """Parameter configuration for plot cell shapes subflow - variance explained.""" 

186 

187 components: int = PCA_COMPONENTS 

188 """Number of principal components (i.e. shape modes).""" 

189 

190 colors: list[str] = field(default_factory=lambda: KEY_COLORS) 

191 """Colors for each key.""" 

192 

193 

194@dataclass 

195class ParametersConfig: 

196 """Parameter configuration for plot cell shapes flow.""" 

197 

198 plots: list[str] = field(default_factory=lambda: PLOTS) 

199 """List of cell shape plots.""" 

200 

201 feature_correlations: ParametersConfigFeatureCorrelations = ( 

202 ParametersConfigFeatureCorrelations() 

203 ) 

204 """Parameters for plot feature correlations subflow.""" 

205 

206 feature_distributions: ParametersConfigFeatureDistributions = ( 

207 ParametersConfigFeatureDistributions() 

208 ) 

209 """Parameters for plot feature distributions subflow.""" 

210 

211 mode_correlations: ParametersConfigModeCorrelations = ParametersConfigModeCorrelations() 

212 """Parameters for plot mode correlations subflow.""" 

213 

214 population_counts: ParametersConfigPopulationCounts = ParametersConfigPopulationCounts() 

215 """Parameters for plot population counts subflow.""" 

216 

217 population_stats: ParametersConfigPopulationStats = ParametersConfigPopulationStats() 

218 """Parameters for plot population stats subflow.""" 

219 

220 shape_average: ParametersConfigShapeAverage = ParametersConfigShapeAverage() 

221 """Parameters for plot shape average subflow.""" 

222 

223 shape_errors: ParametersConfigShapeErrors = ParametersConfigShapeErrors() 

224 """Parameters for plot shape errors subflow.""" 

225 

226 shape_modes: ParametersConfigShapeModes = ParametersConfigShapeModes() 

227 """Parameters for plot shape modes subflow.""" 

228 

229 variance_explained: ParametersConfigVarianceExplained = ParametersConfigVarianceExplained() 

230 """Parameters for plot variance explained subflow.""" 

231 

232 

233@dataclass 

234class ContextConfig: 

235 """Context configuration for plot cell shapes flow.""" 

236 

237 working_location: str 

238 """Location for input and output files (local path or S3 bucket).""" 

239 

240 

241@dataclass 

242class SeriesConfig: 

243 """Series configuration for plot cell shapes flow.""" 

244 

245 name: str 

246 """Name of the simulation series.""" 

247 

248 conditions: list[dict] 

249 """List of series condition dictionaries (must include unique condition "key").""" 

250 

251 

252@flow(name="plot-cell-shapes") 

253def run_flow(context: ContextConfig, series: SeriesConfig, parameters: ParametersConfig) -> None: 

254 """ 

255 Main plot cell shapes flow. 

256 

257 Calls the following subflows, if the plot is specified: 

258 

259 - :py:func:`run_flow_plot_feature_correlations` 

260 - :py:func:`run_flow_plot_feature_distributions` 

261 - :py:func:`run_flow_plot_mode_correlations` 

262 - :py:func:`run_flow_plot_population_counts` 

263 - :py:func:`run_flow_plot_population_stats` 

264 - :py:func:`run_flow_plot_shape_average` 

265 - :py:func:`run_flow_plot_shape_errors` 

266 - :py:func:`run_flow_plot_shape_modes` 

267 - :py:func:`run_flow_plot_variance_explained` 

268 """ 

269 

270 if "feature_correlations" in parameters.plots: 

271 run_flow_plot_feature_correlations(context, series, parameters.feature_correlations) 

272 

273 if "feature_distributions" in parameters.plots: 

274 run_flow_plot_feature_distributions(context, series, parameters.feature_distributions) 

275 

276 if "mode_correlations" in parameters.plots: 

277 run_flow_plot_mode_correlations(context, series, parameters.mode_correlations) 

278 

279 if "population_counts" in parameters.plots: 

280 run_flow_plot_population_counts(context, series, parameters.population_counts) 

281 

282 if "population_stats" in parameters.plots: 

283 run_flow_plot_population_stats(context, series, parameters.population_stats) 

284 

285 if "shape_average" in parameters.plots: 

286 run_flow_plot_shape_average(context, series, parameters.shape_average) 

287 

288 if "shape_errors" in parameters.plots: 

289 run_flow_plot_shape_errors(context, series, parameters.shape_errors) 

290 

291 if "shape_modes" in parameters.plots: 

292 run_flow_plot_shape_modes(context, series, parameters.shape_modes) 

293 

294 if "variance_explained" in parameters.plots: 

295 run_flow_plot_variance_explained(context, series, parameters.variance_explained) 

296 

297 

298@flow(name="plot-cell-shapes_plot-feature-correlations") 

299def run_flow_plot_feature_correlations( 

300 context: ContextConfig, series: SeriesConfig, parameters: ParametersConfigFeatureCorrelations 

301) -> None: 

302 """Plot cell shapes subflow for feature correlations.""" 

303 

304 group_key = make_key(series.name, "groups", "groups.CELL_SHAPES") 

305 plot_key = make_key(series.name, "plots", "plots.CELL_SHAPES") 

306 keys = [condition["key"] for condition in series.conditions] 

307 

308 modes = [f"PC{component + 1}" for component in range(parameters.components)] 

309 properties = [prop.upper() for prop in parameters.properties] 

310 

311 for key in keys: 

312 for region in parameters.regions: 

313 group = load_dataframe( 

314 context.working_location, 

315 make_key(group_key, f"{series.name}.feature_correlations.{key}.{region}.csv"), 

316 ) 

317 

318 group_sorted = group.set_index(["property", "mode"]).sort_index() 

319 group_values = [ 

320 [abs(group_sorted.loc[prop, mode]["correlation"]) for mode in modes] 

321 for prop in properties 

322 ] 

323 

324 save_figure( 

325 context.working_location, 

326 make_key(plot_key, f"{series.name}.feature_correlations.{key}.{region}.png"), 

327 make_heatmap_figure(properties, modes, group_values), 

328 ) 

329 

330 

331@flow(name="plot-cell-shapes_plot-feature-distributions") 

332def run_flow_plot_feature_distributions( 

333 context: ContextConfig, series: SeriesConfig, parameters: ParametersConfigFeatureDistributions 

334) -> None: 

335 """Plot cell shapes subflow for feature distributions.""" 

336 

337 group_key = make_key(series.name, "groups", "groups.CELL_SHAPES") 

338 plot_key = make_key(series.name, "plots", "plots.CELL_SHAPES") 

339 keys = [condition["key"] for condition in series.conditions] 

340 

341 features = [ 

342 f"{prop}.{region}" for prop in parameters.properties for region in parameters.regions 

343 ] + [f"PC{component + 1}" for component in range(parameters.components)] 

344 

345 for feature in features: 

346 feature_key = feature.upper() 

347 

348 group = load_json( 

349 context.working_location, 

350 make_key(group_key, f"{series.name}.feature_distributions.{feature_key}.json"), 

351 ) 

352 

353 assert isinstance(group, dict) 

354 

355 save_figure( 

356 context.working_location, 

357 make_key(plot_key, f"{series.name}.feature_distributions.{feature_key}.png"), 

358 make_histogram_figure(keys, group), 

359 ) 

360 

361 

362@flow(name="plot-cell-shapes_plot-mode-correlations") 

363def run_flow_plot_mode_correlations( 

364 context: ContextConfig, series: SeriesConfig, parameters: ParametersConfigModeCorrelations 

365) -> None: 

366 """Plot cell shapes subflow for mode correlations.""" 

367 

368 group_key = make_key(series.name, "groups", "groups.CELL_SHAPES") 

369 plot_key = make_key(series.name, "plots", "plots.CELL_SHAPES") 

370 keys = ["reference"] + [condition["key"] for condition in series.conditions] 

371 

372 modes = [f"PC{component + 1}" for component in range(parameters.components)] 

373 

374 group = load_dataframe( 

375 context.working_location, 

376 make_key(group_key, f"{series.name}.mode_correlations.csv"), 

377 ) 

378 

379 for source_key in keys: 

380 for target_key in keys: 

381 if source_key == target_key: 

382 continue 

383 

384 group_sorted = ( 

385 group[(group["source_key"] == source_key) & (group["target_key"] == target_key)] 

386 .set_index(["source_mode", "target_mode"]) 

387 .sort_index() 

388 ) 

389 group_values = [ 

390 [ 

391 abs(group_sorted.loc[source_mode, target_mode]["correlation"]) 

392 for target_mode in modes 

393 ] 

394 for source_mode in modes 

395 ] 

396 

397 save_figure( 

398 context.working_location, 

399 make_key( 

400 plot_key, f"{series.name}.mode_correlations.{source_key}.{target_key}.png" 

401 ), 

402 make_heatmap_figure(modes, modes, group_values), 

403 ) 

404 

405 

406@flow(name="plot-cell-shapes_plot-population-counts") 

407def run_flow_plot_population_counts( 

408 context: ContextConfig, series: SeriesConfig, parameters: ParametersConfigPopulationCounts 

409) -> None: 

410 """Plot cell shapes subflow for population counts.""" 

411 

412 group_key = make_key(series.name, "groups", "groups.CELL_SHAPES") 

413 plot_key = make_key(series.name, "plots", "plots.CELL_SHAPES") 

414 keys = [condition["key"] for condition in series.conditions] 

415 

416 group = load_dataframe( 

417 context.working_location, 

418 make_key(group_key, f"{series.name}.population_counts.{parameters.tick:06d}.csv"), 

419 ) 

420 

421 key_group = { 

422 key: { 

423 "COUNT": { 

424 "mean": group[group["key"] == key]["count"].mean(), 

425 "std": group[group["key"] == key]["count"].std(ddof=1), 

426 } 

427 } 

428 for key in keys 

429 } 

430 

431 save_figure( 

432 context.working_location, 

433 make_key(plot_key, f"{series.name}.population_counts.{parameters.tick:06d}.png"), 

434 make_bar_figure(keys, key_group), 

435 ) 

436 

437 

438@flow(name="plot-cell-shapes_plot-population-stats") 

439def run_flow_plot_population_stats( 

440 context: ContextConfig, series: SeriesConfig, parameters: ParametersConfigPopulationStats 

441) -> None: 

442 """Plot cell shapes subflow for population stats.""" 

443 

444 group_key = make_key(series.name, "groups", "groups.CELL_SHAPES") 

445 plot_key = make_key(series.name, "plots", "plots.CELL_SHAPES") 

446 keys = [condition["key"] for condition in series.conditions] 

447 

448 group = load_json( 

449 context.working_location, 

450 make_key(group_key, f"{series.name}.population_stats.json"), 

451 ) 

452 

453 key_group = {key: group[key] for key in keys} 

454 

455 save_figure( 

456 context.working_location, 

457 make_key(plot_key, f"{series.name}.population_stats.png"), 

458 make_bar_figure(keys, key_group), 

459 ) 

460 

461 

462@flow(name="plot-cell-shapes_plot-shape-average") 

463def run_flow_plot_shape_average( 

464 context: ContextConfig, series: SeriesConfig, parameters: ParametersConfigShapeAverage 

465) -> None: 

466 """ 

467 Plot cell shapes subflow for shape average. 

468 """ 

469 

470 group_key = make_key(series.name, "groups", "groups.CELL_SHAPES") 

471 plot_key = make_key(series.name, "plots", "plots.CELL_SHAPES") 

472 keys = [condition["key"] for condition in series.conditions] 

473 

474 for key in keys: 

475 for projection in parameters.projections: 

476 group = load_json( 

477 context.working_location, 

478 make_key(group_key, f"{series.name}.shape_average.{key}.{projection.upper()}.json"), 

479 ) 

480 

481 assert isinstance(group, dict) 

482 

483 elements = [ 

484 {"points": item, "stroke": "#000", "stroke-width": 0.2} 

485 for item in group["original_slice"] 

486 ] 

487 

488 elements = elements + [ 

489 { 

490 "points": item, 

491 "stroke": "#f00", 

492 "stroke-width": 0.2, 

493 "stroke-dasharray": "0.2,0.2", 

494 } 

495 for item in group["reconstructed_slice"] 

496 ] 

497 

498 for extent in group["original_extent"].values(): 

499 elements = elements + [ 

500 { 

501 "points": item, 

502 "stroke": "#999", 

503 "stroke-width": 0.05, 

504 } 

505 for item in extent 

506 ] 

507 

508 rotate = 0 if projection == "top" else 90 

509 

510 save_text( 

511 context.working_location, 

512 make_key(plot_key, f"{series.name}.shape_average.{key}.{projection.upper()}.svg"), 

513 build_svg_image(elements, *parameters.box, rotate, parameters.scale), 

514 ) 

515 

516 

517@flow(name="plot-cell-shapes_plot-shape-errors") 

518def run_flow_plot_shape_errors( 

519 context: ContextConfig, series: SeriesConfig, parameters: ParametersConfigShapeErrors 

520) -> None: 

521 """Plot cell shapes subflow for shape errors.""" 

522 

523 group_key = make_key(series.name, "groups", "groups.CELL_SHAPES") 

524 plot_key = make_key(series.name, "plots", "plots.CELL_SHAPES") 

525 keys = [condition["key"] for condition in series.conditions] 

526 

527 group = load_json( 

528 context.working_location, 

529 make_key(group_key, f"{series.name}.shape_errors.json"), 

530 ) 

531 

532 key_group = {key: group[key] for key in keys} 

533 

534 save_figure( 

535 context.working_location, 

536 make_key(plot_key, f"{series.name}.shape_errors.png"), 

537 make_bar_figure(keys, key_group), 

538 ) 

539 

540 

541@flow(name="plot-cell-shapes_plot-shape-modes") 

542def run_flow_plot_shape_modes( 

543 context: ContextConfig, series: SeriesConfig, parameters: ParametersConfigShapeModes 

544) -> None: 

545 """ 

546 Plot cell shapes subflow for shape modes. 

547 """ 

548 

549 group_key = make_key(series.name, "groups", "groups.CELL_SHAPES") 

550 plot_key = make_key(series.name, "plots", "plots.CELL_SHAPES") 

551 keys = [condition["key"] for condition in series.conditions] 

552 

553 for key in keys: 

554 for component in range(parameters.components): 

555 for projection in parameters.projections: 

556 rotate = 0 if projection == "top" else 90 

557 elements: list[dict] = [] 

558 

559 for region in parameters.regions: 

560 full_key = f"{key}.{region}.PC{component+1}.{projection.upper()}" 

561 

562 group = load_json( 

563 context.working_location, 

564 make_key(group_key, f"{series.name}.shape_modes.{full_key}.json"), 

565 ) 

566 

567 assert isinstance(group, list) 

568 

569 elements = elements + [ 

570 { 

571 "points": item["projection"][0], 

572 "stroke": parameters.colors[region], 

573 "stroke-width": 2, 

574 } 

575 for item in group 

576 if item["point"] == parameters.point 

577 ] 

578 

579 if parameters.point > 0: 

580 point_key = "P" + f"{round(parameters.point*100):03d}" 

581 elif parameters.point < 0: 

582 point_key = "N" + f"{round(-parameters.point*100):03d}" 

583 else: 

584 point_key = "ZERO" 

585 

586 save_text( 

587 context.working_location, 

588 make_key(plot_key, f"{series.name}.shape_modes.{full_key}.{point_key}.svg"), 

589 build_svg_image(elements, *parameters.box, rotate, parameters.scale), 

590 ) 

591 

592 

593@flow(name="plot-cell-shapes_plot-variance-explained") 

594def run_flow_plot_variance_explained( 

595 context: ContextConfig, series: SeriesConfig, parameters: ParametersConfigVarianceExplained 

596) -> None: 

597 """Plot cell shapes subflow for variance explained.""" 

598 

599 group_key = make_key(series.name, "groups", "groups.CELL_SHAPES") 

600 plot_key = make_key(series.name, "plots", "plots.CELL_SHAPES") 

601 keys = [condition["key"] for condition in series.conditions] 

602 

603 group = load_dataframe( 

604 context.working_location, 

605 make_key(group_key, f"{series.name}.variance_explained.csv"), 

606 ) 

607 

608 group_flat = [ 

609 { 

610 "x": [component + 1 for component in range(8)], 

611 "y": np.cumsum(group[group["key"] == key].sort_values("mode")["variance"].values), 

612 "color": parameters.colors[keys.index(key)], 

613 } 

614 for key in keys 

615 ] 

616 

617 save_figure( 

618 context.working_location, 

619 make_key(plot_key, f"{series.name}.variance_explained.png"), 

620 make_line_figure(group_flat), 

621 )