run pre-commit on all files

This commit is contained in:
dmahan93 2025-05-09 09:54:20 -05:00
parent b959c30ebf
commit 40b12dae60
17 changed files with 169 additions and 118 deletions

View file

@ -2,7 +2,6 @@ import argparse
import time
import requests
import wandb

View file

@ -387,8 +387,10 @@ class BaseEnv(ABC):
# Now register the env...
while True:
data = await self._register_env()
if data['status'] != "success":
logging.warning(f"Waiting to register the env due to status {data['status']}")
if data["status"] != "success":
logging.warning(
f"Waiting to register the env due to status {data['status']}"
)
await asyncio.sleep(1)
continue
self.env_id = data["env_id"]

View file

@ -32,7 +32,9 @@ class DatasetEnvConfig(BaseEnvConfig):
None, description="Field in dataset containing canonical correct answer"
)
system_prompt: Optional[str] = Field(None, description="System prompt to use")
prefill: Optional[str] = Field(None, description="Text to prefill the completion with (e.g. '<think>')")
prefill: Optional[str] = Field(
None, description="Text to prefill the completion with (e.g. '<think>')"
)
shuffle_dataset: bool = Field(True, description="Whether to shuffle the dataset")
max_generations_per_prompt: int = Field(
1, description="Number of generations per prompt for collection"
@ -176,12 +178,18 @@ class DatasetEnv(BaseEnv):
# Process completions
for completion in completions.choices:
# Get the completion text
completion_text = completion.text if hasattr(completion, "text") else completion.message.content
completion_text = (
completion.text
if hasattr(completion, "text")
else completion.message.content
)
# Build full message sequence for scoring
full_messages = []
if self.config.system_prompt:
full_messages.append({"role": "system", "content": self.config.system_prompt})
full_messages.append(
{"role": "system", "content": self.config.system_prompt}
)
full_messages.append({"role": "user", "content": user_content})
@ -193,9 +201,7 @@ class DatasetEnv(BaseEnv):
full_messages.append({"role": "assistant", "content": response_content})
# Add to scoring list with answer and ground truth
to_score.append(
(full_messages, answer, item[2] if len(item) > 2 else None)
)
to_score.append((full_messages, answer, item[2] if len(item) > 2 else None))
return to_score, to_backlog
@ -239,12 +245,18 @@ class DatasetEnv(BaseEnv):
trajectories = []
for completion in completions.choices:
# Get the completion text
completion_text = completion.text if hasattr(completion, "text") else completion.message.content
completion_text = (
completion.text
if hasattr(completion, "text")
else completion.message.content
)
# Build complete message sequence
full_messages = []
if self.config.system_prompt:
full_messages.append({"role": "system", "content": self.config.system_prompt})
full_messages.append(
{"role": "system", "content": self.config.system_prompt}
)
full_messages.append({"role": "user", "content": user_content})
@ -402,6 +414,7 @@ class DatasetEnv(BaseEnv):
await super().wandb_log(metrics)
if __name__ == "__main__":
# Launch the DatasetEnv via the BaseEnv CLI (serve or process)
DatasetEnv.cli()

View file

@ -8,7 +8,8 @@ from dotenv import load_dotenv
from atroposlib.envs.base import OpenaiConfig
from atroposlib.envs.reward_fns import registry
from atroposlib.utils.config_handler import ConfigHandler
# from atroposlib.utils.config_handler import ConfigHandler
from environments.dataset_environment.dataset_env import DatasetEnv, DatasetEnvConfig
load_dotenv()
@ -23,7 +24,8 @@ def parse_arguments():
"--config",
type=str,
default="dataset_local",
help="Configuration file name (without .yaml extension) relative to environments/dataset_environment/configs/, or full path to a YAML file.",
help="Configuration file name (without .yaml extension) relative to environments/dataset_environment/configs/,"
" or full path to a YAML file.",
)
return parser.parse_args()
@ -35,7 +37,7 @@ async def main():
args = parse_arguments()
# Initialize config handler
config_handler = ConfigHandler()
# config_handler = ConfigHandler()
# Determine config path
if (

View file

@ -14,16 +14,16 @@ Requirements:
- Run from project root so example_trainer is on PYTHONPATH
- example_trainer/ is a valid Python package (with __init__.py)
"""
import os
import sys
import subprocess
import time
import atexit
import os
import signal
import subprocess
import sys
import time
import traceback
# Ensure project root is on PYTHONPATH
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..'))
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
if project_root not in sys.path:
sys.path.insert(0, project_root)
@ -32,56 +32,73 @@ try:
from example_trainer.grpo import TrainingConfig, train
except ImportError as e:
print(f"Error importing example_trainer.grpo: {e}")
print("Ensure you're running from project root and that example_trainer/ is a package.")
print(
"Ensure you're running from project root and that example_trainer/ is a package."
)
sys.exit(1)
# -----------------------------------------------------------------------------
# Configuration
# -----------------------------------------------------------------------------
API_HOST = '127.0.0.1'
API_HOST = "127.0.0.1"
API_PORT = 8000
VLLM_HOST = '127.0.0.1'
VLLM_HOST = "127.0.0.1"
VLLM_PORT = 9001
MODEL_NAME = 'Qwen/Qwen2.5-1.5B-Instruct'
MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct"
TOKENIZER_NAME = MODEL_NAME
TRAINER_CONFIG = {
'model_name': MODEL_NAME,
'training_steps': 20,
'batch_size': 2,
'gradient_accumulation_steps': 2,
'seq_len': 512,
'vllm_port': VLLM_PORT,
'vllm_restart_interval': 10,
'use_wandb': False,
'wandb_project': '',
'wandb_group': '',
'save_path': './trained_model_checkpoints_local_test',
"model_name": MODEL_NAME,
"training_steps": 20,
"batch_size": 2,
"gradient_accumulation_steps": 2,
"seq_len": 512,
"vllm_port": VLLM_PORT,
"vllm_restart_interval": 10,
"use_wandb": False,
"wandb_project": "",
"wandb_group": "",
"save_path": "./trained_model_checkpoints_local_test",
}
# Flags for launching DatasetEnv serve
DATASET_FLAGS = [
'--group_size', '4',
'--max_num_workers', '2',
'--rollout_server_url', f"http://{API_HOST}:{API_PORT}",
'--tokenizer_name', TOKENIZER_NAME,
'--use_wandb',
'--wandb_name', 'dataset_env_local_test',
'--max_token_length', str(TRAINER_CONFIG['seq_len']),
'--ensure_scores_are_not_same',
'--dataset_name', 'HuggingFaceH4/testing_self_instruct_process_essays',
'--split', 'train[:100]',
'--prompt_field', 'prompt',
'--answer_field', 'answer',
'--reward_functions', 'length',
'--max_tokens', '128',
'--temperature', '0.7',
'--model_name', MODEL_NAME,
'--base_url', f"http://{VLLM_HOST}:{VLLM_PORT}",
'--slurm',
'--testing',
"--group_size",
"4",
"--max_num_workers",
"2",
"--rollout_server_url",
f"http://{API_HOST}:{API_PORT}",
"--tokenizer_name",
TOKENIZER_NAME,
"--use_wandb",
"--wandb_name",
"dataset_env_local_test",
"--max_token_length",
str(TRAINER_CONFIG["seq_len"]),
"--ensure_scores_are_not_same",
"--dataset_name",
"HuggingFaceH4/testing_self_instruct_process_essays",
"--split",
"train[:100]",
"--prompt_field",
"prompt",
"--answer_field",
"answer",
"--reward_functions",
"length",
"--max_tokens",
"128",
"--temperature",
"0.7",
"--model_name",
MODEL_NAME,
"--base_url",
f"http://{VLLM_HOST}:{VLLM_PORT}",
"--slurm",
"--testing",
]
# Track background processes for cleanup
@ -106,6 +123,7 @@ def cleanup_processes():
print(f"PID {p.pid} already exited.")
print("Cleanup complete.")
atexit.register(cleanup_processes)
@ -113,6 +131,7 @@ def handle_signal(sig, frame):
print(f"\nSignal {sig} received; exiting.")
sys.exit(0)
signal.signal(signal.SIGINT, handle_signal)
signal.signal(signal.SIGTERM, handle_signal)
@ -121,10 +140,12 @@ def main():
# 1) Start the API server
print("--- Starting Trajectory Handler API Server ---")
api_cmd = [
'uvicorn',
'atroposlib.api.server:app',
'--host', API_HOST,
'--port', str(API_PORT),
"uvicorn",
"atroposlib.api.server:app",
"--host",
API_HOST,
"--port",
str(API_PORT),
]
print(f"$ {' '.join(api_cmd)}")
api_proc = subprocess.Popen(api_cmd)
@ -133,7 +154,12 @@ def main():
# 2) Start the dataset environment
print("\n--- Starting Dataset Environment ---")
env_cmd = ['python', '-m', 'environments.dataset_environment.dataset_env', 'serve'] + DATASET_FLAGS
env_cmd = [
"python",
"-m",
"environments.dataset_environment.dataset_env",
"serve",
] + DATASET_FLAGS
print(f"$ {' '.join(env_cmd)}")
env_proc = subprocess.Popen(env_cmd)
processes.append(env_proc)
@ -150,5 +176,5 @@ def main():
print("--- Training complete ---")
if __name__ == '__main__':
if __name__ == "__main__":
main()

View file

@ -1,10 +1,10 @@
import random
import re
from typing import List, Optional, Tuple, Union, Dict
from typing import Dict, List, Optional, Tuple, Union
import wandb
from datasets import load_dataset
from tqdm.asyncio import tqdm_asyncio
import wandb
from atroposlib.envs.base import (
BaseEnv,
@ -508,23 +508,32 @@ class FundamentalPredictionEnv(BaseEnv):
# Calculate and log training direction accuracy
try:
direction_accuracy = sum(self.percent_correct_buffer) / len(self.percent_correct_buffer)
direction_accuracy = sum(self.percent_correct_buffer) / len(
self.percent_correct_buffer
)
wandb_metrics["train/direction_accuracy"] = direction_accuracy
except ZeroDivisionError:
pass # Skip if buffer is empty
# Calculate and log training magnitude accuracy
try:
magnitude_accuracy = sum(self.magnitude_accuracy_buffer) / len(self.magnitude_accuracy_buffer)
magnitude_accuracy = sum(self.magnitude_accuracy_buffer) / len(
self.magnitude_accuracy_buffer
)
wandb_metrics["train/magnitude_accuracy"] = magnitude_accuracy
except ZeroDivisionError:
pass # Skip if buffer is empty
# Calculate combined training score (direction + magnitude)
try:
combined_score = direction_accuracy + magnitude_accuracy if 'direction_accuracy' in wandb_metrics else 0
combined_score = (
direction_accuracy + magnitude_accuracy
if "direction_accuracy" in wandb_metrics
else 0
)
wandb_metrics["train/combined_score"] = combined_score
except:
except Exception as e:
print(f"Error calculating combined score: {e}")
pass
# Clear the buffers after logging
@ -577,13 +586,15 @@ class FundamentalPredictionEnv(BaseEnv):
async def create_rollout_table(self, wandb_metrics):
if hasattr(self, "rollouts_for_wandb") and len(self.rollouts_for_wandb) > 0:
table = wandb.Table(columns=[
table = wandb.Table(
columns=[
"text",
"score",
"expected_direction",
"expected_magnitude",
"fundamental_metric"
])
"fundamental_metric",
]
)
for group in self.rollouts_for_wandb:
for item in group:

View file

@ -5,6 +5,7 @@ from concurrent.futures import ProcessPoolExecutor
from difflib import SequenceMatcher
from typing import Dict, List, Optional, Tuple
import wandb
from datasets import load_dataset
from latex2sympy2_extended import NormalizationConfig
from math_verify import LatexExtractionConfig, parse, verify
@ -12,7 +13,6 @@ from math_verify.errors import TimeoutException
from pydantic import Field
from tqdm.asyncio import tqdm_asyncio
import wandb
from atroposlib.envs.base import (
BaseEnv,
BaseEnvConfig,

View file

@ -8,7 +8,6 @@ from datasets import load_dataset
from atroposlib.envs.base import BaseEnv, BaseEnvConfig, OpenaiConfig, ScoredDataGroup
from atroposlib.type_definitions import GameHistory, Item
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer

View file

@ -13,13 +13,12 @@ import numpy as np
import requests
import torch
import torch.nn.functional as F
import wandb # Added for logging
from pydantic import BaseModel, Field
from tenacity import retry, stop_after_attempt, wait_exponential
from torch.optim import AdamW
from transformers import AutoModelForCausalLM, AutoTokenizer
import wandb # Added for logging
# Global variable to keep track of the vLLM process
vllm_process = None