Coverage for src/cell_abm_pipeline/__config__.py: 0%
86 statements
« prev ^ index » next coverage.py v7.1.0, created at 2024-06-05 19:14 +0000
« prev ^ index » next coverage.py v7.1.0, created at 2024-06-05 19:14 +0000
1import os
2import re
3from dataclasses import dataclass, field, fields, make_dataclass
4from types import ModuleType
5from typing import Any, Union
7from hydra import compose, initialize_config_dir
8from hydra.core.config_store import ConfigStore
9from omegaconf import MISSING, DictConfig, ListConfig, OmegaConf
11defaults = [
12 "_self_",
13 {"context": MISSING},
14 {"series": MISSING},
15 {"parameters": MISSING},
16]
19@dataclass
20class ContextConfig:
21 name: str
24@dataclass
25class SeriesConfig:
26 name: str
29@dataclass
30class ParametersConfig:
31 name: str
34def make_dotlist_from_config(config: dict) -> list[str]:
35 container = OmegaConf.to_container(OmegaConf.structured(config))
37 assert isinstance(container, dict)
39 queue = list(container.items())
40 dotlist = []
42 while queue:
43 key, value = queue.pop()
45 if isinstance(value, dict):
46 queue = queue + [(f"{key}.{subkey}", subvalue) for subkey, subvalue in value.items()]
47 elif isinstance(value, list):
48 dotlist.append(f"{key}=[{','.join([str(v) for v in value])}]")
49 elif value is None:
50 dotlist.append(f"{key}=null")
51 else:
52 dotlist.append(f"{key}={value}")
54 return dotlist
57def make_config_from_dotlist(module: ModuleType, args: list[str]) -> DictConfig:
58 context_config = generate_config(module.ContextConfig, "context", args)
59 series_config = generate_config(module.SeriesConfig, "series", args)
60 parameters_config = generate_config(module.ParametersConfig, "parameters", args)
62 config = OmegaConf.create(
63 {
64 "context": context_config,
65 "series": series_config,
66 "parameters": parameters_config,
67 }
68 )
70 return config
73def make_config_from_yaml(module: ModuleType, args: list[str]) -> DictConfig:
74 config_store = ConfigStore.instance()
76 config_dataclass = make_dataclass(
77 "Config",
78 [
79 ("defaults", list[Any], field(default_factory=lambda: defaults)),
80 ("context", module.ContextConfig, MISSING),
81 ("series", module.SeriesConfig, MISSING),
82 ("parameters", module.ParametersConfig, MISSING),
83 ],
84 )
86 config_store.store(name="config", node=config_dataclass)
88 config_dir = os.path.join(os.path.abspath(os.getcwd()), "configs")
89 initialize_config_dir(config_dir, version_base=None)
91 config = compose(config_name="config", overrides=args)
93 return config
96def make_config_from_file(module: ModuleType, file: str) -> DictConfig:
97 contents = OmegaConf.load(file)
99 context_config = load_config(module.ContextConfig, contents.context)
100 series_config = load_config(module.SeriesConfig, contents.series)
101 parameters_config = load_config(module.ParametersConfig, contents.parameters)
103 config = OmegaConf.create(
104 {
105 "context": context_config,
106 "series": series_config,
107 "parameters": parameters_config,
108 }
109 )
111 return config
114def generate_config(config_class: Any, group: str, args: list[str]) -> DictConfig:
115 dotlist = [arg.replace(group, "", 1).strip(".") for arg in args if arg.startswith(group)]
116 config = OmegaConf.structured(config_class)
117 config.merge_with_dotlist(dotlist)
118 return config
121def load_config(
122 schema: Any, config: Union[ListConfig, DictConfig]
123) -> Union[ListConfig, DictConfig]:
124 config_keys = list(config.keys())
125 schema_fields = [field.name for field in fields(schema)]
127 for key in config_keys:
128 if key not in schema_fields:
129 del config[key]
131 return OmegaConf.merge(schema, config)
134def display_config(config: DictConfig) -> None:
135 active = False
136 config_lines = []
137 list_entries = []
139 for line in OmegaConf.to_yaml(config, resolve=True).split("\n"):
140 match = re.findall(r"^[\s]{2,6}\- ([\dA-z\.\*\'\-]+)$", line)
142 if match and not active:
143 active = True
144 list_entries.append(match[0])
145 elif not match and active:
146 active = False
147 config_lines.append("[" + ", ".join(list_entries) + "]")
148 list_entries = []
149 config_lines.append(line)
150 elif match and active:
151 list_entries.append(match[0])
152 else:
153 config_lines.append(line)
155 config_display = "\n".join(config_lines)
156 config_display = config_display.replace(":\n[", ": [")
158 print(config_display)