mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-24 17:04:55 +00:00
65 lines
2.4 KiB
Python
65 lines
2.4 KiB
Python
from pathlib import Path
|
|
from typing import Any, Iterable, Optional
|
|
import re
|
|
import yaml
|
|
|
|
|
|
class YAMLParser():
|
|
def __init__(self) -> None:
|
|
pass
|
|
|
|
def parse_yaml(self, file: Path) -> Any:
|
|
"""
|
|
Opens and parses a YAML file.
|
|
"""
|
|
with open(file, "r", encoding="utf8") as f:
|
|
return yaml.safe_load(f)
|
|
|
|
def parse_yamls_and_extra_args(self,
|
|
default_yaml: Path,
|
|
custom_yaml: Optional[Path],
|
|
additional_args: Iterable[str] = tuple()
|
|
) -> dict:
|
|
"""assumes that there is a dict in the yaml"""
|
|
config_to_use = self.parse_yaml(default_yaml)
|
|
if custom_yaml is not None:
|
|
user_yaml = self.parse_yaml(custom_yaml)
|
|
# merge in place
|
|
self.merge_dicts_hierarchical(lo=config_to_use, hi=user_yaml)
|
|
self.parse_args_into_searchspace(config_to_use, additional_args)
|
|
return config_to_use
|
|
|
|
def parse_args_into_searchspace(self, searchspace: dict[str, Any], args: Iterable[str]):
|
|
"""
|
|
Overwrites args given in the form of 'this.that=something'. Also supports lists: 'this.that[0]=something'
|
|
"""
|
|
for arg in args:
|
|
self._parse_arg_into_searchspace(searchspace, arg)
|
|
|
|
def _parse_arg_into_searchspace(self, searchspace: dict[str, Any], arg: str):
|
|
keys, value = arg.split("=")
|
|
keys = keys.split(".")
|
|
keys_with_list_indices = []
|
|
for key in keys:
|
|
match = re.search(r"^(.*?)\[(\-?\d+)\]$", key)
|
|
if match:
|
|
keys_with_list_indices.append(match.group(1))
|
|
keys_with_list_indices.append(int(match.group(2)))
|
|
else:
|
|
keys_with_list_indices.append(key)
|
|
target = searchspace
|
|
for key in keys_with_list_indices[:-1]:
|
|
if isinstance(target, dict) and key not in target:
|
|
target[key] = {}
|
|
target = target[key]
|
|
target[keys_with_list_indices[-1]] = yaml.safe_load(value)
|
|
|
|
def merge_dicts_hierarchical(self, lo: dict, hi: dict):
|
|
"""
|
|
Overwrites values in `lo` with values from `hi` if they are present in both/
|
|
"""
|
|
for k, v in hi.items():
|
|
if isinstance(v, dict) and isinstance(lo.get(k, None), dict):
|
|
self.merge_dicts_hierarchical(lo[k], v)
|
|
else:
|
|
lo[k] = v
|