mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-25 17:10:42 +00:00
Convert FOB submodule to regular folder
This commit is contained in:
parent
94f046ad40
commit
94825011a0
74 changed files with 4563 additions and 0 deletions
228
environments/optimizer/FOB/pytorch_fob/engine/utils.py
Normal file
228
environments/optimizer/FOB/pytorch_fob/engine/utils.py
Normal file
|
|
@ -0,0 +1,228 @@
|
|||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Iterable, Optional, Type
|
||||
import json
|
||||
import math
|
||||
import signal
|
||||
import torch
|
||||
from lightning_utilities.core.rank_zero import rank_zero_only, rank_zero_info, rank_zero_debug, log
|
||||
|
||||
|
||||
def set_loglevel(level: str):
|
||||
pytorch_logger = logging.getLogger("lightning.pytorch")
|
||||
match level:
|
||||
case "debug":
|
||||
pytorch_logger.setLevel(logging.DEBUG)
|
||||
case "info":
|
||||
pytorch_logger.setLevel(logging.INFO)
|
||||
case "warn":
|
||||
pytorch_logger.setLevel(logging.WARNING)
|
||||
case "error":
|
||||
pytorch_logger.setLevel(logging.ERROR)
|
||||
case "silent":
|
||||
pytorch_logger.setLevel(logging.CRITICAL)
|
||||
|
||||
|
||||
@rank_zero_only
|
||||
def rank_zero_print(*args: Any, **kwargs: Any):
|
||||
return print(*args, **kwargs)
|
||||
|
||||
|
||||
@rank_zero_only
|
||||
def log_warn(msg: str, *args: Any, prefix: str = "[FOB WARNING] ", **kwargs: Any):
|
||||
return log.warning(f"{prefix}{msg}", *args, **kwargs)
|
||||
|
||||
|
||||
def log_info(msg: str, *args: Any, prefix: str = "[FOB INFO] ", **kwargs: Any):
|
||||
return rank_zero_info(f"{prefix}{msg}", *args, **kwargs)
|
||||
|
||||
|
||||
def log_debug(msg: str, *args: Any, prefix: str = "[FOB DEBUG] ", **kwargs: Any):
|
||||
return rank_zero_debug(f"{prefix}{msg}", *args, **kwargs)
|
||||
|
||||
|
||||
def write_results(results, filepath: Path):
|
||||
with open(filepath, "w", encoding="utf8") as f:
|
||||
json.dump(results, f, indent=4)
|
||||
print(f"Saved results into {filepath}.")
|
||||
|
||||
|
||||
def wrap_list(x: Any) -> list[Any]:
|
||||
if isinstance(x, list):
|
||||
return x
|
||||
return [x]
|
||||
|
||||
|
||||
def calculate_steps(epochs: int, datapoints: int, devices: int, batch_size: int) -> int:
|
||||
return math.ceil(datapoints / batch_size / devices) * epochs
|
||||
|
||||
|
||||
def some(*args, default):
|
||||
"""
|
||||
returns the first argument that is not None or default.
|
||||
"""
|
||||
if len(args) < 1:
|
||||
return default
|
||||
first, *rest = args
|
||||
if first is not None:
|
||||
return first
|
||||
return some(*rest, default=default)
|
||||
|
||||
|
||||
def maybe_abspath(path: Optional[str | Path]) -> Optional[Path]:
|
||||
if path is None:
|
||||
return None
|
||||
return Path(path).resolve()
|
||||
|
||||
|
||||
def findfirst(f: Callable, xs: Iterable):
|
||||
for x in xs:
|
||||
if f(x):
|
||||
return x
|
||||
return None
|
||||
|
||||
|
||||
def trainer_strategy(devices: int | list[int] | str) -> str:
|
||||
if isinstance(devices, str):
|
||||
return "auto"
|
||||
ndevices = devices if isinstance(devices, int) else len(devices)
|
||||
return "ddp" if ndevices > 1 else "auto"
|
||||
|
||||
|
||||
def gpu_suited_for_compile():
|
||||
if torch.cuda.is_available():
|
||||
device_cap = torch.cuda.get_device_capability()
|
||||
return device_cap in ((7, 0), (8, 0), (9, 0))
|
||||
|
||||
|
||||
def precision_with_fallback(precision: str) -> str:
|
||||
"""
|
||||
Check if cuda supports bf16, if not using cuda or if not available return 16 instead of bf16
|
||||
"""
|
||||
if not torch.cuda.is_available():
|
||||
log_warn("Warning: No CUDA available. Results can be different!")
|
||||
return precision[2:]
|
||||
if precision.startswith("bf") and not torch.cuda.is_bf16_supported():
|
||||
log_warn("Warning: GPU does not support bfloat16. Results can be different!")
|
||||
return precision[2:]
|
||||
return precision
|
||||
|
||||
|
||||
def str_to_seconds(s: str) -> int:
|
||||
parts = s.split(":")
|
||||
assert len(parts) == 3, f"Invalid time format: {s}. Use 'HH:MM:SS'."
|
||||
return int(parts[0]) * 3600 + int(parts[1]) * 60 + int(parts[2])
|
||||
|
||||
|
||||
def seconds_to_str(total_seconds: int, sep: str = ":") -> str:
|
||||
hours, rest = divmod(total_seconds, 3600)
|
||||
minutes, seconds = divmod(rest, 60)
|
||||
return sep.join(map(lambda x: str(x).zfill(2), [hours, minutes, seconds]))
|
||||
|
||||
|
||||
def begin_timeout(delay=10, show_threads=False):
|
||||
if show_threads:
|
||||
import sys
|
||||
import traceback
|
||||
import threading
|
||||
thread_names = {t.ident: t.name for t in threading.enumerate()}
|
||||
for thread_id, frame in sys._current_frames().items():
|
||||
print(f"Thread {thread_names.get(thread_id, thread_id)}:")
|
||||
traceback.print_stack(frame)
|
||||
print()
|
||||
signal.alarm(delay) # Timeout after 10 seconds
|
||||
|
||||
|
||||
def path_to_str_inside_dict(d: dict) -> dict:
|
||||
return convert_type_inside_dict(d, Path, str)
|
||||
|
||||
|
||||
def convert_type_inside_dict(d: dict, src: Type, tgt: Type) -> dict:
|
||||
ret = {}
|
||||
for k, v in d.items():
|
||||
if isinstance(v, dict):
|
||||
v = convert_type_inside_dict(v, src, tgt)
|
||||
if isinstance(v, src):
|
||||
ret[k] = tgt(v)
|
||||
else:
|
||||
ret[k] = v
|
||||
return ret
|
||||
|
||||
|
||||
def dict_differences(custom: dict[str, Any], default: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Recursively returns a dictionary with the items in `custom` that are different or missing from `default`.
|
||||
|
||||
Example:
|
||||
>>> dict_differences({"hi": 3, "bla": {"a": 2, "b": 2}}, {"hi": 2, "bla": {"a": 1, "b": 2}})
|
||||
{'hi': 3, 'bla': {'a': 2}}
|
||||
"""
|
||||
diff: dict[str, Any] = {}
|
||||
for key, value in custom.items():
|
||||
if key in default:
|
||||
default_value = default[key]
|
||||
if default_value == value:
|
||||
continue
|
||||
if isinstance(value, dict) and isinstance(default_value, dict):
|
||||
diff[key] = dict_differences(value, default_value)
|
||||
else:
|
||||
diff[key] = value
|
||||
else:
|
||||
diff[key] = value
|
||||
return diff
|
||||
|
||||
|
||||
def concatenate_dict_keys(
|
||||
d: dict[str, Any],
|
||||
parent_key: str = "",
|
||||
sep: str = ".",
|
||||
exclude_keys: Iterable[str] = tuple()
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Example:
|
||||
>>> concatenate_dict_keys({ "A": { "B": { "C": 1, "D": 2 }, "E": { "F": 3 } } })
|
||||
{'A.B.C': 1, 'A.B.D': 2, 'A.E.F': 3}
|
||||
>>> concatenate_dict_keys({ "A": { "B": { "C": 1, "D": 2 }, "E": { "F": 3 } } }, exclude_keys=["B"])
|
||||
{'A.E.F': 3}
|
||||
"""
|
||||
result = {}
|
||||
for k, v in d.items():
|
||||
if k in exclude_keys:
|
||||
continue
|
||||
new_key = f"{parent_key}{sep}{k}" if parent_key else k
|
||||
if isinstance(v, dict):
|
||||
nested_result = concatenate_dict_keys(v, new_key, sep, exclude_keys)
|
||||
result.update(nested_result)
|
||||
else:
|
||||
result[new_key] = v
|
||||
return result
|
||||
|
||||
|
||||
def sort_dict_recursively(d: dict) -> dict:
|
||||
sorted_dict = {}
|
||||
for k, v in sorted(d.items()):
|
||||
if isinstance(v, dict):
|
||||
sorted_dict[k] = sort_dict_recursively(v)
|
||||
else:
|
||||
sorted_dict[k] = v
|
||||
return sorted_dict
|
||||
|
||||
|
||||
class EndlessList(list):
|
||||
"""
|
||||
Returns first element if out of bounds. Otherwise same as list.
|
||||
"""
|
||||
def __getitem__(self, index):
|
||||
if index >= len(self) and len(self) > 0:
|
||||
return self[0]
|
||||
return super().__getitem__(index)
|
||||
|
||||
|
||||
class AttributeDict(dict):
|
||||
|
||||
def __getattribute__(self, key: str) -> Any:
|
||||
try:
|
||||
return super().__getattribute__(key)
|
||||
except AttributeError:
|
||||
pass
|
||||
return super().__getitem__(key)
|
||||
Loading…
Add table
Add a link
Reference in a new issue