Coverage for src/cell_abm_pipeline/tasks/make_centroids_figure.py: 0%

40 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2024-06-05 19:14 +0000

1import matplotlib.colors as mcolors 

2import matplotlib.figure as mpl 

3import matplotlib.pyplot as plt 

4import numpy as np 

5import pandas as pd 

6from prefect import task 

7 

8 

9@task 

10def make_centroids_figure( 

11 data: pd.DataFrame, 

12 frame: int, 

13 xlim: tuple[float, float], 

14 ylim: tuple[float, float], 

15 dt: float, 

16 window: int, 

17) -> mpl.Figure: 

18 fig = plt.figure(figsize=(4, 4), constrained_layout=True) 

19 

20 ax = fig.add_subplot() 

21 

22 ax.set_box_aspect(1) 

23 ax.invert_yaxis() 

24 ax.get_xaxis().set_ticks([]) 

25 ax.get_yaxis().set_ticks([]) 

26 ax.set_xlim(xlim) 

27 ax.set_ylim((ylim[1], ylim[0])) 

28 

29 ticks = sorted(data["TICK"].unique()) 

30 index = ticks.index(frame) 

31 

32 lower = ticks[max(0, index - window)] 

33 upper = ticks[index] 

34 subset = data[(data["TICK"] <= upper) & (data["TICK"] >= lower)] 

35 

36 ticks = subset["TICK"].values 

37 x = subset["CENTER_X"].values 

38 y = subset["CENTER_Y"].values 

39 

40 if index == 0: 

41 sizes = 2 * np.ones((len(ticks))) 

42 colors = np.ones((len(ticks))) 

43 else: 

44 sizes = 1 + (ticks - lower) / (upper - lower) 

45 colors = (ticks - lower) / (upper - lower) 

46 

47 colors[subset["PHASE"] == "PROLIFERATIVE_M"] *= -1 

48 

49 cmap_base = plt.cm.bone_r(np.linspace(0.2, 1, 128)) 

50 cmap_accent = plt.cm.Reds_r(np.linspace(0.2, 1, 128)) 

51 cmap_combined = np.vstack((cmap_accent, cmap_base)) 

52 cmap = mcolors.LinearSegmentedColormap.from_list("custom", cmap_combined) 

53 

54 sizes = sizes**2 

55 ax.scatter(x, y, s=sizes, c=colors, cmap=cmap, vmin=-1, vmax=1) 

56 

57 hours, minutes = divmod(round(upper * dt, 2), 1) 

58 timestamp = f"{int(hours):02d}H:{round(minutes*60):02d}M" 

59 

60 ax.text( 

61 0.02, 

62 0.02, 

63 timestamp, 

64 fontfamily="monospace", 

65 fontsize=20, 

66 fontweight="bold", 

67 transform=ax.transAxes, 

68 ) 

69 

70 return fig