atropos/environments/optimizer/FOB/pytorch_fob/engine/run_schedulers.py
2025-05-18 16:36:28 -07:00

171 lines
7 KiB
Python

from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Any, Iterable, Optional, Sequence
import traceback
import yaml
from pytorch_fob.engine.run import Run
from pytorch_fob.engine.slurm import Slurm
from pytorch_fob.engine.utils import log_info, log_warn, seconds_to_str, some, str_to_seconds
FOB_RUN_SCRIPT = "pytorch_fob.run_experiment"
FOB_EVAL_SCRIPT = "pytorch_fob.evaluate_experiment"
def argcheck_allequal_engine(
runs: list[Run],
keys: list[str],
reason: str = "'engine.run_scheduler=slurm_array'"
) -> None:
ok = True
first = runs[0]
for key in keys:
if not all(run.engine[key] == first.engine[key] for run in runs[1:]):
ok = False
break
if not ok:
req = ", ".join(map(lambda s: "engine." + s, keys))
raise ValueError(f"All runs must have the same values for {req} when using {reason}")
def export_experiment(run: Run, experiment: dict[str, Any]) -> Path:
run.run_dir.mkdir(parents=True, exist_ok=True)
outfile = run.run_dir / "experiment.yaml"
with open(outfile, "w", encoding="utf8") as f:
yaml.safe_dump(experiment, f)
return outfile
def process_args(args: dict[str, str], run: Run) -> None:
if "time" in args:
time = args["time"]
seconds = str_to_seconds(time) if isinstance(time, str) else time
args["time"] = seconds_to_str(int(run.engine.sbatch_time_factor * seconds))
if "gres" not in args and "gpus" not in args:
args["gres"] = f"gpu:{run.engine.devices}"
if not any(k.startswith("ntasks") for k in args):
args["ntasks-per-node"] = str(run.engine.devices)
if not any(k.startswith("cpus") for k in args):
args["cpus-per-task"] = str(run.engine.workers)
def wrap_template(template_path: Optional[Path], command: str, placeholder: str = "__FOB_COMMAND__") -> str:
if template_path is not None:
with open(template_path, "r", encoding="utf8") as f:
template = f.read()
if placeholder in template:
command = template.replace(placeholder, command)
else:
command = f"{template}\n{command}\n"
return command
def get_command(experiment_file: Path, index: Optional[str], plot: bool) -> str:
run_script = FOB_EVAL_SCRIPT if plot else FOB_RUN_SCRIPT
disable_plot = "" if plot else "engine.plot=false"
scheduler = "" if index is None else f"engine.run_scheduler=single:{index}"
return f"""srun python -m {run_script} {experiment_file} {scheduler} {disable_plot}"""
def get_job_name(run: Run) -> str:
return f"FOB-{run.task.name}-{run.optimizer.name}"
def get_slurm(job_name: str, args: dict[str, str], log_dir: Path, scripts_dir: Path) -> Slurm:
return Slurm(
job_name,
args,
log_dir=str(log_dir.resolve()),
scripts_dir=str(scripts_dir.resolve()),
bash_strict=False # TODO: maybe add arg or just remove 'nounset'
)
def run_slurm(
job_name: str,
command: str,
args: dict[str, str],
log_dir: Path,
save_sbatch_scripts: Optional[Path] = None,
dependencies: Sequence[int] = tuple(),
dependency_type: str = "afterok"
) -> Optional[int]:
if save_sbatch_scripts is None:
with TemporaryDirectory() as tmpdir:
s = get_slurm(job_name, args, log_dir, scripts_dir=Path(tmpdir).resolve())
return s.run(command, name_addition="", depends_on=dependencies, dependency_type=dependency_type)
else:
s = get_slurm(job_name, args, log_dir, scripts_dir=save_sbatch_scripts)
return s.run(command, name_addition="", depends_on=dependencies, dependency_type=dependency_type)
def run_plotting_job(
experiment_file: Path,
args: dict[str, str],
log_dir: Path,
dependencies: Sequence[int],
template: Optional[Path] = None
) -> None:
args["time"] = seconds_to_str(300) # 5 minutes should be plenty of time to plot
args.pop("array", None)
# no gpus needed for plotting
args.pop("gpus", None)
args.pop("gres", None)
# just one cpu per node for plotting
remove_keys = [k for k in args.keys() if k.startswith("ntasks") or k.startswith("cpus")]
for k in remove_keys:
args.pop(k)
args["nodes"] = "1"
args["ntasks-per-node"] = "1"
args["cpus-per-task"] = "2"
command = get_command(experiment_file, None, plot=True)
command = wrap_template(template, command)
run_slurm("FOB-plot", command, args, log_dir, dependencies=dependencies, dependency_type="afterany")
def slurm_array(runs: list[Run], experiment: dict[str, Any]) -> None:
equal_req = ["devices", "workers", "sbatch_args", "slurm_log_dir", "sbatch_script_template", "run_scheduler"]
argcheck_allequal_engine(runs, equal_req)
run = runs[0] # all runs have the same args
args = run.engine.sbatch_args
log_dir = some(run.engine.slurm_log_dir, default=run.engine.output_dir / "slurm_logs")
if "array" not in args:
args["array"] = f"1-{len(runs)}"
process_args(args, run)
experiment_file = [export_experiment(run, experiment).resolve() for run in runs][0]
command = get_command(experiment_file, "$SLURM_ARRAY_TASK_ID", plot=False)
command = wrap_template(run.engine.sbatch_script_template, command)
job_id = run_slurm(get_job_name(run), command, args, log_dir, save_sbatch_scripts=run.engine.save_sbatch_scripts)
if job_id is not None and run.engine.plot:
run_plotting_job(experiment_file, args, log_dir, [job_id], template=run.engine.sbatch_script_template)
def slurm_jobs(runs: list[Run], experiment: dict[str, Any]) -> list[int]:
job_ids = []
experiment_file = Path()
for i, run in enumerate(runs, start=1):
args = run.engine.sbatch_args
process_args(args, run)
log_dir = some(run.engine.slurm_log_dir, default=run.run_dir / "slurm_logs")
experiment_file = export_experiment(run, experiment).resolve()
command = get_command(experiment_file, str(i), plot=False)
command = wrap_template(run.engine.sbatch_script_template, command)
job_id = run_slurm(get_job_name(run), command, args, log_dir, save_sbatch_scripts=run.engine.save_sbatch_scripts)
if job_id is not None:
job_ids.append(job_id)
if len(job_ids) > 0 and any(map(lambda r: r.engine.plot, runs)):
equal_req = ["slurm_log_dir", "sbatch_script_template"]
argcheck_allequal_engine(runs, equal_req, reason="'engine.plot=true' with 'engine.run_scheduler=slurm_jobs'")
run_plotting_job(experiment_file, args, log_dir, job_ids, template=runs[0].engine.sbatch_script_template)
return job_ids
def sequential(runs: Iterable[Run], n_runs: int, experiment: dict[str, Any]):
for i, run in enumerate(runs, start=1):
log_info(f"Starting run {i}/{n_runs}.")
export_experiment(run, experiment)
try:
run.start()
except RuntimeError as _e: # detect_anomaly raises RuntimeError
t = traceback.format_exc()
log_warn(f"Run {i}/{n_runs} failed with {t}.")