atropos/environments/optimizer/FOB/pytorch_fob/engine/parser.py
2025-05-18 16:36:28 -07:00

65 lines
2.4 KiB
Python

from pathlib import Path
from typing import Any, Iterable, Optional
import re
import yaml
class YAMLParser():
def __init__(self) -> None:
pass
def parse_yaml(self, file: Path) -> Any:
"""
Opens and parses a YAML file.
"""
with open(file, "r", encoding="utf8") as f:
return yaml.safe_load(f)
def parse_yamls_and_extra_args(self,
default_yaml: Path,
custom_yaml: Optional[Path],
additional_args: Iterable[str] = tuple()
) -> dict:
"""assumes that there is a dict in the yaml"""
config_to_use = self.parse_yaml(default_yaml)
if custom_yaml is not None:
user_yaml = self.parse_yaml(custom_yaml)
# merge in place
self.merge_dicts_hierarchical(lo=config_to_use, hi=user_yaml)
self.parse_args_into_searchspace(config_to_use, additional_args)
return config_to_use
def parse_args_into_searchspace(self, searchspace: dict[str, Any], args: Iterable[str]):
"""
Overwrites args given in the form of 'this.that=something'. Also supports lists: 'this.that[0]=something'
"""
for arg in args:
self._parse_arg_into_searchspace(searchspace, arg)
def _parse_arg_into_searchspace(self, searchspace: dict[str, Any], arg: str):
keys, value = arg.split("=")
keys = keys.split(".")
keys_with_list_indices = []
for key in keys:
match = re.search(r"^(.*?)\[(\-?\d+)\]$", key)
if match:
keys_with_list_indices.append(match.group(1))
keys_with_list_indices.append(int(match.group(2)))
else:
keys_with_list_indices.append(key)
target = searchspace
for key in keys_with_list_indices[:-1]:
if isinstance(target, dict) and key not in target:
target[key] = {}
target = target[key]
target[keys_with_list_indices[-1]] = yaml.safe_load(value)
def merge_dicts_hierarchical(self, lo: dict, hi: dict):
"""
Overwrites values in `lo` with values from `hi` if they are present in both/
"""
for k, v in hi.items():
if isinstance(v, dict) and isinstance(lo.get(k, None), dict):
self.merge_dicts_hierarchical(lo[k], v)
else:
lo[k] = v