mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-27 17:23:08 +00:00
first commit
This commit is contained in:
commit
621d00dd80
89 changed files with 15315 additions and 0 deletions
154
environments/dataset_environment/launch_local_dataset_run.py
Normal file
154
environments/dataset_environment/launch_local_dataset_run.py
Normal file
|
|
@ -0,0 +1,154 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Local dataset training launcher.
|
||||
|
||||
Usage:
|
||||
python -m environments.dataset_environment.launch_local_dataset_run
|
||||
|
||||
This script does:
|
||||
1) Starts the Trajectory Handler API server via uvicorn
|
||||
2) Launches the DatasetEnv in local serve mode
|
||||
3) Imports and runs the example trainer (GRPO) directly
|
||||
|
||||
Requirements:
|
||||
- Run from project root so example_trainer is on PYTHONPATH
|
||||
- example_trainer/ is a valid Python package (with __init__.py)
|
||||
"""
|
||||
import os
|
||||
import sys
|
||||
import subprocess
|
||||
import time
|
||||
import atexit
|
||||
import signal
|
||||
import traceback
|
||||
|
||||
# Ensure project root is on PYTHONPATH
|
||||
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..'))
|
||||
if project_root not in sys.path:
|
||||
sys.path.insert(0, project_root)
|
||||
|
||||
# Import trainer via standard module import
|
||||
try:
|
||||
from example_trainer.grpo import TrainingConfig, train
|
||||
except ImportError as e:
|
||||
print(f"Error importing example_trainer.grpo: {e}")
|
||||
print("Ensure you're running from project root and that example_trainer/ is a package.")
|
||||
sys.exit(1)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Configuration
|
||||
# -----------------------------------------------------------------------------
|
||||
API_HOST = '127.0.0.1'
|
||||
API_PORT = 8000
|
||||
|
||||
VLLM_HOST = '127.0.0.1'
|
||||
VLLM_PORT = 9001
|
||||
|
||||
MODEL_NAME = 'Qwen/Qwen2.5-1.5B-Instruct'
|
||||
TOKENIZER_NAME = MODEL_NAME
|
||||
|
||||
TRAINER_CONFIG = {
|
||||
'model_name': MODEL_NAME,
|
||||
'training_steps': 20,
|
||||
'batch_size': 2,
|
||||
'gradient_accumulation_steps': 2,
|
||||
'seq_len': 512,
|
||||
'vllm_port': VLLM_PORT,
|
||||
'vllm_restart_interval': 10,
|
||||
'use_wandb': False,
|
||||
'wandb_project': '',
|
||||
'wandb_group': '',
|
||||
'save_path': './trained_model_checkpoints_local_test',
|
||||
}
|
||||
|
||||
# Flags for launching DatasetEnv serve
|
||||
DATASET_FLAGS = [
|
||||
'--group_size', '4',
|
||||
'--max_num_workers', '2',
|
||||
'--rollout_server_url', f"http://{API_HOST}:{API_PORT}",
|
||||
'--tokenizer_name', TOKENIZER_NAME,
|
||||
'--use_wandb',
|
||||
'--wandb_name', 'dataset_env_local_test',
|
||||
'--max_token_length', str(TRAINER_CONFIG['seq_len']),
|
||||
'--ensure_scores_are_not_same',
|
||||
'--dataset_name', 'HuggingFaceH4/testing_self_instruct_process_essays',
|
||||
'--split', 'train[:100]',
|
||||
'--prompt_field', 'prompt',
|
||||
'--answer_field', 'answer',
|
||||
'--reward_functions', 'length',
|
||||
'--max_tokens', '128',
|
||||
'--temperature', '0.7',
|
||||
'--model_name', MODEL_NAME,
|
||||
'--base_url', f"http://{VLLM_HOST}:{VLLM_PORT}",
|
||||
'--slurm',
|
||||
'--testing',
|
||||
]
|
||||
|
||||
# Track background processes for cleanup
|
||||
processes = []
|
||||
|
||||
|
||||
def cleanup_processes():
|
||||
print("\nCleaning up background processes...")
|
||||
for p in reversed(processes):
|
||||
if p.poll() is None:
|
||||
print(f"Terminating PID {p.pid}...")
|
||||
p.terminate()
|
||||
try:
|
||||
p.wait(timeout=5)
|
||||
print(f"PID {p.pid} terminated.")
|
||||
except subprocess.TimeoutExpired:
|
||||
print(f"PID {p.pid} did not terminate; killing.")
|
||||
p.kill()
|
||||
p.wait()
|
||||
print(f"PID {p.pid} killed.")
|
||||
else:
|
||||
print(f"PID {p.pid} already exited.")
|
||||
print("Cleanup complete.")
|
||||
|
||||
atexit.register(cleanup_processes)
|
||||
|
||||
|
||||
def handle_signal(sig, frame):
|
||||
print(f"\nSignal {sig} received; exiting.")
|
||||
sys.exit(0)
|
||||
|
||||
signal.signal(signal.SIGINT, handle_signal)
|
||||
signal.signal(signal.SIGTERM, handle_signal)
|
||||
|
||||
|
||||
def main():
|
||||
# 1) Start the API server
|
||||
print("--- Starting Trajectory Handler API Server ---")
|
||||
api_cmd = [
|
||||
'uvicorn',
|
||||
'atroposlib.api.server:app',
|
||||
'--host', API_HOST,
|
||||
'--port', str(API_PORT),
|
||||
]
|
||||
print(f"$ {' '.join(api_cmd)}")
|
||||
api_proc = subprocess.Popen(api_cmd)
|
||||
processes.append(api_proc)
|
||||
time.sleep(3)
|
||||
|
||||
# 2) Start the dataset environment
|
||||
print("\n--- Starting Dataset Environment ---")
|
||||
env_cmd = ['python', '-m', 'environments.dataset_environment.dataset_env', 'serve'] + DATASET_FLAGS
|
||||
print(f"$ {' '.join(env_cmd)}")
|
||||
env_proc = subprocess.Popen(env_cmd)
|
||||
processes.append(env_proc)
|
||||
time.sleep(3)
|
||||
|
||||
# 3) Run the example trainer
|
||||
print("\n--- Running Example Trainer (GRPO) ---")
|
||||
config = TrainingConfig(**TRAINER_CONFIG)
|
||||
try:
|
||||
train(config)
|
||||
except Exception:
|
||||
print("Error during training:")
|
||||
traceback.print_exc()
|
||||
print("--- Training complete ---")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
Loading…
Add table
Add a link
Reference in a new issue