Coverage for src/io_collection/save/save_figure.py: 100%
24 statements
« prev ^ index » next coverage.py v7.5.1, created at 2024-09-25 19:09 +0000
« prev ^ index » next coverage.py v7.5.1, created at 2024-09-25 19:09 +0000
1from __future__ import annotations
3import io
4from pathlib import Path
5from typing import TYPE_CHECKING
7from io_collection.save.save_buffer import _save_buffer_to_s3
9if TYPE_CHECKING:
10 import matplotlib.figure as mpl
13EXTENSIONS = (".png", ".jpeg", ".jpg", ".svg")
15CONTENT_TYPES = {
16 "png": "image/png",
17 "jpeg": "image/jpeg",
18 "jpg": "image/jpeg",
19 "svg": "image/svg+xml",
20}
23def save_figure(
24 location: str, key: str, figure: mpl.Figure, **kwargs: bool | str | float | list
25) -> None:
26 """
27 Save matplotlib figure to key at specified location.
29 Method will save to the S3 bucket if the location begins with the **s3://**
30 protocol, otherwise it assumes the location is a local path.
32 Parameters
33 ----------
34 location
35 Object location (local path or S3 bucket).
36 key
37 Object key ending in `.png`, `.jpg`, or `.svg`.
38 figure
39 Figure instance to save.
40 **kwargs
41 Additional parameters for saving figure. The keyword arguments are
42 passed to `matplotlib.pyplot.savefig`.
43 """
45 if not key.endswith(EXTENSIONS):
46 extensions = " | ".join([ext[1:] for ext in EXTENSIONS])
47 message = f"key [ {key} ] must have [ {extensions} ] extension"
48 raise ValueError(message)
50 if location[:5] == "s3://":
51 _save_figure_to_s3(location[5:], key, figure, **kwargs)
52 else:
53 _save_figure_to_fs(location, key, figure, **kwargs)
56def _save_figure_to_fs(
57 path: str, key: str, figure: mpl.Figure, **kwargs: bool | str | float | list
58) -> None:
59 """
60 Save matplotlib figure to key on local file system.
62 Parameters
63 ----------
64 path
65 Local object path.
66 key
67 Object key ending in `.png`, `.jpg`, or `.svg`.
68 figure
69 Figure instance to save.
70 **kwargs
71 Additional parameters for saving figure. The keyword arguments are
72 passed to `matplotlib.pyplot.savefig`.
73 """
75 full_path = Path(path) / key
76 full_path.parent.mkdir(parents=True, exist_ok=True)
77 figure.savefig(full_path, **kwargs) # type: ignore[arg-type]
80def _save_figure_to_s3(
81 bucket: str, key: str, figure: mpl.Figure, **kwargs: bool | str | float | list
82) -> None:
83 """
84 Save matplotlib figure to key in AWS S3 bucket.
86 Parameters
87 ----------
88 bucket
89 AWS S3 bucket name.
90 key
91 Object key ending in `.png`, `.jpg`, or `.svg`.
92 figure
93 Figure instance to save.
94 **kwargs
95 Additional parameters for saving figure. The keyword arguments are
96 passed to `matplotlib.pyplot.savefig`.
97 """
99 with io.BytesIO() as buffer:
100 figure.savefig(buffer, **kwargs) # type: ignore[arg-type]
101 content_type = CONTENT_TYPES[key.split(".")[-1]]
102 _save_buffer_to_s3(bucket, key, buffer, content_type)