mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-23 16:54:56 +00:00
Convert FOB submodule to regular folder
This commit is contained in:
parent
94f046ad40
commit
94825011a0
74 changed files with 4563 additions and 0 deletions
100
environments/optimizer/FOB/fix_folder_names.py
Normal file
100
environments/optimizer/FOB/fix_folder_names.py
Normal file
|
|
@ -0,0 +1,100 @@
|
|||
"""
|
||||
This tool fixes folder names which are incorrect due to changes in the default.yaml
|
||||
"""
|
||||
import sys
|
||||
import argparse
|
||||
from pprint import pprint
|
||||
from pathlib import Path
|
||||
import yaml
|
||||
from pytorch_fob.engine.engine import Engine
|
||||
|
||||
|
||||
def deep_diff(dict1, dict2):
|
||||
diff = {}
|
||||
|
||||
# Check keys in dict1 but not in dict2
|
||||
for key in dict1:
|
||||
if key not in dict2:
|
||||
if dict1[key] is not None:
|
||||
diff[key] = {'old_value': dict1[key], 'new_value': None}
|
||||
elif dict1[key] != dict2[key]:
|
||||
if isinstance(dict1[key], dict) and isinstance(dict2[key], dict):
|
||||
nested_diff = deep_diff(dict1[key], dict2[key])
|
||||
if nested_diff:
|
||||
diff[key] = nested_diff
|
||||
else:
|
||||
diff[key] = {'old_value': dict1[key], 'new_value': dict2[key]}
|
||||
|
||||
# Check keys in dict2 but not in dict1
|
||||
for key in dict2:
|
||||
if key not in dict1:
|
||||
if dict2[key] is not None:
|
||||
diff[key] = {'old_value': None, 'new_value': dict2[key]}
|
||||
|
||||
return diff
|
||||
|
||||
|
||||
def fix_recursive(path: Path, dry_run: bool, ignore_config_diff: bool):
|
||||
for file in path.iterdir():
|
||||
if file.name == "config.yaml":
|
||||
e = Engine()
|
||||
e.parse_experiment_from_file(file)
|
||||
runs = list(e.runs())
|
||||
if len(runs) != 1:
|
||||
print("Error config.yaml is invalid:", file, file=sys.stderr)
|
||||
sys.exit(1)
|
||||
r = runs[0]
|
||||
target = r.run_dir.name
|
||||
actual = file.parent.name
|
||||
if actual == target:
|
||||
continue
|
||||
target = file.parent.parent / target
|
||||
actual = file.parent
|
||||
print("folder name is wrong and needs fixing:\ncurrent:", actual, "\ncalculated:", target)
|
||||
computed_config = r.export_config_dict()
|
||||
with open(file, "r", encoding="utf8") as f:
|
||||
actual_config = yaml.safe_load(f)
|
||||
clean_computed_config = {"engine": {"devices": computed_config["engine"]["devices"]},
|
||||
"task": computed_config["task"],
|
||||
"optimizer": computed_config["optimizer"]}
|
||||
clean_actual_config = {"engine": {"devices": actual_config["engine"]["devices"]},
|
||||
"task": actual_config["task"],
|
||||
"optimizer": actual_config["optimizer"]}
|
||||
diff = deep_diff(clean_actual_config, clean_computed_config)
|
||||
if diff and (not ignore_config_diff):
|
||||
print("warning config dict differs!:")
|
||||
pprint(diff)
|
||||
print("skipping folder!")
|
||||
continue
|
||||
if not dry_run:
|
||||
print("renaming...")
|
||||
if target.exists():
|
||||
print("target path already exists, skipping...")
|
||||
continue
|
||||
actual.rename(target)
|
||||
|
||||
elif file.is_dir():
|
||||
fix_recursive(file, dry_run, ignore_config_diff)
|
||||
|
||||
def main(args: argparse.Namespace):
|
||||
base_folder: Path = args.base_folder
|
||||
ignore_config_diff = args.ignore_config_diff
|
||||
if ignore_config_diff:
|
||||
res = input("WARNING: ignoring the config dict diffs can be dangerous do you know what you are doing? [y/n]")
|
||||
if res.strip().lower() != "y":
|
||||
return
|
||||
fix_recursive(base_folder, args.dry_run, ignore_config_diff)
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="folder name fix tool"
|
||||
)
|
||||
parser.add_argument("base_folder", type=Path,
|
||||
help="Folder with experiments, will check recursively")
|
||||
parser.add_argument("--dry_run", action="store_true",
|
||||
help="Just print what would be changed, do not change any files")
|
||||
parser.add_argument("--ignore_config_diff", action="store_true",
|
||||
help="Ignores config dict difference WARNING: experiment could be totally different!\
|
||||
Use with care!")
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
Loading…
Add table
Add a link
Reference in a new issue