mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-25 17:10:42 +00:00
Convert FOB submodule to regular folder
This commit is contained in:
parent
94f046ad40
commit
94825011a0
74 changed files with 4563 additions and 0 deletions
65
environments/optimizer/FOB/pytorch_fob/engine/parser.py
Normal file
65
environments/optimizer/FOB/pytorch_fob/engine/parser.py
Normal file
|
|
@ -0,0 +1,65 @@
|
|||
from pathlib import Path
|
||||
from typing import Any, Iterable, Optional
|
||||
import re
|
||||
import yaml
|
||||
|
||||
|
||||
class YAMLParser():
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
def parse_yaml(self, file: Path) -> Any:
|
||||
"""
|
||||
Opens and parses a YAML file.
|
||||
"""
|
||||
with open(file, "r", encoding="utf8") as f:
|
||||
return yaml.safe_load(f)
|
||||
|
||||
def parse_yamls_and_extra_args(self,
|
||||
default_yaml: Path,
|
||||
custom_yaml: Optional[Path],
|
||||
additional_args: Iterable[str] = tuple()
|
||||
) -> dict:
|
||||
"""assumes that there is a dict in the yaml"""
|
||||
config_to_use = self.parse_yaml(default_yaml)
|
||||
if custom_yaml is not None:
|
||||
user_yaml = self.parse_yaml(custom_yaml)
|
||||
# merge in place
|
||||
self.merge_dicts_hierarchical(lo=config_to_use, hi=user_yaml)
|
||||
self.parse_args_into_searchspace(config_to_use, additional_args)
|
||||
return config_to_use
|
||||
|
||||
def parse_args_into_searchspace(self, searchspace: dict[str, Any], args: Iterable[str]):
|
||||
"""
|
||||
Overwrites args given in the form of 'this.that=something'. Also supports lists: 'this.that[0]=something'
|
||||
"""
|
||||
for arg in args:
|
||||
self._parse_arg_into_searchspace(searchspace, arg)
|
||||
|
||||
def _parse_arg_into_searchspace(self, searchspace: dict[str, Any], arg: str):
|
||||
keys, value = arg.split("=")
|
||||
keys = keys.split(".")
|
||||
keys_with_list_indices = []
|
||||
for key in keys:
|
||||
match = re.search(r"^(.*?)\[(\-?\d+)\]$", key)
|
||||
if match:
|
||||
keys_with_list_indices.append(match.group(1))
|
||||
keys_with_list_indices.append(int(match.group(2)))
|
||||
else:
|
||||
keys_with_list_indices.append(key)
|
||||
target = searchspace
|
||||
for key in keys_with_list_indices[:-1]:
|
||||
if isinstance(target, dict) and key not in target:
|
||||
target[key] = {}
|
||||
target = target[key]
|
||||
target[keys_with_list_indices[-1]] = yaml.safe_load(value)
|
||||
|
||||
def merge_dicts_hierarchical(self, lo: dict, hi: dict):
|
||||
"""
|
||||
Overwrites values in `lo` with values from `hi` if they are present in both/
|
||||
"""
|
||||
for k, v in hi.items():
|
||||
if isinstance(v, dict) and isinstance(lo.get(k, None), dict):
|
||||
self.merge_dicts_hierarchical(lo[k], v)
|
||||
else:
|
||||
lo[k] = v
|
||||
Loading…
Add table
Add a link
Reference in a new issue