mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-24 17:04:55 +00:00
171 lines
7 KiB
Python
171 lines
7 KiB
Python
from pathlib import Path
|
|
from tempfile import TemporaryDirectory
|
|
from typing import Any, Iterable, Optional, Sequence
|
|
import traceback
|
|
import yaml
|
|
from pytorch_fob.engine.run import Run
|
|
from pytorch_fob.engine.slurm import Slurm
|
|
from pytorch_fob.engine.utils import log_info, log_warn, seconds_to_str, some, str_to_seconds
|
|
|
|
|
|
FOB_RUN_SCRIPT = "pytorch_fob.run_experiment"
|
|
FOB_EVAL_SCRIPT = "pytorch_fob.evaluate_experiment"
|
|
|
|
|
|
def argcheck_allequal_engine(
|
|
runs: list[Run],
|
|
keys: list[str],
|
|
reason: str = "'engine.run_scheduler=slurm_array'"
|
|
) -> None:
|
|
ok = True
|
|
first = runs[0]
|
|
for key in keys:
|
|
if not all(run.engine[key] == first.engine[key] for run in runs[1:]):
|
|
ok = False
|
|
break
|
|
if not ok:
|
|
req = ", ".join(map(lambda s: "engine." + s, keys))
|
|
raise ValueError(f"All runs must have the same values for {req} when using {reason}")
|
|
|
|
|
|
def export_experiment(run: Run, experiment: dict[str, Any]) -> Path:
|
|
run.run_dir.mkdir(parents=True, exist_ok=True)
|
|
outfile = run.run_dir / "experiment.yaml"
|
|
with open(outfile, "w", encoding="utf8") as f:
|
|
yaml.safe_dump(experiment, f)
|
|
return outfile
|
|
|
|
|
|
def process_args(args: dict[str, str], run: Run) -> None:
|
|
if "time" in args:
|
|
time = args["time"]
|
|
seconds = str_to_seconds(time) if isinstance(time, str) else time
|
|
args["time"] = seconds_to_str(int(run.engine.sbatch_time_factor * seconds))
|
|
if "gres" not in args and "gpus" not in args:
|
|
args["gres"] = f"gpu:{run.engine.devices}"
|
|
if not any(k.startswith("ntasks") for k in args):
|
|
args["ntasks-per-node"] = str(run.engine.devices)
|
|
if not any(k.startswith("cpus") for k in args):
|
|
args["cpus-per-task"] = str(run.engine.workers)
|
|
|
|
|
|
def wrap_template(template_path: Optional[Path], command: str, placeholder: str = "__FOB_COMMAND__") -> str:
|
|
if template_path is not None:
|
|
with open(template_path, "r", encoding="utf8") as f:
|
|
template = f.read()
|
|
if placeholder in template:
|
|
command = template.replace(placeholder, command)
|
|
else:
|
|
command = f"{template}\n{command}\n"
|
|
return command
|
|
|
|
|
|
def get_command(experiment_file: Path, index: Optional[str], plot: bool) -> str:
|
|
run_script = FOB_EVAL_SCRIPT if plot else FOB_RUN_SCRIPT
|
|
disable_plot = "" if plot else "engine.plot=false"
|
|
scheduler = "" if index is None else f"engine.run_scheduler=single:{index}"
|
|
return f"""srun python -m {run_script} {experiment_file} {scheduler} {disable_plot}"""
|
|
|
|
|
|
def get_job_name(run: Run) -> str:
|
|
return f"FOB-{run.task.name}-{run.optimizer.name}"
|
|
|
|
|
|
def get_slurm(job_name: str, args: dict[str, str], log_dir: Path, scripts_dir: Path) -> Slurm:
|
|
return Slurm(
|
|
job_name,
|
|
args,
|
|
log_dir=str(log_dir.resolve()),
|
|
scripts_dir=str(scripts_dir.resolve()),
|
|
bash_strict=False # TODO: maybe add arg or just remove 'nounset'
|
|
)
|
|
|
|
|
|
def run_slurm(
|
|
job_name: str,
|
|
command: str,
|
|
args: dict[str, str],
|
|
log_dir: Path,
|
|
save_sbatch_scripts: Optional[Path] = None,
|
|
dependencies: Sequence[int] = tuple(),
|
|
dependency_type: str = "afterok"
|
|
) -> Optional[int]:
|
|
if save_sbatch_scripts is None:
|
|
with TemporaryDirectory() as tmpdir:
|
|
s = get_slurm(job_name, args, log_dir, scripts_dir=Path(tmpdir).resolve())
|
|
return s.run(command, name_addition="", depends_on=dependencies, dependency_type=dependency_type)
|
|
else:
|
|
s = get_slurm(job_name, args, log_dir, scripts_dir=save_sbatch_scripts)
|
|
return s.run(command, name_addition="", depends_on=dependencies, dependency_type=dependency_type)
|
|
|
|
|
|
def run_plotting_job(
|
|
experiment_file: Path,
|
|
args: dict[str, str],
|
|
log_dir: Path,
|
|
dependencies: Sequence[int],
|
|
template: Optional[Path] = None
|
|
) -> None:
|
|
args["time"] = seconds_to_str(300) # 5 minutes should be plenty of time to plot
|
|
args.pop("array", None)
|
|
# no gpus needed for plotting
|
|
args.pop("gpus", None)
|
|
args.pop("gres", None)
|
|
# just one cpu per node for plotting
|
|
remove_keys = [k for k in args.keys() if k.startswith("ntasks") or k.startswith("cpus")]
|
|
for k in remove_keys:
|
|
args.pop(k)
|
|
args["nodes"] = "1"
|
|
args["ntasks-per-node"] = "1"
|
|
args["cpus-per-task"] = "2"
|
|
command = get_command(experiment_file, None, plot=True)
|
|
command = wrap_template(template, command)
|
|
run_slurm("FOB-plot", command, args, log_dir, dependencies=dependencies, dependency_type="afterany")
|
|
|
|
|
|
def slurm_array(runs: list[Run], experiment: dict[str, Any]) -> None:
|
|
equal_req = ["devices", "workers", "sbatch_args", "slurm_log_dir", "sbatch_script_template", "run_scheduler"]
|
|
argcheck_allequal_engine(runs, equal_req)
|
|
run = runs[0] # all runs have the same args
|
|
args = run.engine.sbatch_args
|
|
log_dir = some(run.engine.slurm_log_dir, default=run.engine.output_dir / "slurm_logs")
|
|
if "array" not in args:
|
|
args["array"] = f"1-{len(runs)}"
|
|
process_args(args, run)
|
|
experiment_file = [export_experiment(run, experiment).resolve() for run in runs][0]
|
|
command = get_command(experiment_file, "$SLURM_ARRAY_TASK_ID", plot=False)
|
|
command = wrap_template(run.engine.sbatch_script_template, command)
|
|
job_id = run_slurm(get_job_name(run), command, args, log_dir, save_sbatch_scripts=run.engine.save_sbatch_scripts)
|
|
if job_id is not None and run.engine.plot:
|
|
run_plotting_job(experiment_file, args, log_dir, [job_id], template=run.engine.sbatch_script_template)
|
|
|
|
|
|
def slurm_jobs(runs: list[Run], experiment: dict[str, Any]) -> list[int]:
|
|
job_ids = []
|
|
experiment_file = Path()
|
|
for i, run in enumerate(runs, start=1):
|
|
args = run.engine.sbatch_args
|
|
process_args(args, run)
|
|
log_dir = some(run.engine.slurm_log_dir, default=run.run_dir / "slurm_logs")
|
|
experiment_file = export_experiment(run, experiment).resolve()
|
|
command = get_command(experiment_file, str(i), plot=False)
|
|
command = wrap_template(run.engine.sbatch_script_template, command)
|
|
job_id = run_slurm(get_job_name(run), command, args, log_dir, save_sbatch_scripts=run.engine.save_sbatch_scripts)
|
|
if job_id is not None:
|
|
job_ids.append(job_id)
|
|
if len(job_ids) > 0 and any(map(lambda r: r.engine.plot, runs)):
|
|
equal_req = ["slurm_log_dir", "sbatch_script_template"]
|
|
argcheck_allequal_engine(runs, equal_req, reason="'engine.plot=true' with 'engine.run_scheduler=slurm_jobs'")
|
|
run_plotting_job(experiment_file, args, log_dir, job_ids, template=runs[0].engine.sbatch_script_template)
|
|
return job_ids
|
|
|
|
|
|
def sequential(runs: Iterable[Run], n_runs: int, experiment: dict[str, Any]):
|
|
for i, run in enumerate(runs, start=1):
|
|
log_info(f"Starting run {i}/{n_runs}.")
|
|
export_experiment(run, experiment)
|
|
try:
|
|
run.start()
|
|
except RuntimeError as _e: # detect_anomaly raises RuntimeError
|
|
t = traceback.format_exc()
|
|
log_warn(f"Run {i}/{n_runs} failed with {t}.")
|