Coverage for src/cell_abm_pipeline/__config__.py: 0%

86 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2024-06-05 19:14 +0000

1import os 

2import re 

3from dataclasses import dataclass, field, fields, make_dataclass 

4from types import ModuleType 

5from typing import Any, Union 

6 

7from hydra import compose, initialize_config_dir 

8from hydra.core.config_store import ConfigStore 

9from omegaconf import MISSING, DictConfig, ListConfig, OmegaConf 

10 

11defaults = [ 

12 "_self_", 

13 {"context": MISSING}, 

14 {"series": MISSING}, 

15 {"parameters": MISSING}, 

16] 

17 

18 

19@dataclass 

20class ContextConfig: 

21 name: str 

22 

23 

24@dataclass 

25class SeriesConfig: 

26 name: str 

27 

28 

29@dataclass 

30class ParametersConfig: 

31 name: str 

32 

33 

34def make_dotlist_from_config(config: dict) -> list[str]: 

35 container = OmegaConf.to_container(OmegaConf.structured(config)) 

36 

37 assert isinstance(container, dict) 

38 

39 queue = list(container.items()) 

40 dotlist = [] 

41 

42 while queue: 

43 key, value = queue.pop() 

44 

45 if isinstance(value, dict): 

46 queue = queue + [(f"{key}.{subkey}", subvalue) for subkey, subvalue in value.items()] 

47 elif isinstance(value, list): 

48 dotlist.append(f"{key}=[{','.join([str(v) for v in value])}]") 

49 elif value is None: 

50 dotlist.append(f"{key}=null") 

51 else: 

52 dotlist.append(f"{key}={value}") 

53 

54 return dotlist 

55 

56 

57def make_config_from_dotlist(module: ModuleType, args: list[str]) -> DictConfig: 

58 context_config = generate_config(module.ContextConfig, "context", args) 

59 series_config = generate_config(module.SeriesConfig, "series", args) 

60 parameters_config = generate_config(module.ParametersConfig, "parameters", args) 

61 

62 config = OmegaConf.create( 

63 { 

64 "context": context_config, 

65 "series": series_config, 

66 "parameters": parameters_config, 

67 } 

68 ) 

69 

70 return config 

71 

72 

73def make_config_from_yaml(module: ModuleType, args: list[str]) -> DictConfig: 

74 config_store = ConfigStore.instance() 

75 

76 config_dataclass = make_dataclass( 

77 "Config", 

78 [ 

79 ("defaults", list[Any], field(default_factory=lambda: defaults)), 

80 ("context", module.ContextConfig, MISSING), 

81 ("series", module.SeriesConfig, MISSING), 

82 ("parameters", module.ParametersConfig, MISSING), 

83 ], 

84 ) 

85 

86 config_store.store(name="config", node=config_dataclass) 

87 

88 config_dir = os.path.join(os.path.abspath(os.getcwd()), "configs") 

89 initialize_config_dir(config_dir, version_base=None) 

90 

91 config = compose(config_name="config", overrides=args) 

92 

93 return config 

94 

95 

96def make_config_from_file(module: ModuleType, file: str) -> DictConfig: 

97 contents = OmegaConf.load(file) 

98 

99 context_config = load_config(module.ContextConfig, contents.context) 

100 series_config = load_config(module.SeriesConfig, contents.series) 

101 parameters_config = load_config(module.ParametersConfig, contents.parameters) 

102 

103 config = OmegaConf.create( 

104 { 

105 "context": context_config, 

106 "series": series_config, 

107 "parameters": parameters_config, 

108 } 

109 ) 

110 

111 return config 

112 

113 

114def generate_config(config_class: Any, group: str, args: list[str]) -> DictConfig: 

115 dotlist = [arg.replace(group, "", 1).strip(".") for arg in args if arg.startswith(group)] 

116 config = OmegaConf.structured(config_class) 

117 config.merge_with_dotlist(dotlist) 

118 return config 

119 

120 

121def load_config( 

122 schema: Any, config: Union[ListConfig, DictConfig] 

123) -> Union[ListConfig, DictConfig]: 

124 config_keys = list(config.keys()) 

125 schema_fields = [field.name for field in fields(schema)] 

126 

127 for key in config_keys: 

128 if key not in schema_fields: 

129 del config[key] 

130 

131 return OmegaConf.merge(schema, config) 

132 

133 

134def display_config(config: DictConfig) -> None: 

135 active = False 

136 config_lines = [] 

137 list_entries = [] 

138 

139 for line in OmegaConf.to_yaml(config, resolve=True).split("\n"): 

140 match = re.findall(r"^[\s]{2,6}\- ([\dA-z\.\*\'\-]+)$", line) 

141 

142 if match and not active: 

143 active = True 

144 list_entries.append(match[0]) 

145 elif not match and active: 

146 active = False 

147 config_lines.append("[" + ", ".join(list_entries) + "]") 

148 list_entries = [] 

149 config_lines.append(line) 

150 elif match and active: 

151 list_entries.append(match[0]) 

152 else: 

153 config_lines.append(line) 

154 

155 config_display = "\n".join(config_lines) 

156 config_display = config_display.replace(":\n[", ": [") 

157 

158 print(config_display)