#!/usr/bin/env python3 """ Local dataset training launcher. Usage: python -m environments.dataset_environment.launch_local_dataset_run This script does: 1) Starts the Trajectory Handler API server via uvicorn 2) Launches the DatasetEnv in local serve mode 3) Imports and runs the example trainer (GRPO) directly Requirements: - Run from project root so example_trainer is on PYTHONPATH - example_trainer/ is a valid Python package (with __init__.py) """ import atexit import os import signal import subprocess import sys import time import traceback # Ensure project root is on PYTHONPATH project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")) if project_root not in sys.path: sys.path.insert(0, project_root) # Import trainer via standard module import try: from example_trainer.grpo import TrainingConfig, train except ImportError as e: print(f"Error importing example_trainer.grpo: {e}") print( "Ensure you're running from project root and that example_trainer/ is a package." ) sys.exit(1) # ----------------------------------------------------------------------------- # Configuration # ----------------------------------------------------------------------------- API_HOST = "127.0.0.1" API_PORT = 8000 VLLM_HOST = "127.0.0.1" VLLM_PORT = 9001 MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct" TOKENIZER_NAME = MODEL_NAME TRAINER_CONFIG = { "model_name": MODEL_NAME, "training_steps": 20, "batch_size": 2, "gradient_accumulation_steps": 2, "seq_len": 512, "vllm_port": VLLM_PORT, "vllm_restart_interval": 10, "use_wandb": False, "wandb_project": "", "wandb_group": "", "save_path": "./trained_model_checkpoints_local_test", } # Flags for launching DatasetEnv serve DATASET_FLAGS = [ "--group_size", "4", "--max_num_workers", "2", "--rollout_server_url", f"http://{API_HOST}:{API_PORT}", "--tokenizer_name", TOKENIZER_NAME, "--use_wandb", "--wandb_name", "dataset_env_local_test", "--max_token_length", str(TRAINER_CONFIG["seq_len"]), "--ensure_scores_are_not_same", "--dataset_name", "HuggingFaceH4/testing_self_instruct_process_essays", "--split", "train[:100]", "--prompt_field", "prompt", "--answer_field", "answer", "--reward_functions", "length", "--max_tokens", "128", "--temperature", "0.7", "--model_name", MODEL_NAME, "--base_url", f"http://{VLLM_HOST}:{VLLM_PORT}", "--slurm", "--testing", ] # Track background processes for cleanup processes = [] def cleanup_processes(): print("\nCleaning up background processes...") for p in reversed(processes): if p.poll() is None: print(f"Terminating PID {p.pid}...") p.terminate() try: p.wait(timeout=5) print(f"PID {p.pid} terminated.") except subprocess.TimeoutExpired: print(f"PID {p.pid} did not terminate; killing.") p.kill() p.wait() print(f"PID {p.pid} killed.") else: print(f"PID {p.pid} already exited.") print("Cleanup complete.") atexit.register(cleanup_processes) def handle_signal(sig, frame): print(f"\nSignal {sig} received; exiting.") sys.exit(0) signal.signal(signal.SIGINT, handle_signal) signal.signal(signal.SIGTERM, handle_signal) def main(): # 1) Start the API server print("--- Starting Trajectory Handler API Server ---") api_cmd = [ "uvicorn", "atroposlib.api.server:app", "--host", API_HOST, "--port", str(API_PORT), ] print(f"$ {' '.join(api_cmd)}") api_proc = subprocess.Popen(api_cmd) processes.append(api_proc) time.sleep(3) # 2) Start the dataset environment print("\n--- Starting Dataset Environment ---") env_cmd = [ "python", "-m", "environments.dataset_environment.dataset_env", "serve", ] + DATASET_FLAGS print(f"$ {' '.join(env_cmd)}") env_proc = subprocess.Popen(env_cmd) processes.append(env_proc) time.sleep(3) # 3) Run the example trainer print("\n--- Running Example Trainer (GRPO) ---") config = TrainingConfig(**TRAINER_CONFIG) try: train(config) except Exception: print("Error during training:") traceback.print_exc() print("--- Training complete ---") if __name__ == "__main__": main()