Coverage for src/cell_abm_pipeline/tasks/make_bar_figure.py: 0%

21 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2024-06-05 19:14 +0000

1import matplotlib.figure as mpl 

2import matplotlib.pyplot as plt 

3import numpy as np 

4from prefect import task 

5 

6 

7@task 

8def make_bar_figure(keys: list[str], data: dict) -> mpl.Figure: 

9 fig = plt.figure(figsize=(4, 4), constrained_layout=True) 

10 

11 ax = fig.add_subplot() 

12 ax.set_box_aspect(1) 

13 

14 width = 0.9 / len(keys) 

15 offset = (width * (len(keys) - 1)) / 2 

16 labels = list(data[keys[0]].keys()) 

17 

18 for index, key in enumerate(keys): 

19 positions = np.arange(len(labels)) + index * width 

20 means = [data[key][label]["mean"] for label in labels] 

21 stds = [data[key][label]["std"] for label in labels] 

22 

23 # Replace nans with 0. 

24 means = [0 if np.isnan(mean) else mean for mean in means] 

25 stds = [0 if np.isnan(std) else std for std in stds] 

26 

27 ax.bar(positions, means, yerr=stds, width=width) 

28 

29 ax.set_xticks(np.arange(len(labels)) + offset, labels, rotation=90) 

30 

31 return fig