From 40b12dae60ede531872e6b9771680a911a1801b9 Mon Sep 17 00:00:00 2001 From: dmahan93 Date: Fri, 9 May 2025 09:54:20 -0500 Subject: [PATCH] run pre-commit on all files --- .env.example | 2 +- .github/pull_request_template.md | 2 +- CODE_OF_CONDUCT.md | 2 +- .../cli/inference_node_wandb_watcher.py | 1 - atroposlib/envs/base.py | 6 +- atroposlib/utils/config_handler.py | 2 +- .../dataset_environment/LOCAL_TESTING.md | 2 +- .../configs/dataset_local.yaml | 2 +- .../dataset_environment/configs/gsm8k.yaml | 2 +- .../configs/gsm8k_debug.yaml | 2 +- .../dataset_environment/dataset_env.py | 85 +++++++------ .../dataset_local_server.py | 8 +- .../launch_local_dataset_run.py | 118 +++++++++++------- .../fundamental_prediction_environment.py | 47 ++++--- environments/math_server.py | 2 +- environments/multimodal_dpo/clevr_complex.py | 1 - example_trainer/grpo.py | 3 +- 17 files changed, 169 insertions(+), 118 deletions(-) diff --git a/.env.example b/.env.example index 9847a1df..e570b8b5 100644 --- a/.env.example +++ b/.env.example @@ -1 +1 @@ -OPENAI_API_KEY= \ No newline at end of file +OPENAI_API_KEY= diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md index 985423ec..108d8c0c 100644 --- a/.github/pull_request_template.md +++ b/.github/pull_request_template.md @@ -66,4 +66,4 @@ - [ ] My changes generate no new warnings - [ ] New and existing unit tests pass locally with my changes - [ ] Docstrings added for all new public classes / functions -- [ ] If .env vars required, did you add it to the .env.example in repo root? \ No newline at end of file +- [ ] If .env vars required, did you add it to the .env.example in repo root? diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md index 974722f5..dd1795ab 100644 --- a/CODE_OF_CONDUCT.md +++ b/CODE_OF_CONDUCT.md @@ -41,4 +41,4 @@ Project maintainers are obligated to respect the privacy and security of the rep This Code of Conduct is adapted from general open source community standards and GitHub's community guidelines. -Remember: Respect each other, collaborate constructively, and focus on making Atropos better for everyone. \ No newline at end of file +Remember: Respect each other, collaborate constructively, and focus on making Atropos better for everyone. diff --git a/atroposlib/cli/inference_node_wandb_watcher.py b/atroposlib/cli/inference_node_wandb_watcher.py index 2ef77520..b5f5fc45 100644 --- a/atroposlib/cli/inference_node_wandb_watcher.py +++ b/atroposlib/cli/inference_node_wandb_watcher.py @@ -2,7 +2,6 @@ import argparse import time import requests - import wandb diff --git a/atroposlib/envs/base.py b/atroposlib/envs/base.py index c8b9d04b..59c2adf3 100644 --- a/atroposlib/envs/base.py +++ b/atroposlib/envs/base.py @@ -387,8 +387,10 @@ class BaseEnv(ABC): # Now register the env... while True: data = await self._register_env() - if data['status'] != "success": - logging.warning(f"Waiting to register the env due to status {data['status']}") + if data["status"] != "success": + logging.warning( + f"Waiting to register the env due to status {data['status']}" + ) await asyncio.sleep(1) continue self.env_id = data["env_id"] diff --git a/atroposlib/utils/config_handler.py b/atroposlib/utils/config_handler.py index 3cfd1f7b..d7712f62 100644 --- a/atroposlib/utils/config_handler.py +++ b/atroposlib/utils/config_handler.py @@ -181,4 +181,4 @@ class ConfigHandler: # Add slurm flag to config if running in a Slurm environment config["use_slurm"] = "SLURM_JOB_ID" in os.environ - return config \ No newline at end of file + return config diff --git a/environments/dataset_environment/LOCAL_TESTING.md b/environments/dataset_environment/LOCAL_TESTING.md index a5c8eb87..1d430b21 100644 --- a/environments/dataset_environment/LOCAL_TESTING.md +++ b/environments/dataset_environment/LOCAL_TESTING.md @@ -152,4 +152,4 @@ server_configs: If you encounter issues with reward functions, make sure they are properly registered in the registry. -For dataset-related issues, verify that the dataset exists on HuggingFace and that the specified fields exist in the dataset. \ No newline at end of file +For dataset-related issues, verify that the dataset exists on HuggingFace and that the specified fields exist in the dataset. diff --git a/environments/dataset_environment/configs/dataset_local.yaml b/environments/dataset_environment/configs/dataset_local.yaml index 7849de34..d66a01f7 100644 --- a/environments/dataset_environment/configs/dataset_local.yaml +++ b/environments/dataset_environment/configs/dataset_local.yaml @@ -49,4 +49,4 @@ dataset: server_configs: - model_name: "gpt-4.1-nano" api_key: ${OPENAI_API_KEY} - timeout: 600 \ No newline at end of file + timeout: 600 diff --git a/environments/dataset_environment/configs/gsm8k.yaml b/environments/dataset_environment/configs/gsm8k.yaml index 7979fe46..ea19ea76 100644 --- a/environments/dataset_environment/configs/gsm8k.yaml +++ b/environments/dataset_environment/configs/gsm8k.yaml @@ -70,4 +70,4 @@ dataset: eval_dataset_name: "gsm8k" eval_dataset_config: "main" - eval_split: "test" \ No newline at end of file + eval_split: "test" diff --git a/environments/dataset_environment/configs/gsm8k_debug.yaml b/environments/dataset_environment/configs/gsm8k_debug.yaml index f928e9e2..372c5a3e 100644 --- a/environments/dataset_environment/configs/gsm8k_debug.yaml +++ b/environments/dataset_environment/configs/gsm8k_debug.yaml @@ -27,4 +27,4 @@ dataset: max_tokens: 4096 length_warmup_steps: 0 - min_tokens: 200 \ No newline at end of file + min_tokens: 200 diff --git a/environments/dataset_environment/dataset_env.py b/environments/dataset_environment/dataset_env.py index 602cf812..23548c73 100644 --- a/environments/dataset_environment/dataset_env.py +++ b/environments/dataset_environment/dataset_env.py @@ -32,7 +32,9 @@ class DatasetEnvConfig(BaseEnvConfig): None, description="Field in dataset containing canonical correct answer" ) system_prompt: Optional[str] = Field(None, description="System prompt to use") - prefill: Optional[str] = Field(None, description="Text to prefill the completion with (e.g. '')") + prefill: Optional[str] = Field( + None, description="Text to prefill the completion with (e.g. '')" + ) shuffle_dataset: bool = Field(True, description="Whether to shuffle the dataset") max_generations_per_prompt: int = Field( 1, description="Number of generations per prompt for collection" @@ -137,21 +139,21 @@ class DatasetEnv(BaseEnv): # Extract user prompt and answer from item user_content = dict(item[0][0])["content"] answer = item[1] if len(item) > 1 else None - + # Create messages list messages = [] if self.config.system_prompt: messages.append({"role": "system", "content": self.config.system_prompt}) - + messages.append({"role": "user", "content": user_content}) - + # Add prefill as assistant message if configured if self.config.prefill: messages.append({"role": "assistant", "content": self.config.prefill}) - + # Convert messages to a prompt string using the tokenizer prompt = self.tokenizer.apply_chat_template(messages, tokenize=False) - + # Calculate max tokens for generation (with optional warmup) max_tokens = self.config.max_tokens if self.config.length_warmup_steps > 0: @@ -160,7 +162,7 @@ class DatasetEnv(BaseEnv): self.config.min_tokens + warmup_progress * (self.config.max_tokens - self.config.min_tokens) ) - + # Generate completion using completions API completions = await self.server.completion( prompt=prompt, @@ -169,34 +171,38 @@ class DatasetEnv(BaseEnv): temperature=self.config.temperature, top_p=self.config.top_p, ) - + to_score = [] to_backlog = [] - + # Process completions for completion in completions.choices: # Get the completion text - completion_text = completion.text if hasattr(completion, "text") else completion.message.content - + completion_text = ( + completion.text + if hasattr(completion, "text") + else completion.message.content + ) + # Build full message sequence for scoring full_messages = [] if self.config.system_prompt: - full_messages.append({"role": "system", "content": self.config.system_prompt}) - + full_messages.append( + {"role": "system", "content": self.config.system_prompt} + ) + full_messages.append({"role": "user", "content": user_content}) - + # Combine prefill with completion if prefill was used response_content = completion_text if self.config.prefill: response_content = self.config.prefill + completion_text - + full_messages.append({"role": "assistant", "content": response_content}) - + # Add to scoring list with answer and ground truth - to_score.append( - (full_messages, answer, item[2] if len(item) > 2 else None) - ) - + to_score.append((full_messages, answer, item[2] if len(item) > 2 else None)) + return to_score, to_backlog async def postprocess_histories(self, trajectories: List) -> Tuple[List, List]: @@ -204,27 +210,27 @@ class DatasetEnv(BaseEnv): async def collect_trajectories(self, item: Item) -> Tuple[List, List]: self.current_item = item - + # Extract user prompt from item user_content = dict(item[0][0])["content"] - + # Create messages list messages = [] if self.config.system_prompt: messages.append({"role": "system", "content": self.config.system_prompt}) - + messages.append({"role": "user", "content": user_content}) - + # Add prefill as assistant message if configured if self.config.prefill: messages.append({"role": "assistant", "content": self.config.prefill}) - + # Convert messages to a prompt string using the tokenizer prompt = self.tokenizer.apply_chat_template(messages, tokenize=False) - + # Calculate max tokens for generation (with optional warmup) max_tokens = self.config.max_tokens - + # Generate completions completions = await self.server.completion( prompt=prompt, @@ -233,30 +239,36 @@ class DatasetEnv(BaseEnv): temperature=self.config.temperature, top_p=self.config.top_p, ) - + print(f"Completions: {completions}") # Process completions trajectories = [] for completion in completions.choices: # Get the completion text - completion_text = completion.text if hasattr(completion, "text") else completion.message.content - + completion_text = ( + completion.text + if hasattr(completion, "text") + else completion.message.content + ) + # Build complete message sequence full_messages = [] if self.config.system_prompt: - full_messages.append({"role": "system", "content": self.config.system_prompt}) - + full_messages.append( + {"role": "system", "content": self.config.system_prompt} + ) + full_messages.append({"role": "user", "content": user_content}) - + # Combine prefill with completion if prefill was used response_content = completion_text if self.config.prefill: response_content = self.config.prefill + completion_text - + full_messages.append({"role": "assistant", "content": response_content}) - + trajectories.append(full_messages) - + return trajectories, [] async def score(self, rollout_group_data: List) -> Optional[ScoredDataGroup]: @@ -402,6 +414,7 @@ class DatasetEnv(BaseEnv): await super().wandb_log(metrics) + if __name__ == "__main__": # Launch the DatasetEnv via the BaseEnv CLI (serve or process) DatasetEnv.cli() diff --git a/environments/dataset_environment/dataset_local_server.py b/environments/dataset_environment/dataset_local_server.py index 7fdf047a..f31fa3b5 100644 --- a/environments/dataset_environment/dataset_local_server.py +++ b/environments/dataset_environment/dataset_local_server.py @@ -8,7 +8,8 @@ from dotenv import load_dotenv from atroposlib.envs.base import OpenaiConfig from atroposlib.envs.reward_fns import registry -from atroposlib.utils.config_handler import ConfigHandler + +# from atroposlib.utils.config_handler import ConfigHandler from environments.dataset_environment.dataset_env import DatasetEnv, DatasetEnvConfig load_dotenv() @@ -23,7 +24,8 @@ def parse_arguments(): "--config", type=str, default="dataset_local", - help="Configuration file name (without .yaml extension) relative to environments/dataset_environment/configs/, or full path to a YAML file.", + help="Configuration file name (without .yaml extension) relative to environments/dataset_environment/configs/," + " or full path to a YAML file.", ) return parser.parse_args() @@ -35,7 +37,7 @@ async def main(): args = parse_arguments() # Initialize config handler - config_handler = ConfigHandler() + # config_handler = ConfigHandler() # Determine config path if ( diff --git a/environments/dataset_environment/launch_local_dataset_run.py b/environments/dataset_environment/launch_local_dataset_run.py index 26f6a2ac..2e95d05a 100644 --- a/environments/dataset_environment/launch_local_dataset_run.py +++ b/environments/dataset_environment/launch_local_dataset_run.py @@ -14,16 +14,16 @@ Requirements: - Run from project root so example_trainer is on PYTHONPATH - example_trainer/ is a valid Python package (with __init__.py) """ -import os -import sys -import subprocess -import time import atexit +import os import signal +import subprocess +import sys +import time import traceback # Ensure project root is on PYTHONPATH -project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..')) +project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")) if project_root not in sys.path: sys.path.insert(0, project_root) @@ -32,56 +32,73 @@ try: from example_trainer.grpo import TrainingConfig, train except ImportError as e: print(f"Error importing example_trainer.grpo: {e}") - print("Ensure you're running from project root and that example_trainer/ is a package.") + print( + "Ensure you're running from project root and that example_trainer/ is a package." + ) sys.exit(1) # ----------------------------------------------------------------------------- # Configuration # ----------------------------------------------------------------------------- -API_HOST = '127.0.0.1' +API_HOST = "127.0.0.1" API_PORT = 8000 -VLLM_HOST = '127.0.0.1' +VLLM_HOST = "127.0.0.1" VLLM_PORT = 9001 -MODEL_NAME = 'Qwen/Qwen2.5-1.5B-Instruct' +MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct" TOKENIZER_NAME = MODEL_NAME TRAINER_CONFIG = { - 'model_name': MODEL_NAME, - 'training_steps': 20, - 'batch_size': 2, - 'gradient_accumulation_steps': 2, - 'seq_len': 512, - 'vllm_port': VLLM_PORT, - 'vllm_restart_interval': 10, - 'use_wandb': False, - 'wandb_project': '', - 'wandb_group': '', - 'save_path': './trained_model_checkpoints_local_test', + "model_name": MODEL_NAME, + "training_steps": 20, + "batch_size": 2, + "gradient_accumulation_steps": 2, + "seq_len": 512, + "vllm_port": VLLM_PORT, + "vllm_restart_interval": 10, + "use_wandb": False, + "wandb_project": "", + "wandb_group": "", + "save_path": "./trained_model_checkpoints_local_test", } # Flags for launching DatasetEnv serve DATASET_FLAGS = [ - '--group_size', '4', - '--max_num_workers', '2', - '--rollout_server_url', f"http://{API_HOST}:{API_PORT}", - '--tokenizer_name', TOKENIZER_NAME, - '--use_wandb', - '--wandb_name', 'dataset_env_local_test', - '--max_token_length', str(TRAINER_CONFIG['seq_len']), - '--ensure_scores_are_not_same', - '--dataset_name', 'HuggingFaceH4/testing_self_instruct_process_essays', - '--split', 'train[:100]', - '--prompt_field', 'prompt', - '--answer_field', 'answer', - '--reward_functions', 'length', - '--max_tokens', '128', - '--temperature', '0.7', - '--model_name', MODEL_NAME, - '--base_url', f"http://{VLLM_HOST}:{VLLM_PORT}", - '--slurm', - '--testing', + "--group_size", + "4", + "--max_num_workers", + "2", + "--rollout_server_url", + f"http://{API_HOST}:{API_PORT}", + "--tokenizer_name", + TOKENIZER_NAME, + "--use_wandb", + "--wandb_name", + "dataset_env_local_test", + "--max_token_length", + str(TRAINER_CONFIG["seq_len"]), + "--ensure_scores_are_not_same", + "--dataset_name", + "HuggingFaceH4/testing_self_instruct_process_essays", + "--split", + "train[:100]", + "--prompt_field", + "prompt", + "--answer_field", + "answer", + "--reward_functions", + "length", + "--max_tokens", + "128", + "--temperature", + "0.7", + "--model_name", + MODEL_NAME, + "--base_url", + f"http://{VLLM_HOST}:{VLLM_PORT}", + "--slurm", + "--testing", ] # Track background processes for cleanup @@ -106,6 +123,7 @@ def cleanup_processes(): print(f"PID {p.pid} already exited.") print("Cleanup complete.") + atexit.register(cleanup_processes) @@ -113,6 +131,7 @@ def handle_signal(sig, frame): print(f"\nSignal {sig} received; exiting.") sys.exit(0) + signal.signal(signal.SIGINT, handle_signal) signal.signal(signal.SIGTERM, handle_signal) @@ -121,10 +140,12 @@ def main(): # 1) Start the API server print("--- Starting Trajectory Handler API Server ---") api_cmd = [ - 'uvicorn', - 'atroposlib.api.server:app', - '--host', API_HOST, - '--port', str(API_PORT), + "uvicorn", + "atroposlib.api.server:app", + "--host", + API_HOST, + "--port", + str(API_PORT), ] print(f"$ {' '.join(api_cmd)}") api_proc = subprocess.Popen(api_cmd) @@ -133,7 +154,12 @@ def main(): # 2) Start the dataset environment print("\n--- Starting Dataset Environment ---") - env_cmd = ['python', '-m', 'environments.dataset_environment.dataset_env', 'serve'] + DATASET_FLAGS + env_cmd = [ + "python", + "-m", + "environments.dataset_environment.dataset_env", + "serve", + ] + DATASET_FLAGS print(f"$ {' '.join(env_cmd)}") env_proc = subprocess.Popen(env_cmd) processes.append(env_proc) @@ -150,5 +176,5 @@ def main(): print("--- Training complete ---") -if __name__ == '__main__': - main() \ No newline at end of file +if __name__ == "__main__": + main() diff --git a/environments/fundamental_prediction_environment.py b/environments/fundamental_prediction_environment.py index 4b283b03..38ad5556 100644 --- a/environments/fundamental_prediction_environment.py +++ b/environments/fundamental_prediction_environment.py @@ -1,10 +1,10 @@ import random import re -from typing import List, Optional, Tuple, Union, Dict +from typing import Dict, List, Optional, Tuple, Union +import wandb from datasets import load_dataset from tqdm.asyncio import tqdm_asyncio -import wandb from atroposlib.envs.base import ( BaseEnv, @@ -508,23 +508,32 @@ class FundamentalPredictionEnv(BaseEnv): # Calculate and log training direction accuracy try: - direction_accuracy = sum(self.percent_correct_buffer) / len(self.percent_correct_buffer) + direction_accuracy = sum(self.percent_correct_buffer) / len( + self.percent_correct_buffer + ) wandb_metrics["train/direction_accuracy"] = direction_accuracy except ZeroDivisionError: pass # Skip if buffer is empty # Calculate and log training magnitude accuracy try: - magnitude_accuracy = sum(self.magnitude_accuracy_buffer) / len(self.magnitude_accuracy_buffer) + magnitude_accuracy = sum(self.magnitude_accuracy_buffer) / len( + self.magnitude_accuracy_buffer + ) wandb_metrics["train/magnitude_accuracy"] = magnitude_accuracy except ZeroDivisionError: pass # Skip if buffer is empty # Calculate combined training score (direction + magnitude) try: - combined_score = direction_accuracy + magnitude_accuracy if 'direction_accuracy' in wandb_metrics else 0 + combined_score = ( + direction_accuracy + magnitude_accuracy + if "direction_accuracy" in wandb_metrics + else 0 + ) wandb_metrics["train/combined_score"] = combined_score - except: + except Exception as e: + print(f"Error calculating combined score: {e}") pass # Clear the buffers after logging @@ -549,7 +558,7 @@ class FundamentalPredictionEnv(BaseEnv): # Get number of examples to keep num_keep = getattr(self.config, "num_rollouts_per_group_for_logging", -1) - + if num_keep == -1: num_keep = self.config.group_size @@ -577,23 +586,25 @@ class FundamentalPredictionEnv(BaseEnv): async def create_rollout_table(self, wandb_metrics): if hasattr(self, "rollouts_for_wandb") and len(self.rollouts_for_wandb) > 0: - table = wandb.Table(columns=[ - "text", - "score", - "expected_direction", - "expected_magnitude", - "fundamental_metric" - ]) - + table = wandb.Table( + columns=[ + "text", + "score", + "expected_direction", + "expected_magnitude", + "fundamental_metric", + ] + ) + for group in self.rollouts_for_wandb: for item in group: table.add_data(item[0], item[1], item[2], item[3], item[4]) - + wandb_metrics["train/rollouts"] = table - + # Clear rollouts after logging self.rollouts_for_wandb = [] - + return wandb_metrics diff --git a/environments/math_server.py b/environments/math_server.py index e5060016..544f95a0 100644 --- a/environments/math_server.py +++ b/environments/math_server.py @@ -5,6 +5,7 @@ from concurrent.futures import ProcessPoolExecutor from difflib import SequenceMatcher from typing import Dict, List, Optional, Tuple +import wandb from datasets import load_dataset from latex2sympy2_extended import NormalizationConfig from math_verify import LatexExtractionConfig, parse, verify @@ -12,7 +13,6 @@ from math_verify.errors import TimeoutException from pydantic import Field from tqdm.asyncio import tqdm_asyncio -import wandb from atroposlib.envs.base import ( BaseEnv, BaseEnvConfig, diff --git a/environments/multimodal_dpo/clevr_complex.py b/environments/multimodal_dpo/clevr_complex.py index 4f68080f..84f7a5a9 100644 --- a/environments/multimodal_dpo/clevr_complex.py +++ b/environments/multimodal_dpo/clevr_complex.py @@ -8,7 +8,6 @@ from datasets import load_dataset from atroposlib.envs.base import BaseEnv, BaseEnvConfig, OpenaiConfig, ScoredDataGroup from atroposlib.type_definitions import GameHistory, Item - from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer diff --git a/example_trainer/grpo.py b/example_trainer/grpo.py index fa7fca76..38273153 100644 --- a/example_trainer/grpo.py +++ b/example_trainer/grpo.py @@ -13,13 +13,12 @@ import numpy as np import requests import torch import torch.nn.functional as F +import wandb # Added for logging from pydantic import BaseModel, Field from tenacity import retry, stop_after_attempt, wait_exponential from torch.optim import AdamW from transformers import AutoModelForCausalLM, AutoTokenizer -import wandb # Added for logging - # Global variable to keep track of the vLLM process vllm_process = None