Coverage for src/cell_abm_pipeline/flows/plot_cell_shapes.py: 0%
228 statements
« prev ^ index » next coverage.py v7.1.0, created at 2024-06-05 19:14 +0000
« prev ^ index » next coverage.py v7.1.0, created at 2024-06-05 19:14 +0000
1"""
2Workflow for plotting cell shapes.
4Working location structure:
6.. code-block:: bash
8 (name)
9 ├── groups
10 │ └── groups.CELL_SHAPES
11 │ ├── (name).feature_correlations.(key).(region).csv
12 │ ├── (name).feature_distributions.(feature).json
13 │ ├── (name).mode_correlations.csv
14 │ ├── (name).population_counts.(tick).csv
15 │ ├── (name).population_stats.json
16 │ ├── (name).shape_average.(key).(projection).json
17 │ ├── (name).shape_errors.json
18 │ ├── (name).shape_modes.(key).(region).(mode).(projection).json
19 │ └── (name).variance_explained.csv
20 └── plots
21 └── plots.CELL_SHAPES
22 ├── (name).feature_correlations.(key).(region).png
23 ├── (name).feature_distributions.(feature).png
24 ├── (name).mode_correlations.(key).(key).png
25 ├── (name).population_counts.(tick).png
26 ├── (name).population_stats.png
27 ├── (name).shape_average.(key).(projection).svg
28 ├── (name).shape_errors.png
29 ├── (name).shape_modes.(key).(region).(mode).(projection).(point).svg
30 └── (name).variance_explained.png
32Plots use grouped data from **groups.CELL_SHAPES**. Plots are saved to
33**plots.CELL_SHAPES**.
34"""
36from dataclasses import dataclass, field
38import numpy as np
39from io_collection.keys import make_key
40from io_collection.load import load_dataframe, load_json
41from io_collection.save import save_figure, save_text
42from prefect import flow
44from cell_abm_pipeline.flows.analyze_cell_shapes import PCA_COMPONENTS
45from cell_abm_pipeline.flows.group_cell_shapes import (
46 CORRELATION_PROPERTIES,
47 DISTRIBUTION_PROPERTIES,
48 PROJECTIONS,
49)
50from cell_abm_pipeline.tasks import (
51 build_svg_image,
52 make_bar_figure,
53 make_heatmap_figure,
54 make_histogram_figure,
55 make_line_figure,
56)
58PLOTS: list[str] = [
59 "feature_correlations",
60 "feature_distributions",
61 "mode_correlations",
62 "population_counts",
63 "population_stats",
64 "shape_average",
65 "shape_errors",
66 "shape_modes",
67 "variance_explained",
68]
71REGION_COLORS: dict[str, str] = {"DEFAULT": "#FF00FF", "NUCLEUS": "#00FFFF"}
73KEY_COLORS: list[str] = [
74 "#7F3C8D",
75 "#11A579",
76 "#3969AC",
77 "#F2B701",
78 "#E73F74",
79 "#80BA5A",
80 "#E68310",
81 "#008695",
82 "#CF1C90",
83 "#f97b72",
84 "#4b4b8f",
85 "#A5AA99",
86]
89@dataclass
90class ParametersConfigFeatureCorrelations:
91 """Parameter configuration for plot cell shapes subflow - feature correlations."""
93 properties: list[str] = field(default_factory=lambda: CORRELATION_PROPERTIES)
94 """List of shape properties."""
96 regions: list[str] = field(default_factory=lambda: ["DEFAULT"])
97 """List of subcellular regions."""
99 components: int = PCA_COMPONENTS
100 """Number of principal components (i.e. shape modes)."""
103@dataclass
104class ParametersConfigFeatureDistributions:
105 """Parameter configuration for plot cell shapes subflow - feature distributions."""
107 properties: list[str] = field(default_factory=lambda: DISTRIBUTION_PROPERTIES)
108 """List of shape properties."""
110 regions: list[str] = field(default_factory=lambda: ["(region)"])
111 """List of subcellular regions."""
113 components: int = PCA_COMPONENTS
114 """Number of principal components (i.e. shape modes)."""
117@dataclass
118class ParametersConfigModeCorrelations:
119 """Parameter configuration for plot cell shapes subflow - mode correlations."""
121 components: int = PCA_COMPONENTS
122 """Number of principal components (i.e. shape modes)."""
125@dataclass
126class ParametersConfigPopulationCounts:
127 """Parameter configuration for plot cell shapes subflow - population counts."""
129 tick: int = 0
130 """Simulation tick to use for plotting population counts."""
133@dataclass
134class ParametersConfigPopulationStats:
135 """Parameter configuration for plot cell shapes subflow - population stats."""
138@dataclass
139class ParametersConfigShapeAverage:
140 """Parameter configuration for plot cell shapes subflow - shape average."""
142 projections: list[str] = field(default_factory=lambda: PROJECTIONS)
143 """List of shape projections."""
145 box: tuple[int, int] = field(default_factory=lambda: (100, 100))
146 """Size of bounding box."""
148 scale: float = 1
149 """Scaling for image."""
152@dataclass
153class ParametersConfigShapeErrors:
154 """Parameter configuration for plot cell shapes subflow - shape errors."""
157@dataclass
158class ParametersConfigShapeModes:
159 """Parameter configuration for plot cell shapes subflow - shape modes."""
161 regions: list[str] = field(default_factory=lambda: ["(region)"])
162 """List of subcellular regions."""
164 components: int = PCA_COMPONENTS
165 """Number of principal components (i.e. shape modes)."""
167 projections: list[str] = field(default_factory=lambda: PROJECTIONS)
168 """List of shape projections."""
170 point: float = 0
171 """Selected shape mode map point."""
173 box: tuple[int, int] = field(default_factory=lambda: (100, 100))
174 """Size of bounding box."""
176 scale: float = 1
177 """Scaling for image."""
179 colors: dict[str, str] = field(default_factory=lambda: REGION_COLORS)
180 """Colors for each region."""
183@dataclass
184class ParametersConfigVarianceExplained:
185 """Parameter configuration for plot cell shapes subflow - variance explained."""
187 components: int = PCA_COMPONENTS
188 """Number of principal components (i.e. shape modes)."""
190 colors: list[str] = field(default_factory=lambda: KEY_COLORS)
191 """Colors for each key."""
194@dataclass
195class ParametersConfig:
196 """Parameter configuration for plot cell shapes flow."""
198 plots: list[str] = field(default_factory=lambda: PLOTS)
199 """List of cell shape plots."""
201 feature_correlations: ParametersConfigFeatureCorrelations = (
202 ParametersConfigFeatureCorrelations()
203 )
204 """Parameters for plot feature correlations subflow."""
206 feature_distributions: ParametersConfigFeatureDistributions = (
207 ParametersConfigFeatureDistributions()
208 )
209 """Parameters for plot feature distributions subflow."""
211 mode_correlations: ParametersConfigModeCorrelations = ParametersConfigModeCorrelations()
212 """Parameters for plot mode correlations subflow."""
214 population_counts: ParametersConfigPopulationCounts = ParametersConfigPopulationCounts()
215 """Parameters for plot population counts subflow."""
217 population_stats: ParametersConfigPopulationStats = ParametersConfigPopulationStats()
218 """Parameters for plot population stats subflow."""
220 shape_average: ParametersConfigShapeAverage = ParametersConfigShapeAverage()
221 """Parameters for plot shape average subflow."""
223 shape_errors: ParametersConfigShapeErrors = ParametersConfigShapeErrors()
224 """Parameters for plot shape errors subflow."""
226 shape_modes: ParametersConfigShapeModes = ParametersConfigShapeModes()
227 """Parameters for plot shape modes subflow."""
229 variance_explained: ParametersConfigVarianceExplained = ParametersConfigVarianceExplained()
230 """Parameters for plot variance explained subflow."""
233@dataclass
234class ContextConfig:
235 """Context configuration for plot cell shapes flow."""
237 working_location: str
238 """Location for input and output files (local path or S3 bucket)."""
241@dataclass
242class SeriesConfig:
243 """Series configuration for plot cell shapes flow."""
245 name: str
246 """Name of the simulation series."""
248 conditions: list[dict]
249 """List of series condition dictionaries (must include unique condition "key")."""
252@flow(name="plot-cell-shapes")
253def run_flow(context: ContextConfig, series: SeriesConfig, parameters: ParametersConfig) -> None:
254 """
255 Main plot cell shapes flow.
257 Calls the following subflows, if the plot is specified:
259 - :py:func:`run_flow_plot_feature_correlations`
260 - :py:func:`run_flow_plot_feature_distributions`
261 - :py:func:`run_flow_plot_mode_correlations`
262 - :py:func:`run_flow_plot_population_counts`
263 - :py:func:`run_flow_plot_population_stats`
264 - :py:func:`run_flow_plot_shape_average`
265 - :py:func:`run_flow_plot_shape_errors`
266 - :py:func:`run_flow_plot_shape_modes`
267 - :py:func:`run_flow_plot_variance_explained`
268 """
270 if "feature_correlations" in parameters.plots:
271 run_flow_plot_feature_correlations(context, series, parameters.feature_correlations)
273 if "feature_distributions" in parameters.plots:
274 run_flow_plot_feature_distributions(context, series, parameters.feature_distributions)
276 if "mode_correlations" in parameters.plots:
277 run_flow_plot_mode_correlations(context, series, parameters.mode_correlations)
279 if "population_counts" in parameters.plots:
280 run_flow_plot_population_counts(context, series, parameters.population_counts)
282 if "population_stats" in parameters.plots:
283 run_flow_plot_population_stats(context, series, parameters.population_stats)
285 if "shape_average" in parameters.plots:
286 run_flow_plot_shape_average(context, series, parameters.shape_average)
288 if "shape_errors" in parameters.plots:
289 run_flow_plot_shape_errors(context, series, parameters.shape_errors)
291 if "shape_modes" in parameters.plots:
292 run_flow_plot_shape_modes(context, series, parameters.shape_modes)
294 if "variance_explained" in parameters.plots:
295 run_flow_plot_variance_explained(context, series, parameters.variance_explained)
298@flow(name="plot-cell-shapes_plot-feature-correlations")
299def run_flow_plot_feature_correlations(
300 context: ContextConfig, series: SeriesConfig, parameters: ParametersConfigFeatureCorrelations
301) -> None:
302 """Plot cell shapes subflow for feature correlations."""
304 group_key = make_key(series.name, "groups", "groups.CELL_SHAPES")
305 plot_key = make_key(series.name, "plots", "plots.CELL_SHAPES")
306 keys = [condition["key"] for condition in series.conditions]
308 modes = [f"PC{component + 1}" for component in range(parameters.components)]
309 properties = [prop.upper() for prop in parameters.properties]
311 for key in keys:
312 for region in parameters.regions:
313 group = load_dataframe(
314 context.working_location,
315 make_key(group_key, f"{series.name}.feature_correlations.{key}.{region}.csv"),
316 )
318 group_sorted = group.set_index(["property", "mode"]).sort_index()
319 group_values = [
320 [abs(group_sorted.loc[prop, mode]["correlation"]) for mode in modes]
321 for prop in properties
322 ]
324 save_figure(
325 context.working_location,
326 make_key(plot_key, f"{series.name}.feature_correlations.{key}.{region}.png"),
327 make_heatmap_figure(properties, modes, group_values),
328 )
331@flow(name="plot-cell-shapes_plot-feature-distributions")
332def run_flow_plot_feature_distributions(
333 context: ContextConfig, series: SeriesConfig, parameters: ParametersConfigFeatureDistributions
334) -> None:
335 """Plot cell shapes subflow for feature distributions."""
337 group_key = make_key(series.name, "groups", "groups.CELL_SHAPES")
338 plot_key = make_key(series.name, "plots", "plots.CELL_SHAPES")
339 keys = [condition["key"] for condition in series.conditions]
341 features = [
342 f"{prop}.{region}" for prop in parameters.properties for region in parameters.regions
343 ] + [f"PC{component + 1}" for component in range(parameters.components)]
345 for feature in features:
346 feature_key = feature.upper()
348 group = load_json(
349 context.working_location,
350 make_key(group_key, f"{series.name}.feature_distributions.{feature_key}.json"),
351 )
353 assert isinstance(group, dict)
355 save_figure(
356 context.working_location,
357 make_key(plot_key, f"{series.name}.feature_distributions.{feature_key}.png"),
358 make_histogram_figure(keys, group),
359 )
362@flow(name="plot-cell-shapes_plot-mode-correlations")
363def run_flow_plot_mode_correlations(
364 context: ContextConfig, series: SeriesConfig, parameters: ParametersConfigModeCorrelations
365) -> None:
366 """Plot cell shapes subflow for mode correlations."""
368 group_key = make_key(series.name, "groups", "groups.CELL_SHAPES")
369 plot_key = make_key(series.name, "plots", "plots.CELL_SHAPES")
370 keys = ["reference"] + [condition["key"] for condition in series.conditions]
372 modes = [f"PC{component + 1}" for component in range(parameters.components)]
374 group = load_dataframe(
375 context.working_location,
376 make_key(group_key, f"{series.name}.mode_correlations.csv"),
377 )
379 for source_key in keys:
380 for target_key in keys:
381 if source_key == target_key:
382 continue
384 group_sorted = (
385 group[(group["source_key"] == source_key) & (group["target_key"] == target_key)]
386 .set_index(["source_mode", "target_mode"])
387 .sort_index()
388 )
389 group_values = [
390 [
391 abs(group_sorted.loc[source_mode, target_mode]["correlation"])
392 for target_mode in modes
393 ]
394 for source_mode in modes
395 ]
397 save_figure(
398 context.working_location,
399 make_key(
400 plot_key, f"{series.name}.mode_correlations.{source_key}.{target_key}.png"
401 ),
402 make_heatmap_figure(modes, modes, group_values),
403 )
406@flow(name="plot-cell-shapes_plot-population-counts")
407def run_flow_plot_population_counts(
408 context: ContextConfig, series: SeriesConfig, parameters: ParametersConfigPopulationCounts
409) -> None:
410 """Plot cell shapes subflow for population counts."""
412 group_key = make_key(series.name, "groups", "groups.CELL_SHAPES")
413 plot_key = make_key(series.name, "plots", "plots.CELL_SHAPES")
414 keys = [condition["key"] for condition in series.conditions]
416 group = load_dataframe(
417 context.working_location,
418 make_key(group_key, f"{series.name}.population_counts.{parameters.tick:06d}.csv"),
419 )
421 key_group = {
422 key: {
423 "COUNT": {
424 "mean": group[group["key"] == key]["count"].mean(),
425 "std": group[group["key"] == key]["count"].std(ddof=1),
426 }
427 }
428 for key in keys
429 }
431 save_figure(
432 context.working_location,
433 make_key(plot_key, f"{series.name}.population_counts.{parameters.tick:06d}.png"),
434 make_bar_figure(keys, key_group),
435 )
438@flow(name="plot-cell-shapes_plot-population-stats")
439def run_flow_plot_population_stats(
440 context: ContextConfig, series: SeriesConfig, parameters: ParametersConfigPopulationStats
441) -> None:
442 """Plot cell shapes subflow for population stats."""
444 group_key = make_key(series.name, "groups", "groups.CELL_SHAPES")
445 plot_key = make_key(series.name, "plots", "plots.CELL_SHAPES")
446 keys = [condition["key"] for condition in series.conditions]
448 group = load_json(
449 context.working_location,
450 make_key(group_key, f"{series.name}.population_stats.json"),
451 )
453 key_group = {key: group[key] for key in keys}
455 save_figure(
456 context.working_location,
457 make_key(plot_key, f"{series.name}.population_stats.png"),
458 make_bar_figure(keys, key_group),
459 )
462@flow(name="plot-cell-shapes_plot-shape-average")
463def run_flow_plot_shape_average(
464 context: ContextConfig, series: SeriesConfig, parameters: ParametersConfigShapeAverage
465) -> None:
466 """
467 Plot cell shapes subflow for shape average.
468 """
470 group_key = make_key(series.name, "groups", "groups.CELL_SHAPES")
471 plot_key = make_key(series.name, "plots", "plots.CELL_SHAPES")
472 keys = [condition["key"] for condition in series.conditions]
474 for key in keys:
475 for projection in parameters.projections:
476 group = load_json(
477 context.working_location,
478 make_key(group_key, f"{series.name}.shape_average.{key}.{projection.upper()}.json"),
479 )
481 assert isinstance(group, dict)
483 elements = [
484 {"points": item, "stroke": "#000", "stroke-width": 0.2}
485 for item in group["original_slice"]
486 ]
488 elements = elements + [
489 {
490 "points": item,
491 "stroke": "#f00",
492 "stroke-width": 0.2,
493 "stroke-dasharray": "0.2,0.2",
494 }
495 for item in group["reconstructed_slice"]
496 ]
498 for extent in group["original_extent"].values():
499 elements = elements + [
500 {
501 "points": item,
502 "stroke": "#999",
503 "stroke-width": 0.05,
504 }
505 for item in extent
506 ]
508 rotate = 0 if projection == "top" else 90
510 save_text(
511 context.working_location,
512 make_key(plot_key, f"{series.name}.shape_average.{key}.{projection.upper()}.svg"),
513 build_svg_image(elements, *parameters.box, rotate, parameters.scale),
514 )
517@flow(name="plot-cell-shapes_plot-shape-errors")
518def run_flow_plot_shape_errors(
519 context: ContextConfig, series: SeriesConfig, parameters: ParametersConfigShapeErrors
520) -> None:
521 """Plot cell shapes subflow for shape errors."""
523 group_key = make_key(series.name, "groups", "groups.CELL_SHAPES")
524 plot_key = make_key(series.name, "plots", "plots.CELL_SHAPES")
525 keys = [condition["key"] for condition in series.conditions]
527 group = load_json(
528 context.working_location,
529 make_key(group_key, f"{series.name}.shape_errors.json"),
530 )
532 key_group = {key: group[key] for key in keys}
534 save_figure(
535 context.working_location,
536 make_key(plot_key, f"{series.name}.shape_errors.png"),
537 make_bar_figure(keys, key_group),
538 )
541@flow(name="plot-cell-shapes_plot-shape-modes")
542def run_flow_plot_shape_modes(
543 context: ContextConfig, series: SeriesConfig, parameters: ParametersConfigShapeModes
544) -> None:
545 """
546 Plot cell shapes subflow for shape modes.
547 """
549 group_key = make_key(series.name, "groups", "groups.CELL_SHAPES")
550 plot_key = make_key(series.name, "plots", "plots.CELL_SHAPES")
551 keys = [condition["key"] for condition in series.conditions]
553 for key in keys:
554 for component in range(parameters.components):
555 for projection in parameters.projections:
556 rotate = 0 if projection == "top" else 90
557 elements: list[dict] = []
559 for region in parameters.regions:
560 full_key = f"{key}.{region}.PC{component+1}.{projection.upper()}"
562 group = load_json(
563 context.working_location,
564 make_key(group_key, f"{series.name}.shape_modes.{full_key}.json"),
565 )
567 assert isinstance(group, list)
569 elements = elements + [
570 {
571 "points": item["projection"][0],
572 "stroke": parameters.colors[region],
573 "stroke-width": 2,
574 }
575 for item in group
576 if item["point"] == parameters.point
577 ]
579 if parameters.point > 0:
580 point_key = "P" + f"{round(parameters.point*100):03d}"
581 elif parameters.point < 0:
582 point_key = "N" + f"{round(-parameters.point*100):03d}"
583 else:
584 point_key = "ZERO"
586 save_text(
587 context.working_location,
588 make_key(plot_key, f"{series.name}.shape_modes.{full_key}.{point_key}.svg"),
589 build_svg_image(elements, *parameters.box, rotate, parameters.scale),
590 )
593@flow(name="plot-cell-shapes_plot-variance-explained")
594def run_flow_plot_variance_explained(
595 context: ContextConfig, series: SeriesConfig, parameters: ParametersConfigVarianceExplained
596) -> None:
597 """Plot cell shapes subflow for variance explained."""
599 group_key = make_key(series.name, "groups", "groups.CELL_SHAPES")
600 plot_key = make_key(series.name, "plots", "plots.CELL_SHAPES")
601 keys = [condition["key"] for condition in series.conditions]
603 group = load_dataframe(
604 context.working_location,
605 make_key(group_key, f"{series.name}.variance_explained.csv"),
606 )
608 group_flat = [
609 {
610 "x": [component + 1 for component in range(8)],
611 "y": np.cumsum(group[group["key"] == key].sort_values("mode")["variance"].values),
612 "color": parameters.colors[keys.index(key)],
613 }
614 for key in keys
615 ]
617 save_figure(
618 context.working_location,
619 make_key(plot_key, f"{series.name}.variance_explained.png"),
620 make_line_figure(group_flat),
621 )