diff --git a/.env.example b/.env.example index 9847a1df..545ad9fa 100644 --- a/.env.example +++ b/.env.example @@ -1 +1,2 @@ -OPENAI_API_KEY= \ No newline at end of file +OPENAI_API_KEY= +OPENROUTER_API_KEY= diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md index 985423ec..108d8c0c 100644 --- a/.github/pull_request_template.md +++ b/.github/pull_request_template.md @@ -66,4 +66,4 @@ - [ ] My changes generate no new warnings - [ ] New and existing unit tests pass locally with my changes - [ ] Docstrings added for all new public classes / functions -- [ ] If .env vars required, did you add it to the .env.example in repo root? \ No newline at end of file +- [ ] If .env vars required, did you add it to the .env.example in repo root? diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml new file mode 100644 index 00000000..2b11178b --- /dev/null +++ b/.github/workflows/pre-commit.yml @@ -0,0 +1,14 @@ +name: pre-commit + +on: + pull_request: + push: + branches: [main] + +jobs: + pre-commit: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - uses: actions/setup-python@v3 + - uses: pre-commit/action@v3.0.1 diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md index 974722f5..dd1795ab 100644 --- a/CODE_OF_CONDUCT.md +++ b/CODE_OF_CONDUCT.md @@ -41,4 +41,4 @@ Project maintainers are obligated to respect the privacy and security of the rep This Code of Conduct is adapted from general open source community standards and GitHub's community guidelines. -Remember: Respect each other, collaborate constructively, and focus on making Atropos better for everyone. \ No newline at end of file +Remember: Respect each other, collaborate constructively, and focus on making Atropos better for everyone. diff --git a/CONFIG.md b/CONFIG.md index 454d7529..cc02a162 100644 --- a/CONFIG.md +++ b/CONFIG.md @@ -65,3 +65,4 @@ Configuration for individual OpenAI-compatible API servers (including local SGLa | `num_requests_for_eval` | `int` | `64` | Maximum number of concurrent requests for evaluation. | | `model_name` | `str` | `default` | The model name to use. Required for both OpenAI and local models (e.g., `"gpt-4"`, `"NousResearch/..."`). | | `rolling_buffer_length` | `int` | `1000` | Length of the rolling buffer to store server metrics (like request timings, attempts). | +| `n_kwarg_is_ignored` | `bool` | `False` | If the n kwarg is ignored by the API you are using, set this to True. | diff --git a/README.md b/README.md index 8c637661..0a2a5cc3 100644 --- a/README.md +++ b/README.md @@ -22,16 +22,17 @@ -Atropos is a Language Model Reinforcement Learning Environments framework for collecting and evaluating LLM trajectories through diverse environments including: +--- +## What is Atropos? +Atropos is an environment microservice framework for async RL with LLMs. + +Atropos encompasses both environments, which are set up as services, and a trajectory API for the environments to send data to and for the trainer to pull batches from. + +![image](https://github.com/user-attachments/assets/8ce52994-b219-49d6-970c-58a477f36151)
-| Environment Type | Examples | Purpose | -|---------------------------|--------------------------------------------|----------------------------------------------------| -| 📚 Dataset environments | GSM8K, MMLU | Evaluate and improve LLM performance on static data| -| 🎮 Online environments | Crosswords, Hangman | Train LLMs through interactive game-based learning | -| 🤖 RLAIF and RLHF | LLM Judge/Reward Models | Fine-tune LLMs using human feedback and alignment | -| 🔄 Multi-Turn RL | deepresearch, internal tool calling | Train LLMs on complex multi-step interactions | + *Here is a diagram of how Atropos' components can interact with a trainer & inference server to complete the RL loop (trainer & inference engine not included with the atropos package)*
@@ -45,6 +46,19 @@ Atropos is a robust, scalable framework for **Reinforcement Learning Environment The goal: provide a flexible, scalable, and standardized platform to accelerate LLM-based RL research across diverse, interactive settings. +The framework supports collecting, distributing and evaluating LLM trajectories through diverse environments including: + +
+ +| Environment Type | Examples | Purpose | +|---------------------------|--------------------------------------------|----------------------------------------------------| +| 📚 Dataset environments | GSM8K, MMLU | Evaluate and improve LLM performance on static data| +| 🎮 Online environments | Crosswords, Hangman | Train LLMs through interactive game-based learning | +| 🤖 RLAIF and RLHF | LLM Judge/Reward Models | Fine-tune LLMs using human feedback and alignment | +| 🔄 Multi-Turn RL | deepresearch, internal tool calling | Train LLMs on complex multi-step interactions | + +
+ ## 🎉 Upcoming Atropos Hackathon: LLM RL Environments Join us in San Francisco on May 18th, 2025 for an exciting hackathon focused on building and experimenting with LLM RL Environments! This in-person event will bring together researchers and developers interested in advancing the field of LLM reinforcement learning. @@ -164,16 +178,23 @@ pre-commit install 2. **Run an Example Environment** - You should edit the config_init section of the environment file you want ([For example, in GSM8K Environment](https://github.com/NousResearch/atropos/blob/main/environments/gsm8k_server.py#L53)) to point to a running VLLM or SGLang inference server as well as any other configuration changes you'd like to make, such as the group size, then: + You should edit the config_init section of the environment file you want ([For example, in GSM8K Environment](https://github.com/NousResearch/atropos/blob/main/environments/gsm8k_server.py#L53)) to point to a running VLLM or SGLang inference server as well as any other [configuration changes](CONFIG.md) you'd like to make, such as the group size, then: ```bash - # Start the API server and run the GSM8K environment - run-api & python environments/gsm8k_server.py serve \ - --slurm false + # Start the API server + run-api ``` -3. **Query the the API (Optional)** + In a separate terminal, start the GSM8K environment microservice + ```bash + python environments/gsm8k_server.py serve --openai.model_name Qwen/Qwen2.5-1.5B-Instruct --slurm false + # alternatively + # python environments/gsm8k_server.py serve --config environments/configs/example.yaml + # python environments/gsm8k_server.py serve --config environments/configs/example.yaml --env.group_size 8 # cli args override corresponding config settings + ``` +3. **Grabbing Rollouts** - If you want to just query the api, start getting rollouts, and not use a trainer, see [API Docs](https://github.com/NousResearch/atropos/tree/main/atroposlib/api) to explore the REST API interface that this API exposes, if you plan to use a trainer, skip to step 4. + If you want to just start getting rollouts, and not use a trainer, see the [debug section](#testing-and-debugging-tools) + for help getting started with the available tools, we recommend starting with process or view-run 4. **Training Your Model** - Follow our [training example guide](example_trainer/README.md) for detailed instructions @@ -190,7 +211,7 @@ Environments come with detailed logging and reporting support, runs track comple --- -## Debugging Tools +## Testing and Debugging Tools The trajectory-handler provides several debugging tools to help environment developers test and understand their environments locally without requiring the full distributed infrastructure. diff --git a/atroposlib/api/utils.py b/atroposlib/api/utils.py index c2fef67c..bf89e4e6 100644 --- a/atroposlib/api/utils.py +++ b/atroposlib/api/utils.py @@ -22,40 +22,63 @@ def grab_exact_from_heterogeneous_queue( :param batch_size: :return: batch, new_queue """ - # check if we can even potentially grab a batch - if sum(len(item["tokens"]) for item in queue) < batch_size: + + # Pass 1: precompute group sizes, total tokens and early exit if not enough tokens. + total_groups = len(queue) + if total_groups == 0: return None, queue - # Get max batch size - max_group_size = max(len(group["tokens"]) for group in queue) - group_sizes = set(len(group["tokens"]) for group in queue) - group_batching_storage = {i: [] for i in group_sizes} - # pack the groups into [max_group_size // group_size] packs - potential_batch = [] - for i, item in enumerate(queue): - key = len(item["tokens"]) - group_batching_storage[key].append({"group": item, "indx": i}) - if len(group_batching_storage[key]) * key == max_group_size: - potential_batch.extend(group_batching_storage[key]) - group_batching_storage[key] = [] - if ( - sum(len(grouped_items["group"]["tokens"]) for grouped_items in potential_batch) - < batch_size - ): + + group_sizes = [] + lengths = [] + total_tokens = 0 + max_group_size = 0 + + for item in queue: + length = len(item["tokens"]) + lengths.append(length) + group_sizes.append(length) + total_tokens += length + if length > max_group_size: + max_group_size = length + + if total_tokens < batch_size: return None, queue - # we have a batch + + group_sizes_set = set(group_sizes) + group_batching_storage = {size: [] for size in group_sizes_set} + + # Index into the queue and batch related indices into "packs" + potential_batch_indices = [] + for i, group_size in enumerate(group_sizes): + group_batching_storage[group_size].append(i) + if len(group_batching_storage[group_size]) * group_size == max_group_size: + potential_batch_indices.extend(group_batching_storage[group_size]) + group_batching_storage[group_size].clear() # much faster than = [] + + # Calculate total batch tokens only once (avoid repeated sums) + potential_batch_token_total = sum(lengths[i] for i in potential_batch_indices) + if potential_batch_token_total < batch_size: + return None, queue + + # Batch selection batch = [] - indxes_to_remove_from_queue = [] - for item in potential_batch: - group = item["group"] - indx = item["indx"] + batch_indices = [] + running_tokens = 0 + for idx in potential_batch_indices: + group = queue[idx] batch.append(group) - indxes_to_remove_from_queue.append(indx) - if sum(len(item["tokens"]) for item in batch) == batch_size: + batch_indices.append(idx) + running_tokens += lengths[idx] + if running_tokens == batch_size: break - if sum(len(item["tokens"]) for item in batch) != batch_size: + elif running_tokens > batch_size: + # Should never happen due to problem constraints, but sanity check + return None, queue + + if running_tokens != batch_size: return None, queue - # remove the items from the queue - new_queue = [ - item for i, item in enumerate(queue) if i not in indxes_to_remove_from_queue - ] + + # Construct new_queue with a single pass, using a set for O(1) lookup + batch_indices_set = set(batch_indices) + new_queue = [item for i, item in enumerate(queue) if i not in batch_indices_set] return batch, new_queue diff --git a/atroposlib/cli/dpo.py b/atroposlib/cli/dpo.py index bd409c18..15452f24 100644 --- a/atroposlib/cli/dpo.py +++ b/atroposlib/cli/dpo.py @@ -8,6 +8,8 @@ import jsonlines from tqdm.asyncio import tqdm # Import tqdm for async from transformers import AutoTokenizer +from atroposlib.utils.io import parse_http_response + def find_common_prefix(strings): """ @@ -80,7 +82,7 @@ async def check_for_batch(api_url): while True: async with aiohttp.ClientSession() as session: async with session.get(f"{api_url}/batch") as response: - data = await response.json() + data = await parse_http_response(response) if data["batch"] is not None: return data["batch"] await asyncio.sleep(1) # Wait before polling again diff --git a/atroposlib/cli/inference_node_wandb_watcher.py b/atroposlib/cli/inference_node_wandb_watcher.py index 2ef77520..b5f5fc45 100644 --- a/atroposlib/cli/inference_node_wandb_watcher.py +++ b/atroposlib/cli/inference_node_wandb_watcher.py @@ -2,7 +2,6 @@ import argparse import time import requests - import wandb diff --git a/atroposlib/cli/sft.py b/atroposlib/cli/sft.py index b58badb1..c5781b32 100644 --- a/atroposlib/cli/sft.py +++ b/atroposlib/cli/sft.py @@ -7,6 +7,8 @@ import jsonlines from tqdm.asyncio import tqdm # Import tqdm for async from transformers import AutoTokenizer +from atroposlib.utils.io import parse_http_response + def find_common_prefix(strings): """ @@ -79,7 +81,7 @@ async def check_for_batch(api_url): while True: async with aiohttp.ClientSession() as session: async with session.get(f"{api_url}/batch") as response: - data = await response.json() + data = await parse_http_response(response) if data["batch"] is not None: return data["batch"] await asyncio.sleep(1) # Wait before polling again diff --git a/atroposlib/cli/view_run.py b/atroposlib/cli/view_run.py index 42462355..e22bcbc9 100644 --- a/atroposlib/cli/view_run.py +++ b/atroposlib/cli/view_run.py @@ -5,6 +5,8 @@ import aiohttp import gradio as gr from transformers import AutoTokenizer +from atroposlib.utils.io import parse_http_response + def find_common_prefix(strings): if not strings: @@ -46,7 +48,7 @@ async def check_for_batch(): while True: async with aiohttp.ClientSession() as session: async with session.get("http://localhost:8000/batch") as response: - data = await response.json() + data = await parse_http_response(response) print(data) if data["batch"] is not None: return data["batch"] diff --git a/atroposlib/cli/view_run_multimodal.py b/atroposlib/cli/view_run_multimodal.py new file mode 100644 index 00000000..c7ae0592 --- /dev/null +++ b/atroposlib/cli/view_run_multimodal.py @@ -0,0 +1,184 @@ +import argparse +import asyncio +import base64 +import re +from io import BytesIO + +import aiohttp +import gradio as gr +import PIL.Image +from transformers import AutoTokenizer + + +def find_common_prefix(strings): + if not strings: + return "" + + prefix = strings[0] + for s in strings[1:]: + while not s.startswith(prefix): + prefix = prefix[:-1] + if not prefix: + return "" + return prefix + + +async def register_to_api(group_size, max_token_len): + async with aiohttp.ClientSession() as session: + async with session.get("http://localhost:8000/reset_data") as response: + print(await response.text()) + print(group_size) + async with session.post( + "http://localhost:8000/register", + json={ + "wandb_group": "test", + "wandb_project": "test", + "batch_size": group_size + * 8, # * 8 just in case you want to just sample from a large group + "max_token_len": max_token_len, + "checkpoint_dir": "checkpoints", + "save_checkpoint_interval": 10, + "starting_step": 0, + "num_steps": 69, + }, + ) as response: + print("output of register is") + print(await response.text()) + + +async def check_for_batch(): + while True: + async with aiohttp.ClientSession() as session: + async with session.get("http://localhost:8000/batch") as response: + data = await response.json() + print(data) + if data["batch"] is not None: + return data["batch"] + await asyncio.sleep(1) + + +def extract_image_from_chat(chat_text): + # Extract the base64 image data from the chat text + # Support both jpeg and png formats + image_pattern = r'data:image/(jpeg|png);base64,([^"\\]*)' + match = re.search(image_pattern, chat_text) + + if match: + base64_data = match.group(2) + try: + image_data = base64.b64decode(base64_data) + image = PIL.Image.open(BytesIO(image_data)) + return image + except Exception as e: + print(f"Error decoding image: {e}") + return None + + +def extract_text_from_chat(chat_text): + # Try to extract text from JSON format first + # Check if this is JSON multimodal content + if '"type": "text"' in chat_text: + text_pattern = r'"type": "text", "text": "([^"]*)"' + match = re.search(text_pattern, chat_text) + if match: + return match.group(1) + + # If not in JSON format, look for [Image] prefix + if "[Image]" in chat_text: + return chat_text.split("[Image]", 1)[1].strip() + + # Return original text if no pattern is found + return chat_text + + +async def build_interface(group_size, max_token_len, tokenizer, port): + async def grab_batch(): + tok = AutoTokenizer.from_pretrained(tokenizer) + data = await check_for_batch() + print(data) + chats = [tok.decode(chat) for chat in data[0]["tokens"]] + + # Find common prefix + prefix = find_common_prefix(chats) + + # Handle base64 encoded image + try: + if "images" in data[0] and data[0]["images"] and data[0]["images"][0]: + print("Found image data in batch") + # Convert base64 string to image + base64_image = data[0]["images"][0] + + # If it's already a PIL Image, use it directly + if isinstance(base64_image, PIL.Image.Image): + image = base64_image + # If it's a base64 string, decode it + elif isinstance(base64_image, str): + # Remove data:image prefix if present + if base64_image.startswith("data:image"): + # Extract just the base64 part + image_data = base64_image.split(",", 1)[1] + else: + image_data = base64_image + + # Decode base64 to bytes and create image + image_bytes = base64.b64decode(image_data) + image = PIL.Image.open(BytesIO(image_bytes)) + else: + print(f"Image type not recognized: {type(base64_image)}") + image = None + else: + # Try to extract image from chat text as fallback + print("No images field found, trying to extract from chat text") + image = extract_image_from_chat(prefix) + except Exception as e: + print(f"Error processing image: {e}") + image = None + + # Extract text prompt from prefix + text_prompt = extract_text_from_chat(prefix) + + return ( + image, # Image + text_prompt, # Text prompt + *[chat.split(prefix)[1] for chat in chats[:group_size]], # Model outputs + *data[0]["scores"][:group_size], # Scores + ) + + with gr.Blocks() as demo: + image_blk = gr.Image(label="Image", type="pil") + prompt_blk = gr.Textbox(label="Text Prompt") + + with gr.Row(): + score_blks = [gr.Textbox(label=f"Score_{i+1}") for i in range(group_size)] + + with gr.Row(): + outputs_blks = [ + gr.Textbox(label=f"Output_{i+1}") for i in range(group_size) + ] + + with gr.Row(): + grab_next = gr.Button(value="Grab Next Batch") + + grab_next.click( + fn=grab_batch, + outputs=[image_blk, prompt_blk] + outputs_blks + score_blks, + api_name="get_batch", + ) + await register_to_api(group_size, max_token_len) + demo.launch(server_port=port, share=True) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--port", type=int, default=9001) + parser.add_argument("--group-size", type=int, default=2) + parser.add_argument("--max-token-len", type=int, default=2048) + parser.add_argument("--tokenizer", type=str, default="Qwen/Qwen2-VL-2B-Instruct") + args = parser.parse_args() + asyncio.run( + build_interface(args.group_size, args.max_token_len, args.tokenizer, args.port) + ) + + +if __name__ == "__main__": + main() diff --git a/atroposlib/envs/README.md b/atroposlib/envs/README.md index 85c3f75e..c88d9706 100644 --- a/atroposlib/envs/README.md +++ b/atroposlib/envs/README.md @@ -2,6 +2,21 @@ The `BaseEnv` class (located in `trajectoryhandler/envs/base.py`) provides a foundation for creating custom reinforcement learning environments that interact with Atropos. When creating your own environment, you will typically subclass `BaseEnv` and implement several key methods. +## Design philosophy + +Every environment in Atropos is a microservice that generates rollout data async from whatever trainer you attach to it. Environments (possibly many at once) send data to the [Atropos API server](https://github.com/NousResearch/atropos/tree/main/atroposlib/api) which sequesters rollout data. Your trainer of choice grabs batches of data from the API and backpropagates. + +![image](https://github.com/user-attachments/assets/1cc8634a-319d-4add-9db7-c8b3acc272ad) + +Unlike other popular alternatives like Gymnasium which model environments as [MDPs](https://arxiv.org/abs/2412.05265), we think about environments as dataloaders and do not make any assumptions about how a trajectory is produced. For multi-agent for example, this means our design is agnostic to [AEC](https://pettingzoo.farama.org/api/aec/) vs. [POSG](https://pettingzoo.farama.org/api/parallel/) - both are supported out of the box. +To achieve this generality, our environment abstraction deviates from other open source alternatives in several key ways. + +- **Inference-scoring fusion**: A popular design choice in open-source LLM RL trainers is to separate inference and scoring into independent abstractions. While this makes a lot of sense for single-turn environments like one-shot MCQA, we found that this led to awkwardness in multi-turn setups with process rewards. As such, we assume the existence of a single method `collect_trajectories` which is responsible for both inference and scoring. Users are still welcome to call separate inference and scoring methods from within `collect_trajectories`. + +- **Groups as atomic units of data**: A natural choice for a data atom in RL is a single trajectory. However, many popular RL methods for fine-tuning LLMs such as DPO and GRPO involve packing contrastive data into the same batch. As such the most fundamental dataloading method in our abstraction is not `collect_trajectory` (singular) but `collect_trajectories` (plural). We do not enforce any definition of what a "group" is other than a set of rollouts. Although a "group" is most commonly constructed by generated multiple rollouts starting from the same initial state (as in DPO and GRPO), a user could just as easily pack `n` similar-sounding problems with very different solutions into a group. For cases like PPO where advantages don't depend on group statistics a user can simply use group size 1. + +- **Environments return tokens (not messages!)**: One of the most peculiar design choices we made was that at least for text-only environments, environments are responsible for tokenization. This gives us the flexibility to assign token-level rewards and to mix completions-based (e.g. autocomplete suggestion accept/reject) and chat-based (e.g. instruct-model code generation) environments together in the same training run. For cases like multimodal where a OpenAI-formatted message list needs to be passed to a transformers `AutoProcessor`, we support a `list[dict]`-valued `messages` key within our group abstraction [ScoredDataGroup](https://github.com/NousResearch/atropos/blob/a282604baac8dbb3b201f992cfc889ee1e5a0f4a/atroposlib/envs/base.py#L55). + ## Core Methods to Implement These methods **must** be implemented in your subclass: @@ -10,12 +25,13 @@ These methods **must** be implemented in your subclass: * **`async def get_next_item(self) -> Item`**: This method is responsible for generating or retrieving the next piece of data (prompt, state, etc.) that will be used to start a new trajectory collection. If no more items are available or should be generated, it can return `None` to signal the worker to pause. -* **`async def collect_trajectory(self, item: Item) -> Tuple[Any | None, List[Item]]`**: This method defines the logic for a *single* logical trajectory collection step based on the input `item`. \ - * **How it relates to multiple generations**: The `BaseEnv` uses `collect_trajectories` to run this method multiple times in parallel (controlled by `group_size`) to gather a batch of trajectories. \ - * **Your implementation**: You can implement this method to generate *one* response/trajectory per call.\ +* **`async def collect_trajectories(self, item: Item) -> Tuple[Union[Optional[ScoredDataGroup], List[Optional[ScoredDataGroup]], List[Any | None]], List[Item]]`**: The default implementation of this method runs `collect_trajectory` (see below) multiple times in parallel (controlled by `group_size`). You can override this if you have a more efficient way to generate the entire group of responses/trajectories at once based on the input `item` (e.g. the `n` parameter in the OpenAI chat completions API) or some desired coupling of rollouts (e.g. via MCTS). It should return the collected group data and a list of backlog items. + +* **`async def collect_trajectory(self, item: Item) -> Tuple[Any | ScoredDataItem | None, List[Item]]`**: If the rollouts for your environment can be sampled independently, the easiest way to implement GRPO-style grouping is to define the `collect_trajectory` method and use the default implementation of `collect_trajectories` which runs `group_size` instances of `collect_trajectory` in parallel. This method defines the logic for a *single* logical trajectory collection step based on the input `item`. * **Return value**: It returns a tuple containing:\ - 1. The collected data for this step (one trajectory). This data can be processed further in `postprocess_histories`, if you require additional filtering right before sending to the API.\ + 1. The ScoredDataItem for this step (one trajectory). This data can be processed further in `postprocess_histories`, if you require additional filtering right before sending to the API. 2. A list of new `Item` objects to be added to the backlog for future processing (e.g., follow-up prompts).\ + * **Should I define `collect_trajectory` or override `collect_trajectories`?** If you've got some way to generate your group more efficiently than a bunch of separate but parallel calls to `collect_trajectory`, or if your rollouts aren't independent as in MCTS, you should override `collect_trajectories`. If simplicity and iteration speed is more valuable than efficiency (e.g. at the start of a development cycle) and your rollouts are independent then `collect_trajectory` is for you. * **`async def evaluate(self, *args, **kwargs)`**: This method is called periodically (controlled by `steps_per_eval` in the config) to perform evaluation runs. You define the evaluation logic here. The base class provides an example using `self.eval_workers` for parallel evaluation tasks, but you can implement any evaluation procedure suitable for your environment. diff --git a/atroposlib/envs/base.py b/atroposlib/envs/base.py index 5c170a62..52f8bb84 100644 --- a/atroposlib/envs/base.py +++ b/atroposlib/envs/base.py @@ -20,15 +20,12 @@ import wandb import yaml from pydantic import BaseModel, Field from pydantic_cli import Cmd, FailedExecutionException, run_and_exit +from rich import print as rprint from tenacity import retry, stop_after_attempt, wait_random_exponential from transformers import AutoTokenizer -from atroposlib.envs.constants import ( - ENV_NAMESPACE, - NAMESPACE_SEP, - OPENAI_NAMESPACE, - SERVER_MANAGER_NAMESPACE, -) +from atroposlib.envs.constants import ENV_NAMESPACE, NAMESPACE_SEP, OPENAI_NAMESPACE +from atroposlib.envs.server_handling.openai_server import resolve_openai_configs from atroposlib.frontend.jsonl2html import generate_html from atroposlib.type_definitions import UUID from atroposlib.utils.cli import ( @@ -38,6 +35,7 @@ from atroposlib.utils.cli import ( get_prefixed_pydantic_model, merge_dicts, ) +from atroposlib.utils.io import parse_http_response from atroposlib.utils.metrics import get_std_min_max_avg from ..type_definitions import Item, Message @@ -63,6 +61,17 @@ class ScoredDataGroup(TypedDict): overrides: Optional[List[Dict]] +class ScoredDataItem(TypedDict): + tokens: List[int] + masks: List[int] + scores: float + advantages: Optional[List[float]] + ref_logprobs: Optional[List[float]] + messages: Optional[List[Message]] + group_overrides: Optional[Dict] + overrides: Optional[Dict] + + class EvalHandlingEnum(Enum): """ Enum for handling evals. @@ -237,7 +246,9 @@ class BaseEnv(ABC): """ return cls.env_config_cls(), ServerBaseline(), None - async def collect_trajectory(self, item: Item) -> Tuple[Any | None, List[Item]]: + async def collect_trajectory( + self, item: Item + ) -> Tuple[Optional[Union[ScoredDataItem, Any]], List[Item]]: raise NotImplementedError( "Handle env single method must be implemented in subclass " ) @@ -257,13 +268,38 @@ class BaseEnv(ABC): for _ in range(self.config.group_size): tasks.append(self.collect_trajectory(item)) results = await asyncio.gather(*tasks) + if any(not isinstance(result[0], dict) for result in results): + logging.error("something wasn't a ScoredDataItem") + raise ValueError( + "collect_trajectory must return a ScoredDataItem or None to use the default " + "collect_trajectories method" + ) backlog = [] - to_postprocess = [] + to_postprocess = ScoredDataGroup() + to_postprocess["tokens"] = [] + to_postprocess["masks"] = [] + to_postprocess["scores"] = [] + to_postprocess["advantages"] = [] + to_postprocess["ref_logprobs"] = [] + to_postprocess["messages"] = [] + to_postprocess["group_overrides"] = {} + to_postprocess["overrides"] = [] + print("Processing results") for result in results: - if result[0] is not None: - to_postprocess.append(result[0]) + to_postprocess["tokens"].append(result[0]["tokens"]) + to_postprocess["masks"].append(result[0]["masks"]) + to_postprocess["scores"].append(result[0]["scores"]) + if result[0].get("advantages", None) is not None: + to_postprocess["advantages"].append(result[0]["advantages"]) + if result[0].get("ref_logprobs", None) is not None: + to_postprocess["ref_logprobs"].append(result[0]["ref_logprobs"]) + if result[0].get("messages", None) is not None: + to_postprocess["messages"].append(result[0]["messages"]) + if result[0].get("group_overrides", None) is not None: + to_postprocess["group_overrides"].update(result[0]["group_overrides"]) + if result[0].get("overrides", None) is not None: + to_postprocess["overrides"].append(result[0]["overrides"]) backlog.extend(result[1]) - random.shuffle(backlog) return to_postprocess, backlog async def postprocess_histories( @@ -358,7 +394,7 @@ class BaseEnv(ABC): async with session.get( f"{self.config.rollout_server_url}/wandb_info" ) as resp: - data = await resp.json() + data = await parse_http_response(resp, logger) self.wandb_group = data["group"] self.wandb_project = data["project"] if self.wandb_project is None: @@ -386,7 +422,7 @@ class BaseEnv(ABC): "weight": self.config.inference_weight, }, ) as resp: - data = await resp.json() + data = await parse_http_response(resp, logger) return data except Exception as e: logger.error(f"Error registering env: {e}") @@ -427,7 +463,7 @@ class BaseEnv(ABC): """ async with aiohttp.ClientSession() as session: async with session.get(f"{self.config.rollout_server_url}/info") as resp: - data = await resp.json() + data = await parse_http_response(resp, logger) if data["batch_size"] != -1: # update the batch size self.config.batch_size = data["batch_size"] @@ -710,7 +746,7 @@ class BaseEnv(ABC): f"{self.config.rollout_server_url}/status-env", json={"env_id": self.env_id}, ) as resp: - self.status_dict = await resp.json() + self.status_dict = await parse_http_response(resp, logger) new_weight = self.status_dict["env_weight"] max_num_workers = self.config.max_num_workers if max_num_workers == -1: @@ -964,53 +1000,158 @@ class BaseEnv(ABC): Returns: type: The CliServeConfig class for serving commands. """ + # Get the default configurations defined by the specific environment class configs_and_maybe_server_class = cls.config_init() if len(configs_and_maybe_server_class) == 2: - env_config, server_configs = configs_and_maybe_server_class + default_env_config, default_server_configs = configs_and_maybe_server_class server_class = None else: - env_config, server_configs, server_class = configs_and_maybe_server_class + default_env_config, default_server_configs, server_class = ( + configs_and_maybe_server_class + ) + + # Define namespace prefixes for CLI arguments and YAML keys env_full_prefix = f"{ENV_NAMESPACE}{NAMESPACE_SEP}" openai_full_prefix = f"{OPENAI_NAMESPACE}{NAMESPACE_SEP}" + # Define the CLI configuration class dynamically class CliServeConfig( - get_prefixed_pydantic_model(type(env_config), env_full_prefix), - get_prefixed_pydantic_model(APIServerConfig, openai_full_prefix), - ServerManagerConfig, + get_prefixed_pydantic_model(type(default_env_config), env_full_prefix), + get_prefixed_pydantic_model( + APIServerConfig, openai_full_prefix + ), # Use APIServerConfig for CLI args + ServerManagerConfig, # ServerManager args are not namespaced by default Cmd, ): """ Configuration for the serve command. - This combines BaseEnvConfig and APIServerConfig into a single command. + Supports overrides via YAML config file and CLI arguments. + Order of precedence: CLI > YAML > Class Defaults. """ + config: str | None = Field( + default=None, + description="Path to .yaml config file. CLI args override this.", + ) + def run(self) -> None: """The logic to execute for the 'serve' command.""" - # Convert this config into the formats needed by BaseEnv + # Set default wandb name if not provided and class has a name + # Note: This modifies the 'self' instance based on CLI args before full parsing. wandb_name_attr = f"{ENV_NAMESPACE}{NAMESPACE_SEP}wandb_name" - if getattr(self, wandb_name_attr) is None and cls.name is not None: + if ( + getattr(self, wandb_name_attr, None) is None + and cls.name is not None + ): setattr(self, wandb_name_attr, cls.name) - model_dumped = self.model_dump(exclude_unset=True) - server_manager_config = ServerManagerConfig(**model_dumped) - # Create the environment instance - try: - env = cls( - config=env_config, - server_configs=server_configs, - slurm=server_manager_config.slurm, - testing=server_manager_config.testing, - server_class=server_class, + + # Load configuration from YAML file if specified + if self.config is not None: + with open(self.config, "r") as f: + yaml_config = yaml.safe_load(f) + print(f"Loaded config from {self.config}") + else: + yaml_config = {} + + # Get CLI flags passed with double dashes (e.g., --env--foo bar) + cli_passed_flags = get_double_dash_flags() + + # --- Configuration Merging --- + # Priority: CLI > YAML > Class Defaults + + # 1. Environment Configuration + env_config_dict = merge_dicts( + default_env_config.model_dump(), # Class Defaults + yaml_config.get(ENV_NAMESPACE, {}), # YAML config + extract_namespace(cli_passed_flags, env_full_prefix), # CLI args + ) + + # 2. OpenAI Configuration (used for potential overrides) + oai_cli_passed_args = extract_namespace( + cli_passed_flags, openai_full_prefix + ) # CLI args + yaml_oai_config = yaml_config.get(OPENAI_NAMESPACE, {}) + if isinstance(default_server_configs, ServerBaseline) and ( + oai_cli_passed_args or yaml_oai_config + ): + raise ValueError( + "ServerBaseline is not compatible with OpenAI-namespaced CLI arguments. Please edit `config_init` directly or use APIServerConfig." # noqa: E501 ) - except TypeError as e: - warnings.warn( - "Not supporting server_class will be deprecated soon, please add that kwarg" - ) - env = cls( - config=env_config, - server_configs=server_configs, - slurm=server_manager_config.slurm, - testing=server_manager_config.testing, + if ( + isinstance(default_server_configs, list) + and len(default_server_configs) == 1 + ): + # can't use the same var name because it shadows the class variable and we get an error + default_openai_config_ = default_server_configs[0] + else: + default_openai_config_ = default_server_configs + if isinstance(yaml_oai_config, list) and len(yaml_oai_config) == 1: + yaml_oai_config = yaml_oai_config[0] + if isinstance(default_openai_config_, APIServerConfig) and isinstance( + yaml_oai_config, dict + ): + openai_config_dict = merge_dicts( + default_openai_config_.model_dump(), # Default APIServerConfig (or from class init) + yaml_oai_config, + oai_cli_passed_args, ) + else: + openai_config_dict = {} + + # 3. Server Manager Configuration (slurm, testing - not namespaced) + # Extract only relevant CLI flags for ServerManager + server_manager_cli_passed_flags = {} + if "slurm" in cli_passed_flags: + server_manager_cli_passed_flags["slurm"] = cli_passed_flags["slurm"] + if "testing" in cli_passed_flags: + server_manager_cli_passed_flags["testing"] = cli_passed_flags[ + "testing" + ] + + server_manager_yaml_dict = {} + if "slurm" in yaml_config: + server_manager_yaml_dict["slurm"] = yaml_config["slurm"] + if "testing" in yaml_config: + server_manager_yaml_dict["testing"] = yaml_config["testing"] + + server_manager_config_dict = merge_dicts( + ServerManagerConfig().model_dump(), # Base defaults for ServerManager + server_manager_yaml_dict, # YAML config + server_manager_cli_passed_flags, # CLI args + ) + + # --- Instantiate Final Config Objects --- + # Create instances from the merged dictionaries using the original default types where appropriate + + # Instantiate the final environment config using its original type + env_config = type(default_env_config)(**env_config_dict) + + # Instantiate the final server manager config + server_manager_config = ServerManagerConfig( + **server_manager_config_dict + ) + + # Determine the final server_configs, handling single, multiple servers, and overrides. + + openai_configs = resolve_openai_configs( + default_server_configs=default_server_configs, + openai_config_dict=openai_config_dict, + yaml_config=yaml_config, + cli_passed_flags=cli_passed_flags, + logger=logger, + ) + + # --- Create and Run Environment --- + # Create the environment instance using the final, instantiated config objects + env = cls( + config=env_config, + server_configs=openai_configs, + slurm=server_manager_config.slurm, + testing=server_manager_config.testing, + ) + rprint(env_config) + rprint(openai_configs) + # Run the environment asyncio.run(env.env_manager()) @@ -1025,11 +1166,14 @@ class BaseEnv(ABC): type: The CliProcessConfig class for processing commands. """ + # Define specific default configurations for the 'process' mode PROCESS_MODE_ENV_DEFAULT_CONFIG = BaseEnvConfig( group_size=8, total_steps=2, ensure_scores_are_not_same=False, include_messages=True, + data_path_to_save_groups=f"data/{cls.name or 'groups'}.jsonl", + use_wandb=True, ) PROCESS_MODE_OPENAI_DEFAULT_CONFIG = APIServerConfig( model_name="gpt-4.1-nano", @@ -1041,21 +1185,22 @@ class BaseEnv(ABC): testing=False, ) + # Get the base default configurations from the specific environment class configs_and_maybe_server_class = cls.config_init() if len(configs_and_maybe_server_class) == 2: - default_env_config, default_openai_config = configs_and_maybe_server_class + default_env_config, default_server_configs = configs_and_maybe_server_class server_class = None else: - default_env_config, default_openai_config, server_class = ( + default_env_config, default_server_configs, server_class = ( configs_and_maybe_server_class ) - if isinstance(default_openai_config, list): - default_openai_config = default_openai_config[0] - + # Define namespace prefixes env_full_prefix = f"{ENV_NAMESPACE}{NAMESPACE_SEP}" openai_full_prefix = f"{OPENAI_NAMESPACE}{NAMESPACE_SEP}" + # Create Pydantic model classes with the 'process' mode defaults applied. + # These adjusted classes will be used for final instantiation. env_config_cls_new_defaults = adjust_model_defaults( type(default_env_config), PROCESS_MODE_ENV_DEFAULT_CONFIG ) @@ -1077,6 +1222,8 @@ class BaseEnv(ABC): ): """ Configuration for the process command. + Supports overrides via YAML config file and CLI arguments. + Order of precedence: CLI > YAML > Process Mode Defaults > `config_init` defaults. """ config: str | None = Field( @@ -1086,42 +1233,72 @@ class BaseEnv(ABC): def run(self) -> None: """The logic to execute for the 'process' command.""" - # Setup environment configuration + # Set default wandb name if not provided and class has a name wandb_name_attr = f"{ENV_NAMESPACE}{NAMESPACE_SEP}wandb_name" - if getattr(self, wandb_name_attr) is None and cls.name is not None: + if ( + getattr(self, wandb_name_attr, None) is None + and cls.name is not None + ): setattr(self, wandb_name_attr, cls.name) + # Load configuration from YAML file if specified if self.config is not None: with open(self.config, "r") as f: - config = yaml.safe_load(f) + yaml_config = yaml.safe_load(f) print(f"Loaded config from {self.config}") else: - config = {} + yaml_config = {} + # Get CLI flags passed with double dashes cli_passed_flags = get_double_dash_flags() - # cli args overrides config file which overrides class defaults which overrides process mode defaults - env_config = env_config_cls_new_defaults( - **merge_dicts( - default_env_config.model_dump(), - PROCESS_MODE_ENV_DEFAULT_CONFIG.model_dump(), - config.get(ENV_NAMESPACE, {}), - extract_namespace( - cli_passed_flags, env_full_prefix - ), # only extract namespace for cli-passed args - ) - ) - openai_config = openai_config_cls_new_defaults( - **merge_dicts( - default_openai_config.model_dump(), - PROCESS_MODE_OPENAI_DEFAULT_CONFIG.model_dump(), - config.get(OPENAI_NAMESPACE, {}), - extract_namespace( - cli_passed_flags, openai_full_prefix - ), # only extract namespace for cli-passed args - ) + # --- Configuration Merging --- + # Priority: CLI > YAML > Process Mode Defaults > `config_init` defaults + + # 1. Environment Configuration + env_config_dict = merge_dicts( + default_env_config.model_dump(), # Class Defaults + PROCESS_MODE_ENV_DEFAULT_CONFIG.model_dump(), # Process Mode Defaults + yaml_config.get(ENV_NAMESPACE, {}), # YAML config + extract_namespace(cli_passed_flags, env_full_prefix), # CLI args ) + # 2. OpenAI Configuration + oai_cli_passed_args = extract_namespace( + cli_passed_flags, openai_full_prefix + ) # CLI args + yaml_oai_config = yaml_config.get(OPENAI_NAMESPACE, {}) + if isinstance(default_server_configs, ServerBaseline) and ( + oai_cli_passed_args or yaml_oai_config + ): + raise ValueError( + "ServerBaseline is not compatible with OpenAI-namespaced CLI arguments. Please edit `config_init` directly or use APIServerConfig." # noqa: E501 + ) + + if ( + isinstance(default_server_configs, list) + and len(default_server_configs) == 1 + ): + # can't use the same var name because it shadows the class variable and we get an error + default_openai_config_ = default_server_configs[0] + else: + default_openai_config_ = default_server_configs + if isinstance(yaml_oai_config, list) and len(yaml_oai_config) == 1: + yaml_oai_config = yaml_oai_config[0] + if isinstance(default_openai_config_, APIServerConfig) and isinstance( + yaml_oai_config, dict + ): + openai_config_dict = merge_dicts( + default_openai_config_.model_dump(), # Default APIServerConfig (or from class init) + PROCESS_MODE_OPENAI_DEFAULT_CONFIG.model_dump(), # Process Mode Defaults + yaml_oai_config, + oai_cli_passed_args, + ) + else: + openai_config_dict = {} + + # 3. Server Manager Configuration + # Extract only relevant CLI flags server_manager_cli_passed_flags = {} if "slurm" in cli_passed_flags: server_manager_cli_passed_flags["slurm"] = cli_passed_flags["slurm"] @@ -1130,37 +1307,69 @@ class BaseEnv(ABC): "testing" ] - server_manager_config = server_manager_config_cls_new_defaults( - **merge_dicts( - ServerManagerConfig().model_dump(), - PROCESS_MODE_SERVER_MANAGER_DEFAULT_CONFIG.model_dump(), - config.get(SERVER_MANAGER_NAMESPACE, {}), - server_manager_cli_passed_flags, - ) + server_manager_yaml_dict = {} + if "slurm" in yaml_config: + server_manager_yaml_dict["slurm"] = yaml_config["slurm"] + if "testing" in yaml_config: + server_manager_yaml_dict["testing"] = yaml_config["testing"] + + server_manager_config_dict = merge_dicts( + ServerManagerConfig().model_dump(), # Base defaults + PROCESS_MODE_SERVER_MANAGER_DEFAULT_CONFIG.model_dump(), # Process Mode Defaults + server_manager_yaml_dict, + server_manager_cli_passed_flags, # CLI args ) + # --- Instantiate Final Config Objects --- + # Use the classes with adjusted defaults for instantiation + + env_config = env_config_cls_new_defaults(**env_config_dict) + server_manager_config = server_manager_config_cls_new_defaults( + **server_manager_config_dict + ) + + # Determine the final server_configs, handling single, multiple servers, and overrides. + + openai_configs = resolve_openai_configs( + default_server_configs=default_server_configs, + openai_config_dict=openai_config_dict, + yaml_config=yaml_config, + cli_passed_flags=cli_passed_flags, + logger=logger, + ) + + rprint(env_config) + rprint(openai_configs) + + # --- Create and Run Environment --- # Create the environment instance env = cls( config=env_config, - server_configs=[openai_config], + server_configs=openai_configs, slurm=server_manager_config.slurm, testing=server_manager_config.testing, server_class=server_class, ) - # Set the process mode parameters + # Set specific parameters for process mode on the environment instance env.process_mode = True env.n_groups_to_process = env_config.total_steps env.group_size_to_process = env_config.group_size + # Validate that an output path is set (should have a default from PROCESS_MODE_ENV_DEFAULT_CONFIG) + if env_config.data_path_to_save_groups is None: + # This check might be redundant if the default is always set, but good practice. + raise ValueError( + "data_path_to_save_groups must be set for process mode" + ) + print( f"Processing {env_config.total_steps} groups of " f"{env_config.group_size} responses and " f"writing to {env_config.data_path_to_save_groups}" ) + # Run the environment's asynchronous process manager function asyncio.run(env.process_manager()) - # Actual implementation would go here - return CliProcessConfig diff --git a/atroposlib/envs/reward_fns/cosine_scaled_reward.py b/atroposlib/envs/reward_fns/cosine_scaled_reward.py index 0b620abe..3a34198b 100644 --- a/atroposlib/envs/reward_fns/cosine_scaled_reward.py +++ b/atroposlib/envs/reward_fns/cosine_scaled_reward.py @@ -4,7 +4,17 @@ import logging from typing import Any, List, Optional, Union import scipy -import torch + +try: + import torch +except ImportError as e: + logger = logging.getLogger(__name__) + logger.warning( + "torch not installed, please install atroposlib[rewardfns] to use this reward function" + ) + raise e + + from transformers import AutoModel, AutoTokenizer from .registry import registry diff --git a/atroposlib/envs/server_handling/openai_server.py b/atroposlib/envs/server_handling/openai_server.py index 29dff341..6ebf9ec4 100644 --- a/atroposlib/envs/server_handling/openai_server.py +++ b/atroposlib/envs/server_handling/openai_server.py @@ -1,6 +1,7 @@ import asyncio import collections import time +import warnings from asyncio import exceptions from typing import Optional @@ -9,9 +10,10 @@ import numpy as np import openai from openai.types.chat.chat_completion import ChatCompletion from openai.types.completion import Completion -from pydantic import BaseModel, Field +from pydantic_cli import FailedExecutionException from tenacity import retry, stop_after_attempt, wait_random_exponential +from atroposlib.envs.constants import NAMESPACE_SEP, OPENAI_NAMESPACE from atroposlib.envs.server_handling.server_baseline import APIServerConfig @@ -159,6 +161,42 @@ class OpenAIServer: ) return metrics_dict + async def _chat_completion_wrapper(self, **kwargs) -> ChatCompletion: + if self.config.n_kwarg_is_ignored: + n = kwargs.pop("n", 1) + completion_list = await asyncio.gather( + *[self.openai.chat.completions.create(**kwargs) for _ in range(n)] + ) + completions = completion_list[0] + if n > 1: + for c in completion_list[1:]: + completions.choices.extend(c.choices) + else: + completions = await self.openai.chat.completions.create(**kwargs) + else: + if "n" in kwargs: + n = kwargs["n"] + else: + n = 1 + completions = await self.openai.chat.completions.create(**kwargs) + if len(completions.choices) != n: + if len(completions.choices) != 1: + raise ValueError( + f"Expected 1 or {n} completions, got {len(completions.choices)}!" + ) + else: + warnings.warn("n kwarg is ignored by the API, setting to True") + self.config.n_kwarg_is_ignored = True + completion_list = await asyncio.gather( + *[ + self.openai.chat.completions.create(**kwargs) + for _ in range(1, n) + ] + ) + for c in completion_list: + completions.choices.extend(c.choices) + return completions + @retry( stop=stop_after_attempt(3), wait=wait_random_exponential(multiplier=1, max=10) ) @@ -169,7 +207,7 @@ class OpenAIServer: if stat_dict.get("start", None) is None: stat_dict["start"] = time.time() stat_dict["attempts"] += 1 - completions = await self.openai.chat.completions.create(**kwargs) + completions = await self._chat_completion_wrapper(**kwargs) stat_dict["end"] = time.time() return completions @@ -183,7 +221,7 @@ class OpenAIServer: if stat_dict.get("start", None) is None: stat_dict["start"] = time.time() stat_dict["attempts"] += 1 - completions = await self.openai.chat.completions.create(**kwargs) + completions = await self._chat_completion_wrapper(**kwargs) stat_dict["end"] = time.time() return completions @@ -214,6 +252,36 @@ class OpenAIServer: self.eval_attempts_list.append(stat_dict["attempts"]) return ret_data + async def _completion_wrapper(self, **kwargs) -> Completion: + if self.config.n_kwarg_is_ignored: + n = kwargs.pop("n", 1) + completion_list = await asyncio.gather( + *[self.openai.completions.create(**kwargs) for _ in range(n)] + ) + completions = completion_list[0] + if n > 1: + for c in completion_list[1:]: + completions.choices.extend(c.choices) + else: + if "n" in kwargs: + n = kwargs["n"] + else: + n = 1 + completions = await self.openai.completions.create(**kwargs) + if len(completions.choices) != n: + if len(completions.choices) != 1: + raise ValueError( + f"Expected 1 or {n} completions, got {len(completions.choices)}!" + ) + else: + warnings.warn("n kwarg is ignored by the API, setting to True") + self.config.n_kwarg_is_ignored = True + completion_list = await asyncio.gather( + *[self.openai.completions.create(**kwargs) for _ in range(1, n)] + ) + for c in completion_list: + completions.choices.extend(c.choices) + @retry( stop=stop_after_attempt(3), wait=wait_random_exponential(multiplier=1, max=10) ) @@ -224,7 +292,7 @@ class OpenAIServer: if stat_dict.get("start", None) is None: stat_dict["start"] = time.time() stat_dict["attempts"] += 1 - completions = await self.openai.completions.create(**kwargs) + completions = await self._completion_wrapper(**kwargs) stat_dict["end"] = time.time() return completions @@ -238,7 +306,7 @@ class OpenAIServer: if stat_dict.get("start", None) is None: stat_dict["start"] = time.time() stat_dict["attempts"] += 1 - completions = await self.openai.completions.create(**kwargs) + completions = await self._completion_wrapper(**kwargs) stat_dict["end"] = time.time() return completions @@ -265,3 +333,79 @@ class OpenAIServer: self.eval_request_timings.append(stat_dict["end"] - stat_dict["start"]) self.eval_attempts_list.append(stat_dict["attempts"]) return ret_data + + +def resolve_openai_configs( + default_server_configs, + openai_config_dict, + yaml_config, + cli_passed_flags, + logger, +): + """ + Helper to resolve the final server_configs, handling single, multiple servers, and overrides. + """ + from atroposlib.envs.server_handling.server_manager import ServerBaseline + + openai_full_prefix = f"{OPENAI_NAMESPACE}{NAMESPACE_SEP}" + openai_yaml_config = yaml_config.get(OPENAI_NAMESPACE, None) + openai_cli_config = { + k: v for k, v in cli_passed_flags.items() if k.startswith(openai_full_prefix) + } + + is_multi_server_yaml = ( + isinstance(openai_yaml_config, list) and len(openai_yaml_config) >= 2 + ) + is_multi_server_default = ( + (not is_multi_server_yaml) + and isinstance(default_server_configs, list) + and len(default_server_configs) >= 2 + ) + + if (is_multi_server_yaml or is_multi_server_default) and openai_cli_config: + raise FailedExecutionException( + f"CLI overrides for OpenAI settings (--{openai_full_prefix}*) are not supported " + f"when multiple servers are defined (either via YAML list under '{OPENAI_NAMESPACE}' " + "or a default list with length >= 2)." + ) + + if is_multi_server_yaml: + logger.info( + f"Using multi-server configuration defined in YAML under '{OPENAI_NAMESPACE}'." + ) + try: + server_configs = [APIServerConfig(**cfg) for cfg in openai_yaml_config] + except Exception as e: + raise FailedExecutionException( + f"Error parsing multi-server OpenAI configuration from YAML under '{OPENAI_NAMESPACE}': {e}" + ) from e + elif isinstance(default_server_configs, ServerBaseline): + logger.info("Using ServerBaseline configuration.") + server_configs = default_server_configs + elif is_multi_server_default: + logger.info("Using default multi-server configuration (length >= 2).") + server_configs = default_server_configs + else: + logger.info( + "Using single OpenAI server configuration based on merged settings (default/YAML/CLI)." + ) + try: + final_openai_config = APIServerConfig(**openai_config_dict) + except Exception as e: + raise FailedExecutionException( + f"Error creating final OpenAI configuration from merged settings: {e}\n" + f"Merged Dict: {openai_config_dict}" + ) from e + + if isinstance(default_server_configs, APIServerConfig): + server_configs = final_openai_config + elif isinstance(default_server_configs, list): + server_configs = [final_openai_config] + else: + logger.warning( + f"Unexpected type for default_server_configs: {type(default_server_configs)}. " + f"Proceeding with single OpenAI server configuration based on merged settings." + ) + server_configs = [final_openai_config] + + return server_configs diff --git a/atroposlib/envs/server_handling/server_baseline.py b/atroposlib/envs/server_handling/server_baseline.py index e8429605..62f9df8b 100644 --- a/atroposlib/envs/server_handling/server_baseline.py +++ b/atroposlib/envs/server_handling/server_baseline.py @@ -38,3 +38,6 @@ class APIServerConfig(ServerBaseline): server_type: Literal["openai"] = Field( default="openai", description="Type of server to use, openai or trl" ) + n_kwarg_is_ignored: bool = Field( + default=False, description="Whether the n kwarg is ignored by this API server." + ) diff --git a/atroposlib/envs/server_handling/server_manager.py b/atroposlib/envs/server_handling/server_manager.py index 14c42168..d81dd338 100644 --- a/atroposlib/envs/server_handling/server_manager.py +++ b/atroposlib/envs/server_handling/server_manager.py @@ -86,8 +86,8 @@ class ServerManager: ) ) self.servers = [OpenAIServer(config) for config in openai_configs] - if not slurm: - self.servers = [server_class(config) for config in configs] + elif not slurm: + self.servers = [OpenAIServer(config) for config in configs] else: nodelist = ( os.popen(f'scontrol show hostnames {os.environ["SLURM_JOB_NODELIST"]}') diff --git a/atroposlib/tests/conftest.py b/atroposlib/tests/conftest.py new file mode 100644 index 00000000..d122e39d --- /dev/null +++ b/atroposlib/tests/conftest.py @@ -0,0 +1,23 @@ +import pytest + + +def pytest_addoption(parser): + parser.addoption( + "--runproviders", action="store_true", default=False, help="run provider tests" + ) + + +def pytest_configure(config): + config.addinivalue_line( + "markers", "providers: mark test as requires providers api keys to run" + ) + + +def pytest_collection_modifyitems(config, items): + if config.getoption("--runproviders"): + # --runproviders given in cli: do not skip slow tests + return + skip_providers = pytest.mark.skip(reason="need --runproviders option to run") + for item in items: + if "providers" in item.keywords: + item.add_marker(skip_providers) diff --git a/atroposlib/tests/test_advantages.py b/atroposlib/tests/test_advantages.py index 151ebd2b..2643f580 100644 --- a/atroposlib/tests/test_advantages.py +++ b/atroposlib/tests/test_advantages.py @@ -1,7 +1,7 @@ import math +import numpy as np import pytest -import torch # Adjust the import below if your functions are in a different module. from atroposlib.utils.advantages import ( @@ -23,9 +23,9 @@ def test_allclose_to_first_vector(): """Test that return_vector=True returns a tensor of booleans.""" values = [1.0, 1.000000001, 1.000000002] result = allclose_to_first(values, return_vector=True) - assert isinstance(result, torch.Tensor) + assert isinstance(result, np.ndarray) # All comparisons should be True. - assert torch.all(result) + assert np.all(result) def test_allclose_to_first_not_close(): @@ -74,15 +74,15 @@ def test_compute_stats_jagged(): def test_compute_discounted_returns(): """Test compute_discounted_returns with a tensor input.""" - rewards = torch.tensor([1.0, 1.0, 1.0]) + rewards = np.array([1.0, 1.0, 1.0]) gamma = 0.9 returns = compute_discounted_returns(rewards, gamma) # For a 3-element vector: # t=2: 1.0 # t=1: 1.0 + 0.9*1.0 = 1.9 # t=0: 1.0 + 0.9*1.9 = 2.71 - expected = torch.tensor([2.71, 1.9, 1.0]) - assert torch.allclose(returns, expected, rtol=1e-5, atol=1e-8) + expected = np.array([2.71, 1.9, 1.0]) + assert np.allclose(returns, expected, rtol=1e-5, atol=1e-8) def test_compute_discounted_returns_list_input(): @@ -90,8 +90,8 @@ def test_compute_discounted_returns_list_input(): rewards = [1, 1, 1] gamma = 0.0 # With gamma=0, the returns should equal the rewards. returns = compute_discounted_returns(rewards, gamma) - expected = torch.tensor([1.0, 1.0, 1.0]) - assert torch.allclose(returns, expected, rtol=1e-5, atol=1e-8) + expected = np.array([1.0, 1.0, 1.0]) + assert np.allclose(returns, expected, rtol=1e-5, atol=1e-8) def test_compute_grpo_process_supervision_advantages_cumsum(): diff --git a/atroposlib/tests/test_openai_api_workarounds.py b/atroposlib/tests/test_openai_api_workarounds.py new file mode 100644 index 00000000..8ea91911 --- /dev/null +++ b/atroposlib/tests/test_openai_api_workarounds.py @@ -0,0 +1,110 @@ +import asyncio +import os + +import dotenv +import pytest + +from atroposlib.envs.server_handling.openai_server import APIServerConfig, OpenAIServer + + +@pytest.mark.providers +def test_openai_api_n_kwarg_ignore_discovery(): + dotenv.load_dotenv() + openrouter_api_key = os.getenv("OPENROUTER_API_KEY") + if not openrouter_api_key: + pytest.skip("OPENROUTER_API_KEY not set") + config = APIServerConfig( + api_key=openrouter_api_key, + base_url="https://openrouter.ai/api/v1", + model_name="openai/gpt-4.1-nano", + timeout=1200, + num_max_requests_at_once=512, + num_requests_for_eval=64, + rolling_buffer_length=1024, + ) + assert not config.n_kwarg_is_ignored, "n kwarg is not ignored by default" + n = 4 + server = OpenAIServer( + config=config, + ) + response = asyncio.run( + server.chat_completion( + messages=[ + {"role": "user", "content": "Hello, how are you?"}, + ], + n=n, + ) + ) + assert server.config.n_kwarg_is_ignored, "n kwarg is should be set after discovery" + print(len(response.choices), n) + assert ( + len(response.choices) == n + ), f"Expected {n} responses, got {len(response.choices)}" + + +@pytest.mark.providers +def test_openai_api_n_kwarg_ignore_use(): + dotenv.load_dotenv() + openrouter_api_key = os.getenv("OPENROUTER_API_KEY") + if not openrouter_api_key: + pytest.skip("OPENROUTER_API_KEY not set") + config = APIServerConfig( + api_key=openrouter_api_key, + base_url="https://openrouter.ai/api/v1", + model_name="openai/gpt-4.1-nano", + timeout=1200, + num_max_requests_at_once=512, + num_requests_for_eval=64, + rolling_buffer_length=1024, + n_kwarg_is_ignored=True, + ) + server = OpenAIServer( + config=config, + ) + n = 4 + response = asyncio.run( + server.chat_completion( + messages=[ + {"role": "user", "content": "Hello, how are you?"}, + ], + n=n, + ) + ) + assert server.config.n_kwarg_is_ignored, "n kwarg is should be set after discovery" + assert ( + len(response.choices) == n + ), f"Expected {n} responses, got {len(response.choices)}" + + +@pytest.mark.providers +def test_openai_api_n_kwarg_supported(): + dotenv.load_dotenv() + openai_api_key = os.getenv("OPENAI_API_KEY") + if not openai_api_key: + pytest.skip("OPENAI_API_KEY not set") + config = APIServerConfig( + model_name="gpt-4.1-nano", + timeout=1200, + num_max_requests_at_once=512, + num_requests_for_eval=64, + rolling_buffer_length=1024, + n_kwarg_is_ignored=False, + ) + server = OpenAIServer( + config=config, + ) + n = 4 + response = asyncio.run( + server.chat_completion( + messages=[ + {"role": "user", "content": "Hello, how are you?"}, + ], + n=n, + ) + ) + assert ( + not server.config.n_kwarg_is_ignored + ), "n kwarg should be used with supported models" + assert ( + len(response.choices) == n + ), f"Expected {n} responses, got {len(response.choices)}" diff --git a/atroposlib/utils/advantages.py b/atroposlib/utils/advantages.py index dcb31b60..93ec0575 100644 --- a/atroposlib/utils/advantages.py +++ b/atroposlib/utils/advantages.py @@ -1,31 +1,32 @@ from typing import Sequence -import torch +import numpy as np from atroposlib.type_definitions import number -TensorLike = torch.Tensor | Sequence[torch.Tensor] | Sequence[Sequence] +NumpyArrayLike = np.ndarray | Sequence[np.ndarray] | Sequence[Sequence] # Type alias for vector of bools -BoolVector = torch.Tensor +BoolVector = np.ndarray def allclose_to_first( - values: TensorLike, + # values: TensorLike, + values: NumpyArrayLike, rtol: float = 1e-05, atol: float = 1e-08, equal_nan: bool = False, return_vector: bool = False, ) -> BoolVector | bool: """ - Check if all tensors in `values` are close to the first tensor `values[0]` using a vectorized approach. + Check if all arrays in `values` are close to the first array `values[0]` using a vectorized approach. If `return_vector` is False (default), returns a single boolean indicating whether - every tensor is close to the first tensor. If `return_vector` is True, returns a list - of booleans where each element corresponds to whether the respective tensor in - `values` is close to the first tensor. The first element is always True. + every array is close to the first array. If `return_vector` is True, returns a list + of booleans where each element corresponds to whether the respective array in + `values` is close to the first array. The first element is always True. Args: - values (torch.Tensor | Sequence[torch.Tensor] | Sequence[Sequence]): + values (np.ndarray | Sequence[np.ndarray] | Sequence[Sequence]): Nested list of values to compare. Must be rectangular, but not necessarily 2D. rtol (float, optional): Relative tolerance. Defaults to 1e-05. atol (float, optional): Absolute tolerance. Defaults to 1e-08. @@ -35,24 +36,22 @@ def allclose_to_first( Returns: bool or BoolVector: - - If `return_vector` is False, returns True if all tensors are close to the first tensor; + - If `return_vector` is False, returns True if all arrays are close to the first array; otherwise, returns False. - - If `return_vector` is True, returns a 1D tensor of bools where the first element is True - (as the reference tensor is trivially close to itself), and each subsequent element indicates - whether the corresponding tensor is close to the first tensor. + - If `return_vector` is True, returns a 1D array of bools where the first element is True + (as the reference array is trivially close to itself), and each subsequent element indicates + whether the corresponding array is close to the first array. """ - if not isinstance(values, torch.Tensor): - values = torch.tensor(values) + if not isinstance(values, np.ndarray): + values = np.array(values) reference = values[0] - is_close = torch.isclose( - values, reference, rtol=rtol, atol=atol, equal_nan=equal_nan - ) + is_close = np.isclose(values, reference, rtol=rtol, atol=atol, equal_nan=equal_nan) # flatten dimensions after first - result_vector = torch.all(is_close.view(is_close.size(0), -1), dim=1) + result_vector = np.all(is_close.reshape(is_close.shape[0], -1), axis=1) - return result_vector if return_vector else bool(torch.all(result_vector)) + return result_vector if return_vector else bool(np.all(result_vector)) def compute_stats(data: Sequence[number | Sequence]) -> dict[str, float]: @@ -104,23 +103,23 @@ def compute_stats(data: Sequence[number | Sequence]) -> dict[str, float]: return {"mean": mean, "var": variance} -def compute_discounted_returns(rewards: torch.Tensor, gamma: float) -> torch.Tensor: +def compute_discounted_returns(rewards: np.ndarray, gamma: float) -> np.ndarray: """Compute discounted returns from a 1D vector of rewards. - Given a list or torch tensor of rewards and a discount factor, this function computes + Given a list or numpy array of rewards and a discount factor, this function computes the discounted return at each timestep. The discounted return at time t is defined as: G_t = rewards[t] + gamma * rewards[t+1] + gamma^2 * rewards[t+2] + ... Args: - rewards (list[float] or torch.Tensor): A 1D list or tensor of rewards. + rewards (list[float] or np.ndarray): A 1D list or array of rewards. gamma (float): The discount factor (should be between 0 and 1). Returns: list[float]: A list containing the discounted returns for each timestep. """ - if not isinstance(rewards, torch.Tensor): - rewards = torch.tensor(rewards, dtype=torch.float) - discounted_returns = torch.empty_like(rewards) + if not isinstance(rewards, np.ndarray): + rewards = np.array(rewards, dtype=np.float32) # Use float32 for numpy default + discounted_returns = np.empty_like(rewards) running_return = 0.0 for t in reversed(range(len(rewards))): @@ -132,7 +131,7 @@ def compute_discounted_returns(rewards: torch.Tensor, gamma: float) -> torch.Ten def compute_grpo_process_supervision_advantages( rewards: Sequence[Sequence[number]], gamma: float = None, std_tol: float = 1e-8 -) -> list[torch.Tensor]: +) -> list[np.ndarray]: """ Given a (possibly jagged) list of list of rewards, compute advantages for GRPO. @@ -144,7 +143,7 @@ def compute_grpo_process_supervision_advantages( std_tol (float): The tolerance for the standard deviation. Returns: - A list of tensors of advantages. + A list of arrays of advantages. Raises: ValueError: If the standard deviation of the flattened rewards is smaller than the tolerance. @@ -155,13 +154,11 @@ def compute_grpo_process_supervision_advantages( if std < std_tol: raise ValueError(f"`std` is smaller than tolerance of {std_tol}.") - normalized_rewards = [ - (torch.tensor(trajectory) - mean) / std for trajectory in rewards - ] + normalized_rewards = [(np.array(trajectory) - mean) / std for trajectory in rewards] if gamma is None: advantages = [ - trajectory.flip(dims=[0]).cumsum(dim=0).flip(dims=[0]) + np.flip(np.cumsum(np.flip(trajectory, axis=0), axis=0), axis=0) for trajectory in normalized_rewards ] else: diff --git a/atroposlib/utils/cli.py b/atroposlib/utils/cli.py index a20b6d6d..1ca6082f 100644 --- a/atroposlib/utils/cli.py +++ b/atroposlib/utils/cli.py @@ -156,24 +156,41 @@ def get_double_dash_flags() -> Dict[str, Any]: # Remove '--' prefix key_part = arg[2:] + key = "" + value_str = ( + None # Variable to hold the string value before potential conversion + ) # Check for '--key=value' format if "=" in key_part: - key, value = key_part.split("=", 1) - if key: # Ensure key is not empty (e.g. --=value) - flags_dict[key] = value + key, value_str = key_part.split("=", 1) + if not key: # Ensure key is not empty (e.g. --=value) + i += 1 + continue # Skip if key is empty + + # Process value: Convert "None" string to None object + if value_str == "None": + flags_dict[key] = None + else: + flags_dict[key] = value_str i += 1 # Check if next argument exists and is a value (doesn't start with '-') elif i + 1 < len(args) and not args[i + 1].startswith("-"): key = key_part - value = args[i + 1] - flags_dict[key] = value + value_str = args[i + 1] + + # Process value: Convert "None" string to None object + if value_str == "None": + flags_dict[key] = None + else: + flags_dict[key] = value_str # Skip the next argument since we've consumed it as a value i += 2 # Otherwise, treat as a boolean flag else: key = key_part - flags_dict[key] = True + if key: # Ensure key is not empty (e.g. just '--') + flags_dict[key] = True i += 1 return flags_dict diff --git a/atroposlib/utils/config_handler.py b/atroposlib/utils/config_handler.py index 3cfd1f7b..d7712f62 100644 --- a/atroposlib/utils/config_handler.py +++ b/atroposlib/utils/config_handler.py @@ -181,4 +181,4 @@ class ConfigHandler: # Add slurm flag to config if running in a Slurm environment config["use_slurm"] = "SLURM_JOB_ID" in os.environ - return config \ No newline at end of file + return config diff --git a/atroposlib/utils/io.py b/atroposlib/utils/io.py new file mode 100644 index 00000000..d8ae55bd --- /dev/null +++ b/atroposlib/utils/io.py @@ -0,0 +1,39 @@ +import logging +from typing import Any, Optional + +logger = logging.getLogger(__name__) + + +async def parse_http_response( + resp: Any, logger: Optional[logging.Logger] = None +) -> Any: + """ + Parse an HTTP response with proper error handling and logging. + + Args: + resp: The HTTP response object (must have raise_for_status() and json() methods) + logger: Optional logger instance. If not provided, uses the default module logger. + + Returns: + The parsed JSON response + + Raises: + Exception: Re-raises any exceptions that occur during parsing + """ + if logger is None: + logger = logging.getLogger(__name__) + + try: + # Raise an exception for bad status codes (4xx or 5xx) + resp.raise_for_status() + # Attempt to parse the response as JSON + return await resp.json() + except Exception as e: + # Handle HTTP errors (raised by raise_for_status) + error_text = await resp.text() # Read the response text for logging + logger.error( + f"Error fetching from server. Status: {getattr(e, 'status', 'unknown')}, " + f"Message: {getattr(e, 'message', str(e))}. Response: {error_text}" + ) + # Re-raise the exception to allow retry decorators to handle it + raise diff --git a/atroposlib/utils/tokenize_for_trainer.py b/atroposlib/utils/tokenize_for_trainer.py index 8d9ea3dc..b5d16095 100644 --- a/atroposlib/utils/tokenize_for_trainer.py +++ b/atroposlib/utils/tokenize_for_trainer.py @@ -1,4 +1,4 @@ -import torch +import numpy as np from transformers import PreTrainedTokenizer from atroposlib.type_definitions import Message @@ -39,7 +39,7 @@ def tokenize_for_trainer( # (e.g. current date). e.g. consider a system prompt that depends on the current date and a run that crosses # midnight from 3/9 to 3/10 under a tokenizer that tokenizes 3/9 and 3/10 with a different number of tokens. - masks = torch.ones(len(tokens), dtype=torch.long) * -100 + masks = np.ones(len(tokens), dtype=np.int64) * -100 for i, msg in enumerate(chat): if msg["role"] in UNMASKED_ROLES: @@ -51,7 +51,7 @@ def tokenize_for_trainer( ) start_idx = len(prefix_tokens) end_idx = len(unmasked_tokens) - masks[start_idx:end_idx] = torch.tensor(unmasked_tokens[start_idx:]) + masks[start_idx:end_idx] = np.array(unmasked_tokens[start_idx:]) masks = masks.tolist() if finish_reason == "length": diff --git a/environments/configs/example.yaml b/environments/configs/example.yaml new file mode 100644 index 00000000..680458cf --- /dev/null +++ b/environments/configs/example.yaml @@ -0,0 +1,21 @@ +# Environment configuration +env: + group_size: 4 + max_batches_offpolicy: 3 + tokenizer_name: "Qwen/Qwen2.5-1.5B-Instruct" + use_wandb: true + rollout_server_url: "http://localhost:8000" + wandb_name: "example_env" + ensure_scores_are_not_same: true + data_path_to_save_groups: null + include_messages: true # if data_path_to_save_groups is set this will add the messages to the saved .jsonl + +# OpenAI server configurations +openai: + - model_name: "Qwen/Qwen2.5-1.5B-Instruct" + base_url: "http://localhost:9001/v1" + api_key: "x" + weight: 1.0 + +slurm: false +testing: false diff --git a/environments/dataset_environment/LOCAL_TESTING.md b/environments/dataset_environment/LOCAL_TESTING.md index a5c8eb87..1d430b21 100644 --- a/environments/dataset_environment/LOCAL_TESTING.md +++ b/environments/dataset_environment/LOCAL_TESTING.md @@ -152,4 +152,4 @@ server_configs: If you encounter issues with reward functions, make sure they are properly registered in the registry. -For dataset-related issues, verify that the dataset exists on HuggingFace and that the specified fields exist in the dataset. \ No newline at end of file +For dataset-related issues, verify that the dataset exists on HuggingFace and that the specified fields exist in the dataset. diff --git a/environments/dataset_environment/configs/dataset_local.yaml b/environments/dataset_environment/configs/dataset_local.yaml index 7849de34..d66a01f7 100644 --- a/environments/dataset_environment/configs/dataset_local.yaml +++ b/environments/dataset_environment/configs/dataset_local.yaml @@ -49,4 +49,4 @@ dataset: server_configs: - model_name: "gpt-4.1-nano" api_key: ${OPENAI_API_KEY} - timeout: 600 \ No newline at end of file + timeout: 600 diff --git a/environments/dataset_environment/configs/gsm8k.yaml b/environments/dataset_environment/configs/gsm8k.yaml index 7979fe46..ea19ea76 100644 --- a/environments/dataset_environment/configs/gsm8k.yaml +++ b/environments/dataset_environment/configs/gsm8k.yaml @@ -70,4 +70,4 @@ dataset: eval_dataset_name: "gsm8k" eval_dataset_config: "main" - eval_split: "test" \ No newline at end of file + eval_split: "test" diff --git a/environments/dataset_environment/configs/gsm8k_debug.yaml b/environments/dataset_environment/configs/gsm8k_debug.yaml index f928e9e2..372c5a3e 100644 --- a/environments/dataset_environment/configs/gsm8k_debug.yaml +++ b/environments/dataset_environment/configs/gsm8k_debug.yaml @@ -27,4 +27,4 @@ dataset: max_tokens: 4096 length_warmup_steps: 0 - min_tokens: 200 \ No newline at end of file + min_tokens: 200 diff --git a/environments/dataset_environment/dataset_env.py b/environments/dataset_environment/dataset_env.py index 602cf812..23548c73 100644 --- a/environments/dataset_environment/dataset_env.py +++ b/environments/dataset_environment/dataset_env.py @@ -32,7 +32,9 @@ class DatasetEnvConfig(BaseEnvConfig): None, description="Field in dataset containing canonical correct answer" ) system_prompt: Optional[str] = Field(None, description="System prompt to use") - prefill: Optional[str] = Field(None, description="Text to prefill the completion with (e.g. '')") + prefill: Optional[str] = Field( + None, description="Text to prefill the completion with (e.g. '')" + ) shuffle_dataset: bool = Field(True, description="Whether to shuffle the dataset") max_generations_per_prompt: int = Field( 1, description="Number of generations per prompt for collection" @@ -137,21 +139,21 @@ class DatasetEnv(BaseEnv): # Extract user prompt and answer from item user_content = dict(item[0][0])["content"] answer = item[1] if len(item) > 1 else None - + # Create messages list messages = [] if self.config.system_prompt: messages.append({"role": "system", "content": self.config.system_prompt}) - + messages.append({"role": "user", "content": user_content}) - + # Add prefill as assistant message if configured if self.config.prefill: messages.append({"role": "assistant", "content": self.config.prefill}) - + # Convert messages to a prompt string using the tokenizer prompt = self.tokenizer.apply_chat_template(messages, tokenize=False) - + # Calculate max tokens for generation (with optional warmup) max_tokens = self.config.max_tokens if self.config.length_warmup_steps > 0: @@ -160,7 +162,7 @@ class DatasetEnv(BaseEnv): self.config.min_tokens + warmup_progress * (self.config.max_tokens - self.config.min_tokens) ) - + # Generate completion using completions API completions = await self.server.completion( prompt=prompt, @@ -169,34 +171,38 @@ class DatasetEnv(BaseEnv): temperature=self.config.temperature, top_p=self.config.top_p, ) - + to_score = [] to_backlog = [] - + # Process completions for completion in completions.choices: # Get the completion text - completion_text = completion.text if hasattr(completion, "text") else completion.message.content - + completion_text = ( + completion.text + if hasattr(completion, "text") + else completion.message.content + ) + # Build full message sequence for scoring full_messages = [] if self.config.system_prompt: - full_messages.append({"role": "system", "content": self.config.system_prompt}) - + full_messages.append( + {"role": "system", "content": self.config.system_prompt} + ) + full_messages.append({"role": "user", "content": user_content}) - + # Combine prefill with completion if prefill was used response_content = completion_text if self.config.prefill: response_content = self.config.prefill + completion_text - + full_messages.append({"role": "assistant", "content": response_content}) - + # Add to scoring list with answer and ground truth - to_score.append( - (full_messages, answer, item[2] if len(item) > 2 else None) - ) - + to_score.append((full_messages, answer, item[2] if len(item) > 2 else None)) + return to_score, to_backlog async def postprocess_histories(self, trajectories: List) -> Tuple[List, List]: @@ -204,27 +210,27 @@ class DatasetEnv(BaseEnv): async def collect_trajectories(self, item: Item) -> Tuple[List, List]: self.current_item = item - + # Extract user prompt from item user_content = dict(item[0][0])["content"] - + # Create messages list messages = [] if self.config.system_prompt: messages.append({"role": "system", "content": self.config.system_prompt}) - + messages.append({"role": "user", "content": user_content}) - + # Add prefill as assistant message if configured if self.config.prefill: messages.append({"role": "assistant", "content": self.config.prefill}) - + # Convert messages to a prompt string using the tokenizer prompt = self.tokenizer.apply_chat_template(messages, tokenize=False) - + # Calculate max tokens for generation (with optional warmup) max_tokens = self.config.max_tokens - + # Generate completions completions = await self.server.completion( prompt=prompt, @@ -233,30 +239,36 @@ class DatasetEnv(BaseEnv): temperature=self.config.temperature, top_p=self.config.top_p, ) - + print(f"Completions: {completions}") # Process completions trajectories = [] for completion in completions.choices: # Get the completion text - completion_text = completion.text if hasattr(completion, "text") else completion.message.content - + completion_text = ( + completion.text + if hasattr(completion, "text") + else completion.message.content + ) + # Build complete message sequence full_messages = [] if self.config.system_prompt: - full_messages.append({"role": "system", "content": self.config.system_prompt}) - + full_messages.append( + {"role": "system", "content": self.config.system_prompt} + ) + full_messages.append({"role": "user", "content": user_content}) - + # Combine prefill with completion if prefill was used response_content = completion_text if self.config.prefill: response_content = self.config.prefill + completion_text - + full_messages.append({"role": "assistant", "content": response_content}) - + trajectories.append(full_messages) - + return trajectories, [] async def score(self, rollout_group_data: List) -> Optional[ScoredDataGroup]: @@ -402,6 +414,7 @@ class DatasetEnv(BaseEnv): await super().wandb_log(metrics) + if __name__ == "__main__": # Launch the DatasetEnv via the BaseEnv CLI (serve or process) DatasetEnv.cli() diff --git a/environments/dataset_environment/dataset_local_server.py b/environments/dataset_environment/dataset_local_server.py index 7812bf23..6e4e5608 100644 --- a/environments/dataset_environment/dataset_local_server.py +++ b/environments/dataset_environment/dataset_local_server.py @@ -8,7 +8,8 @@ from dotenv import load_dotenv from atroposlib.envs.base import APIServerConfig from atroposlib.envs.reward_fns import registry -from atroposlib.utils.config_handler import ConfigHandler + +# from atroposlib.utils.config_handler import ConfigHandler from environments.dataset_environment.dataset_env import DatasetEnv, DatasetEnvConfig load_dotenv() @@ -23,7 +24,8 @@ def parse_arguments(): "--config", type=str, default="dataset_local", - help="Configuration file name (without .yaml extension) relative to environments/dataset_environment/configs/, or full path to a YAML file.", + help="Configuration file name (without .yaml extension) relative to environments/dataset_environment/configs/," + " or full path to a YAML file.", ) return parser.parse_args() @@ -35,7 +37,7 @@ async def main(): args = parse_arguments() # Initialize config handler - config_handler = ConfigHandler() + # config_handler = ConfigHandler() # Determine config path if ( diff --git a/environments/dataset_environment/launch_local_dataset_run.py b/environments/dataset_environment/launch_local_dataset_run.py index 26f6a2ac..2e95d05a 100644 --- a/environments/dataset_environment/launch_local_dataset_run.py +++ b/environments/dataset_environment/launch_local_dataset_run.py @@ -14,16 +14,16 @@ Requirements: - Run from project root so example_trainer is on PYTHONPATH - example_trainer/ is a valid Python package (with __init__.py) """ -import os -import sys -import subprocess -import time import atexit +import os import signal +import subprocess +import sys +import time import traceback # Ensure project root is on PYTHONPATH -project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..')) +project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")) if project_root not in sys.path: sys.path.insert(0, project_root) @@ -32,56 +32,73 @@ try: from example_trainer.grpo import TrainingConfig, train except ImportError as e: print(f"Error importing example_trainer.grpo: {e}") - print("Ensure you're running from project root and that example_trainer/ is a package.") + print( + "Ensure you're running from project root and that example_trainer/ is a package." + ) sys.exit(1) # ----------------------------------------------------------------------------- # Configuration # ----------------------------------------------------------------------------- -API_HOST = '127.0.0.1' +API_HOST = "127.0.0.1" API_PORT = 8000 -VLLM_HOST = '127.0.0.1' +VLLM_HOST = "127.0.0.1" VLLM_PORT = 9001 -MODEL_NAME = 'Qwen/Qwen2.5-1.5B-Instruct' +MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct" TOKENIZER_NAME = MODEL_NAME TRAINER_CONFIG = { - 'model_name': MODEL_NAME, - 'training_steps': 20, - 'batch_size': 2, - 'gradient_accumulation_steps': 2, - 'seq_len': 512, - 'vllm_port': VLLM_PORT, - 'vllm_restart_interval': 10, - 'use_wandb': False, - 'wandb_project': '', - 'wandb_group': '', - 'save_path': './trained_model_checkpoints_local_test', + "model_name": MODEL_NAME, + "training_steps": 20, + "batch_size": 2, + "gradient_accumulation_steps": 2, + "seq_len": 512, + "vllm_port": VLLM_PORT, + "vllm_restart_interval": 10, + "use_wandb": False, + "wandb_project": "", + "wandb_group": "", + "save_path": "./trained_model_checkpoints_local_test", } # Flags for launching DatasetEnv serve DATASET_FLAGS = [ - '--group_size', '4', - '--max_num_workers', '2', - '--rollout_server_url', f"http://{API_HOST}:{API_PORT}", - '--tokenizer_name', TOKENIZER_NAME, - '--use_wandb', - '--wandb_name', 'dataset_env_local_test', - '--max_token_length', str(TRAINER_CONFIG['seq_len']), - '--ensure_scores_are_not_same', - '--dataset_name', 'HuggingFaceH4/testing_self_instruct_process_essays', - '--split', 'train[:100]', - '--prompt_field', 'prompt', - '--answer_field', 'answer', - '--reward_functions', 'length', - '--max_tokens', '128', - '--temperature', '0.7', - '--model_name', MODEL_NAME, - '--base_url', f"http://{VLLM_HOST}:{VLLM_PORT}", - '--slurm', - '--testing', + "--group_size", + "4", + "--max_num_workers", + "2", + "--rollout_server_url", + f"http://{API_HOST}:{API_PORT}", + "--tokenizer_name", + TOKENIZER_NAME, + "--use_wandb", + "--wandb_name", + "dataset_env_local_test", + "--max_token_length", + str(TRAINER_CONFIG["seq_len"]), + "--ensure_scores_are_not_same", + "--dataset_name", + "HuggingFaceH4/testing_self_instruct_process_essays", + "--split", + "train[:100]", + "--prompt_field", + "prompt", + "--answer_field", + "answer", + "--reward_functions", + "length", + "--max_tokens", + "128", + "--temperature", + "0.7", + "--model_name", + MODEL_NAME, + "--base_url", + f"http://{VLLM_HOST}:{VLLM_PORT}", + "--slurm", + "--testing", ] # Track background processes for cleanup @@ -106,6 +123,7 @@ def cleanup_processes(): print(f"PID {p.pid} already exited.") print("Cleanup complete.") + atexit.register(cleanup_processes) @@ -113,6 +131,7 @@ def handle_signal(sig, frame): print(f"\nSignal {sig} received; exiting.") sys.exit(0) + signal.signal(signal.SIGINT, handle_signal) signal.signal(signal.SIGTERM, handle_signal) @@ -121,10 +140,12 @@ def main(): # 1) Start the API server print("--- Starting Trajectory Handler API Server ---") api_cmd = [ - 'uvicorn', - 'atroposlib.api.server:app', - '--host', API_HOST, - '--port', str(API_PORT), + "uvicorn", + "atroposlib.api.server:app", + "--host", + API_HOST, + "--port", + str(API_PORT), ] print(f"$ {' '.join(api_cmd)}") api_proc = subprocess.Popen(api_cmd) @@ -133,7 +154,12 @@ def main(): # 2) Start the dataset environment print("\n--- Starting Dataset Environment ---") - env_cmd = ['python', '-m', 'environments.dataset_environment.dataset_env', 'serve'] + DATASET_FLAGS + env_cmd = [ + "python", + "-m", + "environments.dataset_environment.dataset_env", + "serve", + ] + DATASET_FLAGS print(f"$ {' '.join(env_cmd)}") env_proc = subprocess.Popen(env_cmd) processes.append(env_proc) @@ -150,5 +176,5 @@ def main(): print("--- Training complete ---") -if __name__ == '__main__': - main() \ No newline at end of file +if __name__ == "__main__": + main() diff --git a/environments/fundamental_prediction_environment.py b/environments/fundamental_prediction_environment.py index 4b9bd130..7eda4110 100644 --- a/environments/fundamental_prediction_environment.py +++ b/environments/fundamental_prediction_environment.py @@ -532,7 +532,8 @@ class FundamentalPredictionEnv(BaseEnv): else 0 ) wandb_metrics["train/combined_score"] = combined_score - except: + except Exception as e: + print(f"Error calculating combined score: {e}") pass # Clear the buffers after logging diff --git a/environments/gym_taxi.py b/environments/gym_taxi.py new file mode 100644 index 00000000..382959bc --- /dev/null +++ b/environments/gym_taxi.py @@ -0,0 +1,331 @@ +from typing import Dict, List, Optional, Tuple + +import gymnasium as gym + +from atroposlib.envs.base import APIServerConfig, BaseEnv, BaseEnvConfig, ScoredDataItem +from atroposlib.type_definitions import Item + +start_msg = """### Description +There are four designated locations in the grid world indicated by R(ed), +G(reen), Y(ellow), and B(lue). When the episode starts, the taxi starts off +at a random square and the passenger is at a random location. The taxi +drives to the passenger's location, picks up the passenger, drives to the +passenger's destination (another one of the four specified locations), and +then drops off the passenger. Once the passenger is dropped off, the episode ends. + +Map: + + +---------+ + |R: | : :G| + | : | : : | + | : : : : | + | | : | : | + |Y| : |B: | + +---------+ + +### Actions +There are 6 discrete deterministic actions: +- 0: move south (increases row index) +- 1: move north (decreases row index) +- 2: move east (increases column index) +- 3: move west (decreases column index) +- 4: pickup passenger (IF on a letter location, AND passenger is located at the same location, pickup passenger) +- 5: drop off passenger + +### Observations + +Passenger locations: +- 0: R(ed) +- 1: G(reen) +- 2: Y(ellow) +- 3: B(lue) +- 4: in taxi + +Destinations: +- 0: R(ed) (Row 0, Col 0) +- 1: G(reen) (Row 4, Col 4) +- 2: Y(ellow) (Row 0, Col 4) +- 3: B(lue) (Row 3, Col 3) + +### Instructions +Please perform the actions that will let you pick up and/or drop off the passenger. +Please respond with the action number only. +You cannot move the taxi into walls, which are displayed as | in the map. : means you are free to move through that column. + + +For an example, if the passenger is at R, and the destination is G, and the taxi is at (2, 2), then here are the following actions to solve this in the correct order: + +3 (move west) +3 (move west) +1 (move north) +1 (move north) +4 (pickup passenger) +0 (move south) +0 (move south) +2 (move east) +2 (move east) +2 (move east) +2 (move east) +0 (move south) +0 (move south) +5 (drop off passenger) + +If you are stuck, try moving to row idx 2, as there are no walls there. + +Submit your response as a number between 0 and 5 only to perform the discrete action. +Each turn we will give you the current state of the environment, and you will need to respond with the action number only from the available actions.""" # noqa: E501 + + +def decode(i): + out = [] + out.append(i % 4) + i = i // 4 + out.append(i % 5) + i = i // 5 + out.append(i % 5) + i = i // 5 + out.append(i) + assert 0 <= i < 5 + x = reversed(out) + # Making it explicit so I don't have to look into gym code + taxi_row, taxi_col, pass_idx, dest_idx = x + return taxi_row, taxi_col, pass_idx, dest_idx + + +# Note: Works for both the passenger and the destination +TO_LOC_MAP = { + 0: "R(Row 0, Col 0)", + 1: "G (Row 4, Col 4)", + 2: "Y (Row 0, Col 4)", + 3: "B (Row 3, Col 3)", + 4: "in taxi", +} +MAP_LOC = {0: (0, 0), 1: (4, 4), 2: (0, 4), 3: (3, 3)} +TO_ACTION_MAP = { + 0: "south", + 1: "north", + 2: "east", + 3: "west", + 4: "pickup", + 5: "dropoff", +} + + +def state_render_to_user_msg(last_state, state, action_mask, render): + taxi_row, taxi_col, pass_idx, dest_idx = decode(state) + if last_state is not None: + last_taxi_row, last_taxi_col, last_pass_idx, last_dest_idx = decode(last_state) + available_actions = "\n".join( + [ + f"- {i}: {TO_ACTION_MAP[i]}" + for i in range(6) + if (action_mask[i] == 1) + and ( + (i != 5) + or ( + (i == 5) + and (taxi_row == MAP_LOC[dest_idx][0]) + and (taxi_col == MAP_LOC[dest_idx][1]) + ) + ) + ] + ) + if last_state is not None: + ret_str = ( + f"Previous Taxi Location: Row: {last_taxi_row}, Col: {last_taxi_col}\n" + ) + else: + ret_str = "" + ret_str += ( + f"Current state:\nTaxi: Row: {taxi_row}, Col: {taxi_col}\nPassenger: {TO_LOC_MAP[pass_idx]}\n" + f"Destination: {TO_LOC_MAP[dest_idx]}\n\n" + f"Map:\n{render}\n\n" + f"Available actions:\n{available_actions}" + ) + if ( + (pass_idx == 4) + and (taxi_row == MAP_LOC[dest_idx][0]) + and (taxi_col == MAP_LOC[dest_idx][1]) + ): + ret_str += "\n\nPlease drop off the passenger." + elif pass_idx == 4: + ret_str += f"\n\nPlease move the taxi to {TO_LOC_MAP[dest_idx]} to drop off the passenger." + elif (taxi_row == MAP_LOC[pass_idx][0]) and (taxi_col == MAP_LOC[pass_idx][1]): + ret_str += "\n\nPlease pick up the passenger." + else: + ret_str += f"\n\nPlease move the taxi to {TO_LOC_MAP[pass_idx]} to pick up the passenger." + return ret_str + + +class GymTaxiEnv(BaseEnv): + + name = "gym_taxi" + + def __init__( + self, + config: BaseEnvConfig, + server_configs: List[APIServerConfig], + slurm=True, + testing=False, + ): + super().__init__(config, server_configs, slurm, testing) + self.percent_correct_buffer = list() + self.percent_picked_up_passenger_buffer = list() + self.eval_metrics = list() + # Add tracking for wandb visualizations + self.rollouts_for_wandb = [] + self.completion_lengths = [] + self.print_this_env = False + + @classmethod + def config_init(cls) -> Tuple[BaseEnvConfig, List[APIServerConfig]]: + env_config = BaseEnvConfig( + tokenizer_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview", + group_size=32, + use_wandb=True, + rollout_server_url="http://localhost:8000", + max_token_length=8192, + wandb_name="gym_taxi", + ) + server_configs = [ + APIServerConfig( + model_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview", + base_url="http://localhost:9001/v1", + api_key="x", + num_requests_for_eval=256, + ), + ] + + return env_config, server_configs + + async def wandb_log(self, wandb_metrics: Optional[Dict] = None): + if wandb_metrics is None: + wandb_metrics = {} + + # Try to calculate percent_correct, pass if there's a division by zero + try: + wandb_metrics["train/percent_correct"] = sum( + self.percent_correct_buffer + ) / len(self.percent_correct_buffer) + except ZeroDivisionError: + # Skip if buffer is empty + pass + try: + wandb_metrics["train/percent_picked_up_passenger"] = sum( + self.percent_picked_up_passenger_buffer + ) / len(self.percent_picked_up_passenger_buffer) + except ZeroDivisionError: + # Skip if buffer is empty + pass + + self.percent_correct_buffer = list() + self.percent_picked_up_passenger_buffer = list() + for item in self.eval_metrics: + wandb_metrics[item[0]] = item[1] + self.eval_metrics = list() + # Call the parent method to handle the server metrics + await super().wandb_log(wandb_metrics) + + async def setup(self): + self.iter = 0 + + async def evaluate(self, *args, **kwargs): + pass + + async def collect_trajectory( + self, item: Item + ) -> Tuple[Optional[ScoredDataItem], List[Item]]: + # Grab a dedicated llm server to take advantage of caching + async with self.server.dedicated_server() as server: + env = gym.make("Taxi-v3", render_mode="ansi") + state, info = env.reset(seed=item["seed"]) + last_state = None + taxi_row, taxi_col, pass_idx, dest_idx = decode(state) + init_msg = f"{start_msg}\n\n" + state_render_to_user_msg( + last_state, state, info["action_mask"], env.render() + ) + messages = [{"role": "user", "content": init_msg}] + score = -1 + while True: + if ( + len(self.tokenizer.apply_chat_template(messages)) + > self.config.max_token_length - 10 + ): + break + max_tokens = self.config.max_token_length - len( + self.tokenizer.apply_chat_template( + messages, add_generation_prompt=True + ) + ) + chat_completions = await server.chat_completion( + messages=messages, + n=1, + max_tokens=max_tokens, + ) + choice = ( + chat_completions.choices[0] + .message.content.strip() + .replace(".", "")[-1] + ) + messages.append( + { + "role": "assistant", + "content": chat_completions.choices[0].message.content, + } + ) + if choice.isdigit() and 0 <= int(choice) <= 5: + action = int(choice) + else: + break + if info["action_mask"][action] == 0: + break + if action == 3: + # picked up passenger + score = 0 + next_state, reward, terminated, truncated, info = env.step(action) + last_state = state + state = next_state + if terminated: + score = 1 + break + messages.append( + { + "role": "user", + "content": state_render_to_user_msg( + last_state, state, info["action_mask"], env.render() + ), + } + ) + self.percent_correct_buffer.append(max(score, 0)) + self.percent_picked_up_passenger_buffer.append(1 if score >= 0 else 0) + tokens = self.tokenizer.apply_chat_template(messages) + masks = [] + for i, msg in enumerate(messages): + if i == len(messages) - 1: + masks.extend(tokens[len(masks) :]) + else: + curr_tokens = self.tokenizer.apply_chat_template( + messages[: i + 1], + add_generation_prompt=messages[i + 1]["role"] == "assistant", + ) + if messages[i]["role"] == "user": + masks.extend([-100] * (len(curr_tokens) - len(masks))) + else: + masks.extend(curr_tokens[len(masks) :]) + scored_data_item = ScoredDataItem( + messages=messages, + finish_reason=score, + tokens=tokens, + masks=masks, + scores=score, + ) + return scored_data_item, [] + + async def get_next_item(self): + next_item = {"seed": self.iter} + self.iter += 1 + return next_item + + +if __name__ == "__main__": + GymTaxiEnv.cli() diff --git a/environments/math_server_zero.py b/environments/math_server_zero.py index d8a93158..49e5db7e 100644 --- a/environments/math_server_zero.py +++ b/environments/math_server_zero.py @@ -18,11 +18,11 @@ from pydantic import Field from tqdm.asyncio import tqdm_asyncio from atroposlib.envs.base import ( - APIServerConfig, BaseEnv, BaseEnvConfig, EvalHandlingEnum, ScoredDataGroup, + ServerBaseline, ) prompt_format = ( @@ -115,7 +115,7 @@ class MathEnv(BaseEnv): def __init__( self, config: RSConfig, - server_configs: List[APIServerConfig], + server_configs: ServerBaseline, slurm=True, testing=False, ): @@ -133,7 +133,7 @@ class MathEnv(BaseEnv): self.iter = 0 @classmethod - def config_init(cls) -> Tuple[RSConfig, List[APIServerConfig]]: + def config_init(cls) -> Tuple[RSConfig, ServerBaseline]: env_config = RSConfig( tokenizer_name="Qwen/Qwen2.5-7B", group_size=8, @@ -147,14 +147,10 @@ class MathEnv(BaseEnv): eval_handling=EvalHandlingEnum.LIMIT_TRAIN, eval_limit_ratio=0.1, ) - server_configs = [ - APIServerConfig( - model_name="default", - base_url="http://localhost:9004/v1", - api_key="x", - num_requests_for_eval=256, # since evaling only on one... - ), - ] + server_configs = ServerBaseline( + model_name="default", + num_requests_for_eval=256, # since evaling only on one... + ) return env_config, server_configs @@ -222,8 +218,8 @@ class MathEnv(BaseEnv): ) ) for name, t_dataset in zip( - ["amc23", "minerva", "olympiad"], - [amc_test_data, minerva_test_data, olympiad_test_data], + ["amc23", "minerva"], + [amc_test_data, minerva_test_data], ): for item in t_dataset: self.test.append( @@ -235,6 +231,17 @@ class MathEnv(BaseEnv): name, ) ) + for name, t_dataset in zip(["olympiad"], [olympiad_test_data]): + for item in t_dataset: + self.test.append( + ( + prompt_format.format( + prompt=problem_format.format(problem=item["question"]) + ), + item["final_answer"][0], + name, + ) + ) return async def rollout_and_score_eval(self, question, answer, subset): diff --git a/environments/multimodal_dpo/clevr_cogen_a_train.py b/environments/multimodal_dpo/clevr_cogen_a_train.py index 8d787646..8aa7f6d1 100644 --- a/environments/multimodal_dpo/clevr_cogen_a_train.py +++ b/environments/multimodal_dpo/clevr_cogen_a_train.py @@ -1,9 +1,7 @@ import base64 import json -import os import random import re -import sys import traceback from typing import List, Optional, Tuple @@ -29,13 +27,11 @@ class MultimodalExampleEnv(BaseEnv): async def collect_trajectories( self, item: Item ) -> Tuple[GameHistory | None, List[Item]]: - print("DEBUG: Starting collect_trajectories") to_score = list() to_backlog = list() # Get the current image if it was stored if hasattr(self, "current_image"): - print("DEBUG: Using current_image for multimodal content") # Convert PIL image to base64 import io @@ -61,14 +57,12 @@ class MultimodalExampleEnv(BaseEnv): if not text_content: text_content = "Please solve this problem and provide your answer as \\boxed{answer}." - except Exception as e: - print(f"DEBUG: Error parsing JSON: {e}") + except Exception: text_content = "Please solve this problem and provide your answer as \\boxed{answer}." else: text_content = user_content # Create messages with the new format - print("DEBUG: Creating multimodal message with new format") messages = [ { "role": "system", @@ -89,7 +83,6 @@ class MultimodalExampleEnv(BaseEnv): ] else: - print("DEBUG: No image available, using text-only message") messages = [ { "role": "system", @@ -98,26 +91,20 @@ class MultimodalExampleEnv(BaseEnv): dict(item[0][0]), ] - print("DEBUG: About to call chat_completion") chat_completions = await self.server.chat_completion( messages=messages, n=self.config.group_size, max_tokens=1024 * 2, timeout=60, # Add timeout to prevent hanging (60 seconds is more reasonable) ) - print("DEBUG: chat_completion call successful") for i, chat_completion in enumerate(chat_completions.choices): - print(f"DEBUG: Processing completion {i+1}/{len(chat_completions.choices)}") messages = ( dict(item[0][0]), {"role": "assistant", "content": chat_completion.message.content}, ) to_score.append((messages, item[1], base64_image)) - print("DEBUG: Finished processing completions") - - print("DEBUG: Returning from collect_trajectories") return to_score, to_backlog async def postprocess_histories( @@ -146,20 +133,12 @@ class MultimodalExampleEnv(BaseEnv): Get the next items to be rolled out, including the image """ try: - print("DEBUG: Starting get_next_item") - # Get next dataset item next_item = self.train[self.iter % len(self.train)] self.iter += 1 - print(f"DEBUG: Retrieved dataset item {self.iter-1}") - - # For debugging, we'll use a simple text-only prompt and store the image separately - # This is because the collect_trajectories method will handle the multimodal formatting - # Store image as a class attribute so collect_trajectories can access it self.current_image = next_item["image"] - print("DEBUG: Stored image in current_image attribute") # Create a simple text prompt - the image will be added in collect_trajectories # This avoids the unhashable type error with lists in frozensets @@ -182,11 +161,9 @@ class MultimodalExampleEnv(BaseEnv): img_byte_arr = img_byte_arr.getvalue() base64_image = base64.b64encode(img_byte_arr).decode("utf-8") - print("DEBUG: Created simple text-only prompt for get_next_item") return (prompt, answer, base64_image) - except Exception as e: - print(f"DEBUG: Error in get_next_item: {str(e)}") + except Exception: traceback.print_exc() # Create a dummy item as fallback @@ -217,9 +194,6 @@ class MultimodalExampleEnv(BaseEnv): model_answer = ( item[0][-1]["content"].split("\\boxed{")[-1].split("}")[0] ) - print( - f"DEBUG: Model answer: {model_answer} and RG data: {rollout_group_data[0][1]}" - ) pattern = r"\s*(\d{1,2})\s*" string = rollout_group_data[0][1] @@ -248,35 +222,25 @@ class MultimodalExampleEnv(BaseEnv): @classmethod def config_init(cls) -> Tuple[BaseEnvConfig, List[APIServerConfig]]: - if not os.environ.get("OPENAI_API_KEY"): - print("ERROR: OPENAI_API_KEY environment variable is not set!") - print("Please set it using: export OPENAI_API_KEY=your_api_key") - sys.exit(1) - - print( - f"DEBUG: Using API key starting with: {os.environ.get('OPENAI_API_KEY')[:5]}..." - ) - config = BaseEnvConfig( - wandb_name="clevr_cogen", - tokenizer_name="gpt2", - group_size=2, - use_wandb=False, + wandb_name="clevr_cogen_a_train", + tokenizer_name="Qwen/Qwen2-VL-2B-Instruct", + group_size=8, + use_wandb=True, max_num_workers=2, rollout_server_url="http://localhost:8000", total_steps=1000, - batch_size=1, - steps_per_eval=10, - ensure_scores_are_not_same=False, + batch_size=12, + steps_per_eval=100, + max_token_length=2048, ) - print("DEBUG: Creating OpenAI configuration") server_configs = [ APIServerConfig( - model_name="gpt-4o", # Using GPT-4o which has multimodal capabilities - base_url=None, - api_key=os.environ.get("OPENAI_API_KEY"), - num_requests_for_eval=1, + model_name="Qwen/Qwen2-VL-2B-Instruct", + base_url="http://localhost:9001/v1", + api_key="x", + num_requests_for_eval=256, ), ] diff --git a/environments/multimodal_dpo/clevr_complex.py b/environments/multimodal_dpo/clevr_complex.py index 5f6f8dac..6003fe39 100644 --- a/environments/multimodal_dpo/clevr_complex.py +++ b/environments/multimodal_dpo/clevr_complex.py @@ -1,8 +1,6 @@ import base64 import json -import os import random -import sys import traceback from typing import List, Optional, Tuple @@ -28,13 +26,11 @@ class MultimodalComplexEnv(BaseEnv): async def collect_trajectories( self, item: Item ) -> Tuple[GameHistory | None, List[Item]]: - print("DEBUG: Starting collect_trajectories") to_score = list() to_backlog = list() # Get the current image if it was stored if hasattr(self, "current_image"): - print("DEBUG: Using current_image for multimodal content") # Convert PIL image to base64 import io @@ -60,14 +56,12 @@ class MultimodalComplexEnv(BaseEnv): if not text_content: text_content = "Please solve this problem and provide your answer as \\boxed{answer}." - except Exception as e: - print(f"DEBUG: Error parsing JSON: {e}") + except Exception: text_content = "Please solve this problem and provide your answer as \\boxed{answer}." else: text_content = user_content # Create messages with the new format - print("DEBUG: Creating multimodal message with new format") messages = [ { "role": "system", @@ -88,7 +82,6 @@ class MultimodalComplexEnv(BaseEnv): ] else: - print("DEBUG: No image available, using text-only message") messages = [ { "role": "system", @@ -97,32 +90,23 @@ class MultimodalComplexEnv(BaseEnv): dict(item[0][0]), ] - print("DEBUG: About to call chat_completion") chat_completions = await self.server.chat_completion( messages=messages, n=self.config.group_size, max_tokens=1024 * 2, timeout=60, # Add timeout to prevent hanging (60 seconds is more reasonable) ) - print("DEBUG: chat_completion call successful") for i, chat_completion in enumerate(chat_completions.choices): - print(f"DEBUG: Processing completion {i+1}/{len(chat_completions.choices)}") messages = ( dict(item[0][0]), {"role": "assistant", "content": chat_completion.message.content}, ) to_score.append((messages, item[1], base64_image)) - print("DEBUG: Finished processing completions") + to_postprocess = await self.score(to_score) - print("DEBUG: Returning from collect_trajectories") - return to_score, to_backlog - - async def postprocess_histories( - self, trajectories: List[GameHistory] - ) -> ScoredDataGroup: - pass + return to_postprocess, to_backlog async def evaluate(self, *args, **kwargs): """ @@ -145,20 +129,13 @@ class MultimodalComplexEnv(BaseEnv): Get the next items to be rolled out, including the image """ try: - print("DEBUG: Starting get_next_item") # Get next dataset item next_item = self.train[self.iter % len(self.train)] self.iter += 1 - print(f"DEBUG: Retrieved dataset item {self.iter-1}") - - # For debugging, we'll use a simple text-only prompt and store the image separately - # This is because the collect_trajectories method will handle the multimodal formatting - # Store image as a class attribute so collect_trajectories can access it self.current_image = next_item["image"] - print("DEBUG: Stored image in current_image attribute") # Create a simple text prompt - the image will be added in collect_trajectories # This avoids the unhashable type error with lists in frozensets @@ -178,11 +155,9 @@ class MultimodalComplexEnv(BaseEnv): img_byte_arr = img_byte_arr.getvalue() base64_image = base64.b64encode(img_byte_arr).decode("utf-8") - print("DEBUG: Created simple text-only prompt for get_next_item") return (prompt, answer, base64_image) - except Exception as e: - print(f"DEBUG: Error in get_next_item: {str(e)}") + except Exception: traceback.print_exc() # Create a dummy item as fallback @@ -213,9 +188,6 @@ class MultimodalComplexEnv(BaseEnv): model_answer = ( item[0][-1]["content"].split("\\boxed{")[-1].split("}")[0] ) - print( - f"DEBUG: Model answer: {model_answer} and RG data: {rollout_group_data[0][1]}" - ) # Handle both numeric and yes/no answers gold_answer = rollout_group_data[0][1] @@ -249,35 +221,25 @@ class MultimodalComplexEnv(BaseEnv): @classmethod def config_init(cls) -> Tuple[BaseEnvConfig, List[APIServerConfig]]: - if not os.environ.get("OPENAI_API_KEY"): - print("ERROR: OPENAI_API_KEY environment variable is not set!") - print("Please set it using: export OPENAI_API_KEY=your_api_key") - sys.exit(1) - - print( - f"DEBUG: Using API key starting with: {os.environ.get('OPENAI_API_KEY')[:5]}..." - ) - config = BaseEnvConfig( wandb_name="clevr_complex", - tokenizer_name="gpt2", - group_size=2, - use_wandb=False, + tokenizer_name="Qwen/Qwen2-VL-2B-Instruct", + group_size=8, + use_wandb=True, max_num_workers=2, rollout_server_url="http://localhost:8000", total_steps=1000, - batch_size=1, - steps_per_eval=10, - ensure_scores_are_not_same=False, + batch_size=12, + steps_per_eval=100, + max_token_length=2048, ) - print("DEBUG: Creating OpenAI configuration") server_configs = [ APIServerConfig( - model_name="gpt-4o", # Using GPT-4o which has multimodal capabilities - base_url=None, - api_key=os.environ.get("OPENAI_API_KEY"), - num_requests_for_eval=1, + model_name="Qwen/Qwen2-VL-2B-Instruct", + base_url="http://localhost:9001/v1", + api_key="x", + num_requests_for_eval=256, ), ] diff --git a/environments/multimodal_dpo/ocr_vqa.py b/environments/multimodal_dpo/ocr_vqa.py index bd60a111..6820c052 100644 --- a/environments/multimodal_dpo/ocr_vqa.py +++ b/environments/multimodal_dpo/ocr_vqa.py @@ -1,9 +1,7 @@ import base64 import io -import os import random import re -import sys import traceback from typing import List, Optional, Tuple @@ -73,13 +71,8 @@ class OcrVqaEnv(BaseEnv): history: GameHistory = (user_hist, assistant_hist) to_score.append((history, gold, base64_image)) - return to_score, to_backlog - - async def postprocess_histories( - self, trajectories: List[GameHistory] - ) -> ScoredDataGroup: - # No additional post-processing needed - pass + to_postprocess = await self.score(to_score) + return to_postprocess, to_backlog async def evaluate(self, *args, **kwargs): # No custom evaluation @@ -172,29 +165,25 @@ class OcrVqaEnv(BaseEnv): @classmethod def config_init(cls) -> Tuple[BaseEnvConfig, List[APIServerConfig]]: - if not os.environ.get("OPENAI_API_KEY"): - print("ERROR: OPENAI_API_KEY environment variable is not set!") - sys.exit(1) - config = BaseEnvConfig( wandb_name="ocr_vqa", - tokenizer_name="gpt2", - group_size=2, - use_wandb=False, + tokenizer_name="Qwen/Qwen2-VL-2B-Instruct", + group_size=8, + use_wandb=True, max_num_workers=2, rollout_server_url="http://localhost:8000", total_steps=1000, - batch_size=1, - steps_per_eval=10, - ensure_scores_are_not_same=False, + batch_size=12, + steps_per_eval=100, + max_token_length=2048, ) server_configs = [ APIServerConfig( - model_name="gpt-4o", - base_url=None, - api_key=os.environ.get("OPENAI_API_KEY"), - num_requests_for_eval=1, + model_name="Qwen/Qwen2-VL-2B-Instruct", + base_url="http://localhost:9001/v1", + api_key="x", + num_requests_for_eval=256, ), ] diff --git a/environments/multimodal_dpo/pixmo_clocks.py b/environments/multimodal_dpo/pixmo_clocks.py index 15f54ab6..02932178 100644 --- a/environments/multimodal_dpo/pixmo_clocks.py +++ b/environments/multimodal_dpo/pixmo_clocks.py @@ -1,9 +1,7 @@ import base64 import io -import os import random import re -import sys import traceback from typing import List, Optional, Tuple @@ -73,13 +71,9 @@ class ClockDatasetEnv(BaseEnv): history: GameHistory = (user_msg, assistant_msg) to_score.append((history, item[1], base64_image)) - return to_score, to_backlog + to_postprocess = await self.score(to_score) - async def postprocess_histories( - self, trajectories: List[GameHistory] - ) -> ScoredDataGroup: - # No custom post-processing - pass + return to_postprocess, to_backlog async def evaluate(self, *args, **kwargs): # No custom evaluation @@ -92,6 +86,7 @@ class ClockDatasetEnv(BaseEnv): self.iter = 0 async def get_next_item(self) -> Item: + try: entry = self.train[self.iter % len(self.train)] self.iter += 1 @@ -133,6 +128,7 @@ class ClockDatasetEnv(BaseEnv): scores["scores"] = [] scores["images"] = [] random.shuffle(rollout_group_data) + for item in rollout_group_data: out = tokenize_for_trainer(self.tokenizer, item[0]) tokens = out["tokens"] @@ -174,29 +170,25 @@ class ClockDatasetEnv(BaseEnv): @classmethod def config_init(cls) -> Tuple[BaseEnvConfig, List[APIServerConfig]]: - if not os.environ.get("OPENAI_API_KEY"): - print("ERROR: OPENAI_API_KEY environment variable is not set!") - sys.exit(1) - config = BaseEnvConfig( - wandb_name="clocks", - tokenizer_name="gpt2", - group_size=2, - use_wandb=False, + wandb_name="pixmo_clocks", + tokenizer_name="Qwen/Qwen2-VL-2B-Instruct", + group_size=8, + use_wandb=True, max_num_workers=2, rollout_server_url="http://localhost:8000", total_steps=1000, - batch_size=1, - steps_per_eval=10, - ensure_scores_are_not_same=False, + batch_size=12, + steps_per_eval=100, + max_token_length=2048, ) server_configs = [ APIServerConfig( - model_name="gpt-4o", - base_url=None, - api_key=os.environ.get("OPENAI_API_KEY"), - num_requests_for_eval=1, + model_name="Qwen/Qwen2-VL-2B-Instruct", + base_url="http://localhost:9001/v1", + api_key="x", + num_requests_for_eval=256, ), ] diff --git a/environments/multimodal_dpo/pixmo_count.py b/environments/multimodal_dpo/pixmo_count.py index 50f4d5c0..26c1928d 100644 --- a/environments/multimodal_dpo/pixmo_count.py +++ b/environments/multimodal_dpo/pixmo_count.py @@ -1,9 +1,7 @@ import base64 import io -import os import random import re -import sys import traceback from typing import List, Optional, Tuple @@ -108,13 +106,9 @@ class PixmoCountEnv(BaseEnv): history: GameHistory = (user_hist, assistant_hist) to_score.append((history, gold, base64_image)) - return to_score, to_backlog + to_postprocess = await self.score(to_score) - async def postprocess_histories( - self, trajectories: List[GameHistory] - ) -> ScoredDataGroup: - # No custom post-processing - pass + return to_postprocess, to_backlog async def evaluate(self, *args, **kwargs): # No custom evaluation @@ -163,29 +157,26 @@ class PixmoCountEnv(BaseEnv): @classmethod def config_init(cls) -> Tuple[BaseEnvConfig, List[APIServerConfig]]: - if not os.environ.get("OPENAI_API_KEY"): - print("ERROR: OPENAI_API_KEY environment variable is not set!") - sys.exit(1) config = BaseEnvConfig( wandb_name="pixmo_count", - tokenizer_name="gpt2", - group_size=2, - use_wandb=False, + tokenizer_name="Qwen/Qwen2-VL-2B-Instruct", + group_size=8, + use_wandb=True, max_num_workers=2, rollout_server_url="http://localhost:8000", total_steps=1000, - batch_size=1, - steps_per_eval=10, - ensure_scores_are_not_same=False, + batch_size=12, + steps_per_eval=100, + max_token_length=2048, ) server_configs = [ APIServerConfig( - model_name="gpt-4o", - base_url=None, - api_key=os.environ.get("OPENAI_API_KEY"), - num_requests_for_eval=1, + model_name="Qwen/Qwen2-VL-2B-Instruct", + base_url="http://localhost:9001/v1", + api_key="x", + num_requests_for_eval=256, ), ] diff --git a/environments/multimodal_dpo/pixmo_point_explanations.py b/environments/multimodal_dpo/pixmo_point_explanations.py index f44ac565..f1e68e06 100644 --- a/environments/multimodal_dpo/pixmo_point_explanations.py +++ b/environments/multimodal_dpo/pixmo_point_explanations.py @@ -1,9 +1,7 @@ import base64 import io -import os import random import re -import sys import traceback from typing import List, Optional, Tuple @@ -55,8 +53,7 @@ class PixmoPointExplanationsEnv(BaseEnv): img.save(buf, format="PNG") img_bytes = buf.getvalue() base64_image = base64.b64encode(img_bytes).decode("utf-8") - except Exception as e: - print(f"Error loading image from URL: {e}") + except Exception: base64_image = None return (prompt, gold_answer, base64_image) @@ -118,13 +115,9 @@ class PixmoPointExplanationsEnv(BaseEnv): history: GameHistory = (user_hist, assistant_hist) to_score.append((history, gold, base64_image)) - return to_score, to_backlog + to_postprocess = await self.score(to_score) - async def postprocess_histories( - self, trajectories: List[GameHistory] - ) -> ScoredDataGroup: - # No custom post-processing needed - pass + return to_postprocess, to_backlog async def evaluate(self, *args, **kwargs): # No custom evaluation @@ -174,29 +167,26 @@ class PixmoPointExplanationsEnv(BaseEnv): @classmethod def config_init(cls) -> Tuple[BaseEnvConfig, List[APIServerConfig]]: - if not os.environ.get("OPENAI_API_KEY"): - print("ERROR: OPENAI_API_KEY environment variable is not set!") - sys.exit(1) config = BaseEnvConfig( wandb_name="pixmo_point_explanations", - tokenizer_name="gpt2", - group_size=2, - use_wandb=False, + tokenizer_name="Qwen/Qwen2-VL-2B-Instruct", + group_size=8, + use_wandb=True, max_num_workers=2, rollout_server_url="http://localhost:8000", total_steps=1000, - batch_size=1, - steps_per_eval=10, - ensure_scores_are_not_same=False, + batch_size=12, + steps_per_eval=100, + max_token_length=2048, ) server_configs = [ APIServerConfig( - model_name="gpt-4o", - base_url=None, - api_key=os.environ.get("OPENAI_API_KEY"), - num_requests_for_eval=1, + model_name="Qwen/Qwen2-VL-2B-Instruct", + base_url="http://localhost:9001/v1", + api_key="x", + num_requests_for_eval=256, ), ] diff --git a/example_trainer/grpo.py b/example_trainer/grpo.py index fa7fca76..38273153 100644 --- a/example_trainer/grpo.py +++ b/example_trainer/grpo.py @@ -13,13 +13,12 @@ import numpy as np import requests import torch import torch.nn.functional as F +import wandb # Added for logging from pydantic import BaseModel, Field from tenacity import retry, stop_after_attempt, wait_exponential from torch.optim import AdamW from transformers import AutoModelForCausalLM, AutoTokenizer -import wandb # Added for logging - # Global variable to keep track of the vLLM process vllm_process = None diff --git a/pyproject.toml b/pyproject.toml index 50ad2c59..1547aee7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,13 +16,14 @@ dependencies = [ "markdown", "numpy", "wandb", + "gymnasium", "math-verify==0.7.0", "jinja2", "nltk", + "rich", "polars", "aiofiles", "jsonlines", - "torch", "pydantic-cli", "hf_transfer", ] @@ -31,6 +32,7 @@ dependencies = [ run-api = "atroposlib.cli.run_api:main" inference-node-wandb-watcher = "atroposlib.cli.inference_node_wandb_watcher:main" view-run = "atroposlib.cli.view_run:main" +view-run-multimodal = "atroposlib.cli.view_run_multimodal:main" atropos-sft-gen = "atroposlib.cli.sft:main" atropos-dpo-gen = "atroposlib.cli.dpo:main" @@ -38,6 +40,9 @@ atropos-dpo-gen = "atroposlib.cli.dpo:main" all = [ "atroposlib[dev,examples]" ] +rewardfns = [ + "torch" +] dev = [ "pytest", "pytest-asyncio", @@ -46,10 +51,11 @@ dev = [ "flake8", "isort", "mypy", - 'rich', + "rich", ] examples = [ - "gradio" + "gradio", + "atroposlib[rewardfns]" ] [build-system]