atropos/environments/dataset_environment/launch_local_dataset_run.py
2025-04-29 12:10:10 -07:00

154 lines
No EOL
4.3 KiB
Python

#!/usr/bin/env python3
"""
Local dataset training launcher.
Usage:
python -m environments.dataset_environment.launch_local_dataset_run
This script does:
1) Starts the Trajectory Handler API server via uvicorn
2) Launches the DatasetEnv in local serve mode
3) Imports and runs the example trainer (GRPO) directly
Requirements:
- Run from project root so example_trainer is on PYTHONPATH
- example_trainer/ is a valid Python package (with __init__.py)
"""
import os
import sys
import subprocess
import time
import atexit
import signal
import traceback
# Ensure project root is on PYTHONPATH
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..'))
if project_root not in sys.path:
sys.path.insert(0, project_root)
# Import trainer via standard module import
try:
from example_trainer.grpo import TrainingConfig, train
except ImportError as e:
print(f"Error importing example_trainer.grpo: {e}")
print("Ensure you're running from project root and that example_trainer/ is a package.")
sys.exit(1)
# -----------------------------------------------------------------------------
# Configuration
# -----------------------------------------------------------------------------
API_HOST = '127.0.0.1'
API_PORT = 8000
VLLM_HOST = '127.0.0.1'
VLLM_PORT = 9001
MODEL_NAME = 'Qwen/Qwen2.5-1.5B-Instruct'
TOKENIZER_NAME = MODEL_NAME
TRAINER_CONFIG = {
'model_name': MODEL_NAME,
'training_steps': 20,
'batch_size': 2,
'gradient_accumulation_steps': 2,
'seq_len': 512,
'vllm_port': VLLM_PORT,
'vllm_restart_interval': 10,
'use_wandb': False,
'wandb_project': '',
'wandb_group': '',
'save_path': './trained_model_checkpoints_local_test',
}
# Flags for launching DatasetEnv serve
DATASET_FLAGS = [
'--group_size', '4',
'--max_num_workers', '2',
'--rollout_server_url', f"http://{API_HOST}:{API_PORT}",
'--tokenizer_name', TOKENIZER_NAME,
'--use_wandb',
'--wandb_name', 'dataset_env_local_test',
'--max_token_length', str(TRAINER_CONFIG['seq_len']),
'--ensure_scores_are_not_same',
'--dataset_name', 'HuggingFaceH4/testing_self_instruct_process_essays',
'--split', 'train[:100]',
'--prompt_field', 'prompt',
'--answer_field', 'answer',
'--reward_functions', 'length',
'--max_tokens', '128',
'--temperature', '0.7',
'--model_name', MODEL_NAME,
'--base_url', f"http://{VLLM_HOST}:{VLLM_PORT}",
'--slurm',
'--testing',
]
# Track background processes for cleanup
processes = []
def cleanup_processes():
print("\nCleaning up background processes...")
for p in reversed(processes):
if p.poll() is None:
print(f"Terminating PID {p.pid}...")
p.terminate()
try:
p.wait(timeout=5)
print(f"PID {p.pid} terminated.")
except subprocess.TimeoutExpired:
print(f"PID {p.pid} did not terminate; killing.")
p.kill()
p.wait()
print(f"PID {p.pid} killed.")
else:
print(f"PID {p.pid} already exited.")
print("Cleanup complete.")
atexit.register(cleanup_processes)
def handle_signal(sig, frame):
print(f"\nSignal {sig} received; exiting.")
sys.exit(0)
signal.signal(signal.SIGINT, handle_signal)
signal.signal(signal.SIGTERM, handle_signal)
def main():
# 1) Start the API server
print("--- Starting Trajectory Handler API Server ---")
api_cmd = [
'uvicorn',
'atroposlib.api.server:app',
'--host', API_HOST,
'--port', str(API_PORT),
]
print(f"$ {' '.join(api_cmd)}")
api_proc = subprocess.Popen(api_cmd)
processes.append(api_proc)
time.sleep(3)
# 2) Start the dataset environment
print("\n--- Starting Dataset Environment ---")
env_cmd = ['python', '-m', 'environments.dataset_environment.dataset_env', 'serve'] + DATASET_FLAGS
print(f"$ {' '.join(env_cmd)}")
env_proc = subprocess.Popen(env_cmd)
processes.append(env_proc)
time.sleep(3)
# 3) Run the example trainer
print("\n--- Running Example Trainer (GRPO) ---")
config = TrainingConfig(**TRAINER_CONFIG)
try:
train(config)
except Exception:
print("Error during training:")
traceback.print_exc()
print("--- Training complete ---")
if __name__ == '__main__':
main()