atropos/llm.txt
2025-04-29 12:10:10 -07:00

444 lines
28 KiB
Text

# Atropos Library Documentation (for LLM Context)
This document provides comprehensive information about the Atropos library, Nous Research's LLM RL Gym. It covers its purpose, features, usage, components, configuration, and contribution guidelines.
---
## 1. Introduction: Atropos - Nous Research's LLM RL Gym
Atropos is an LLM Reinforcement Learning Environments framework designed for collecting and evaluating LLM trajectories through diverse environments.
**Supported Environment Types:**
<div align="center">
| Environment Type | Examples | Purpose |
|---------------------------|--------------------------------------------|----------------------------------------------------|
| 📚 Dataset environments | GSM8K, MMLU | Evaluate and improve LLM performance on static data|
| 🎮 Online environments | Crosswords, Hangman | Train LLMs through interactive game-based learning |
| 🤖 RLAIF and RLHF | LLM Judge/Reward Models | Fine-tune LLMs using human feedback and alignment |
| 🔄 Multi-Turn RL | deepresearch, internal tool calling | Train LLMs on complex multi-step interactions |
</div>
Atropos provides a robust, scalable framework for **Reinforcement Learning Environments with LLMs**.
**Key Features:**
* **Multi-Turn & Asynchronous RL:** Efficiently supports complex, multi-turn, and asynchronous interactions, decoupling environment steps from policy updates.
* **Inference Agnostic:** Integrates with standard inference APIs (e.g., OpenAI, vLLM, sgLang), enabling easy switching between LLM providers and frameworks.
* **Trainer Independent:** Offers a standardized training interface for experimenting with different RL algorithms and frameworks without major code changes.
* **Scalable & Decentralized:** Easily scale by launching more environment instances (locally or across decentralized resources) that contribute rollouts to a central service.
* **Diverse Environment Integration:** Manages many varied environment types concurrently for heterogeneous, multi-modal training.
**Goal:** Provide a flexible, scalable, and standardized platform to accelerate LLM-based RL research across diverse, interactive settings.
---
## 5. Navigating the Repo
| Category | Description |
|-------------------------------|--------------------------------------------------|
| 📁 [`atroposlib/`](atroposlib/) | Core library containing base classes and utilities |
| 🎮 [`environments/`](environments/) | Collection of ready-to-use RL environments |
| 📚 [`example_trainer/`](example_trainer/) | Example training scripts and configurations |
**Key Documents:**
* **Base Environment Class:** `atroposlib/environments/README.md` (Detailed in Section 9 below)
* **Environments Overview:** `environments/README.md` (Detailed in Section 8 below)
* **Full Environment Config Options:** `CONFIG.md` (Detailed in Section 10 below)
* **Example Trainer:** `example_trainer/README.md` (Detailed in Section 7 below)
* **Contributing Guide:** `CONTRIBUTING.md` (Detailed in Section 11 below)
* **License:** `LICENSE.md` (Apache 2.0 license details)
---
## 6. Installation
Requires Python 3.10 or later.
```bash
# Core usage
pip install -e .
# Development (includes testing, linting tools)
pip install -e .[dev]
# Running examples (includes dependencies like vLLM, transformers)
pip install -e .[examples]
# Everything
pip install -e .[all]
```
**Important for Developers:** Install pre-commit hooks to ensure code quality:
```bash
pre-commit install
```
---
## 7. Quick Start Guide
1. **Create Your First Environment:**
* Review the [Base Environment Class Documentation](#9-core-library-atroposlib) (Section 9).
* Examine existing environments in [`environments/`](#8-environments) for examples.
2. **Run an Example Environment:**
```bash
# Start the central API server (trajectory handler) in the background
run-api &
# Start an environment server (e.g., GSM8K) connected to the API
python environments/gsm8k_server.py serve \
--tokenizer_name="Qwen/Qwen2.5-1.5B-Instruct" \
--model_name="Qwen/Qwen2.5-1.5B-Instruct" \
--slurm False # Assuming local run, set True for SLURM cluster
```
*Note: The model and tokenizer names are examples.*
3. **Training Your Model:**
* Refer to the [Example Trainer Guide](#7-training-with-the-example-trainer) (Section 7).
* Monitor progress via logging: completion lengths, eval accuracies, full rollouts/scores (see WandB image in original README).
* Multiple environments can run concurrently, pointing to the same `run-api` server.
**Logging:** Environments provide detailed logging, tracking completion lengths, eval accuracies, full rollouts, scores, etc. Supports WandB integration.
---
## 8. Environments
The `environments/` directory contains various RL environments.
### 8.1. Common Features Across Environments
1. **Training/Test Split:** Typically 98% training, 2% test, with fixed random shuffling (seed 42).
2. **Metrics Tracking:** Includes percent correct buffer, completion lengths, Wandb integration, and rollout tracking.
3. **Token Management:** Maximum token length limits, statistics tracking, and optional length penalties.
4. **Evaluation:** Separate evaluation on the test set with comprehensive metrics logging. Supports multiple completions per prompt.
5. **Usage Interface:** Environments generally follow a common interface:
* Initialize with `config` (BaseEnvConfig), `server_configs` (OpenAI API configs), `slurm` (bool), `testing` (bool).
* Key methods: `setup()`, `get_next_item()`, `collect_trajectories()`, `score()` (often part of postprocessing), `evaluate()`, `wandb_log()`.
### 8.2. Available Environments
#### 8.2.1. MCQA Thinking Environment (`mcqa_thinking_env.py`)
Multiple Choice Question Answering (MMLU dataset) requiring systematic thought.
* **Input Format:** MMLU items (`prompt`, `answer` index, `ground_truth` letter, `options` list).
* **System Prompt:**
```
You are a deep thinking AI, you may use extremely long chains of thought to deeply consider the problem and deliberate with yourself via systematic reasoning processes to help come to a correct solution prior to answering. You should enclose your thoughts and internal monologue inside <think> </think> tags, and then provide your solution or response to the problem.
```
* **Reward Function:**
* 1.0 for correct letter match.
* 0.0 for incorrect or malformed response (e.g., bad `<think>` tags).
* Length penalty applied *only if all responses in a group are correct*: scales linearly from 1.0 (<=50% max length) down to 0.0 (>=100% max length).
* Returns `None` if all scores in a group are identical (no training signal).
#### 8.2.2. GSM8K Environment (`gsm8k_server.py`)
Mathematical reasoning (GSM8K dataset).
* **Input Format:** GSM8K items (`question`, `answer` number).
* **System Prompt:**
```
You are a deep thinking AI, you may use extremely long chains of thought to deeply consider the problem and deliberate with yourself via systematic reasoning processes to help come to a correct solution prior to answering. You should enclose your thoughts and internal monologue inside <think> </think> tags, and then provide your solution or response to the problem.
You are allocated a maximum of 4096 tokens, please strive to use less.
You will then provide your answer like this: \boxed{your answer here}
It is important that you provide your answer in the correct format.
If you do not, you will not receive credit for your answer.
So please end your answer with \boxed{your answer here}
```
* **Reward Function:**
* 1.0 if `\boxed{}` answer matches ground truth (uses LaTeX verification).
* 0.0 if incorrect or ground truth isn't parseable.
* Length penalty applied *only if all responses in a group are correct*: scales linearly from 1.0 (<=50% max length) down to 0.0 (>=100% max length).
* Returns `None` if all scores in a group are identical.
#### 8.2.3. Tool Calling Environment (`tool_calling_server.py`)
Training models for structured function/tool calls (ShareGPT-Hermes function call dataset).
* **Input Format:** Conversations (`system`, `human`, `gpt` roles) with expected tool calls (JSON format).
* **System Prompt:** (Same "deep thinking AI" prompt as MCQA)
```
You are a deep thinking AI, you may use extremely long chains of thought to deeply consider the problem and deliberate with yourself via systematic reasoning processes to help come to a correct solution prior to answering. You should enclose your thoughts and internal monologue inside <think> </think> tags, and then provide your solution or response to the problem.
```
* **Reward Function:**
* 1.0 if *all* expected tool calls are present and *exactly* match (including nested JSON).
* 0.0 if any calls are missing, incorrect, or malformed.
* Length penalty applied *only if all responses in a group are correct*: scales linearly from 1.0 (<=50% max length) down to 0.0 (>=100% max length).
* Returns `None` if all scores in a group are identical.
---
## 9. Training with the Example Trainer
The `example_trainer/` directory provides `grpo.py`, a script demonstrating integration with Atropos using the GRPO algorithm.
**Note:** This is a *reference example* for API integration and basic setup, *not* optimized for large-scale training. It uses `vLLM` for inference (simulated data generation) and `transformers` for training.
### 9.1. Prerequisites
1. Python 3.8+.
2. Running Atropos API server (default: `http://localhost:8000`). Accessible via `run-api`.
3. Required Python packages: `torch`, `transformers`, `vllm`, `pydantic`, `numpy`, `requests`, `tenacity`, `wandb` (optional). Install via `pip install -r example_trainer/requirements.txt` or `pip install -e .[examples]`.
4. A running Atropos environment (e.g., `python environments/gsm8k_server.py serve --slurm False`).
### 9.2. Setup
1. Clone the Atropos repository.
2. Install dependencies (see Prerequisites).
3. Start the Atropos API: `run-api`.
4. Start an environment connected to the API (e.g., GSM8K example above).
### 9.3. Configuration (`grpo.py`)
Configuration is managed via the `TrainingConfig` Pydantic model within `grpo.py`.
**Key Parameters:**
* `model_name`: Hugging Face model identifier (e.g., `"Qwen/Qwen2.5-1.5B-Instruct"`).
* `training_steps`: Total optimization steps.
* `batch_size` / `gradient_accumulation_steps`: Control effective batch size.
* `lr`: Learning rate.
* `save_path`: Directory for model checkpoints (default: `./trained_model_checkpoints`).
* `vllm_port`: Port for the script's vLLM inference server instance.
* `vllm_restart_interval`: Steps between saving checkpoints and restarting vLLM with updated weights.
* `use_wandb`: Enable/disable Weights & Biases logging.
* `wandb_project`: W&B project name (required if `use_wandb=True`).
* `wandb_group`: Optional W&B group name.
**API Endpoints:** Assumes API at `http://localhost:8000`. Modify `register_trainer` and `get_batch` functions if different.
### 9.4. Running the Example
Navigate to the project root and run:
```bash
python example_trainer/grpo.py
```
### 9.5. Output
* **Console Logs:** Training progress (loss, logp), vLLM status.
* **Checkpoints:** Saved periodically in `save_path`. `final_model` directory upon completion.
* **WandB:** Logs sent to W&B if enabled (link printed to console).
* `temp.json`: Raw data from the last fetched batch (for debugging).
---
## 10. Core Library (`atroposlib`)
The `atroposlib/` directory contains the core framework components.
### 10.1. Base Environment (`atroposlib.envs.base.BaseEnv`)
This class provides the foundation for creating custom RL environments. Subclass `BaseEnv` and implement/override methods as needed.
**Core Methods to Implement:**
* **`async def setup(self)`**: Called once at the start. Use for initial setup (loading data, models, etc.).
* **`async def get_next_item(self) -> Item`**: Returns the next data item (prompt, state) for trajectory collection. Return `None` to pause the worker if no items are ready. `Item` is typically a Pydantic model defined by the environment.
* **`async def collect_trajectory(self, item: Item) -> Tuple[Any | None, List[Item]]`**: Defines logic for *one* trajectory collection step based on `item`. The base class runs this in parallel (`group_size` times). Returns a tuple: `(collected_data_for_this_step, list_of_new_backlog_items)`. The collected data can be any type suitable for later processing.
* **`async def evaluate(self, *args, **kwargs)`**: Called periodically (`steps_per_eval`) for evaluation runs. Implement your evaluation logic here. The base class provides `self.eval_workers` for parallel tasks.
**Optional Methods to Override:**
* **`async def collect_trajectories(self, item: Item) -> Tuple[Union[Optional[ScoredDataGroup], List[Optional[ScoredDataGroup]], List[Any | None]], List[Item]]`**: Override this *instead* of `collect_trajectory` for custom batch generation logic (generating the whole group at once). `ScoredDataGroup` is a structure usually containing prompts, responses, and scores.
* **`async def postprocess_histories(self, trajectories: Union[Optional[ScoredDataGroup], List[Optional[ScoredDataGroup]]]) -> Union[Optional[ScoredDataGroup], List[Optional[ScoredDataGroup]]]`**: Called after `collect_trajectories` and before sending data to the server. Use for final processing, scoring, filtering, or formatting of the collected group data.
* **`async def wandb_log(self, wandb_metrics: Optional[Dict] = None)`**: Called periodically for W&B logging. Add custom metrics to `wandb_metrics`. **Crucially, call `await super().wandb_log(wandb_metrics)`** at the end to include base metrics and rollouts.
* **`save_checkpoint(self, step, data=None)`**: Called automatically by the server based on `checkpoint_interval`. Saves the provided `data` dict (populated with environment state) to JSON. Override to customize *what* or *how* data is saved.
* **`@classmethod config_init(cls) -> Tuple[BaseEnvConfig, Union[ServerBaseline, List[OpenaiConfig]]]`**: Used by CLI `serve` command setup. Returns initial `BaseEnvConfig` and server config(s). Override for custom default CLI configurations. Default returns `cls.env_config_cls(), ServerBaseline()`.
* **`async def cleanup(self)`**: Called after each item processing (`handle_env`). Use for per-item cleanup if needed (rarely required).
**Provided Functionality:**
* **Parallel Trajectory Collection:** Base `collect_trajectories` handles running `collect_trajectory` in parallel.
* **Server Interaction:** Handles registration, config fetching, data sending (with retries via `handle_send_to_api`), status updates.
* **WandB Integration:** Setup, logging hook (`wandb_log`), rollout table helpers (`add_rollouts_for_wandb`, `create_rollout_table`).
* **Checkpointing:** Automatic triggering via server (`checkpoint_interval`), `save_checkpoint` method, automatic loading via `load_checkpoint(self)` on startup if `curr_step > 0`.
* **Worker Management:** Asynchronous task management (`add_train_workers`, `handle_env`).
* **Performance Monitoring:** Tracks and logs task durations, worker counts, etc.
* **CLI Integration:** `cli()` class method using `pydantic-cli` for easy `serve` commands. See `get_cli_serve_config_cls` and `get_cli_process_config_cls`.
### 10.2. Configuration Options (`atroposlib`)
Configuration is primarily managed via Pydantic models, often exposed through a CLI (`pydantic-cli`).
#### 10.2.1. Base Environment Config (`atroposlib.envs.base.BaseEnvConfig`)
| Parameter | Type | Default | Description |
| :------------------------------- | :----------------------- | :---------------------------------------------- | :--------------------------------------------------------------------------------------------------------- |
| `group_size` | `int` | `4` | Number of responses grouped for scoring. |
| `max_num_workers` | `int` | `-1` | Max workers. `-1` calculates from `max_num_workers_per_node`. |
| `max_eval_workers` | `int` | `16` | Max workers for evaluation. |
| `max_num_workers_per_node` | `int` | `8` | Max workers per node. |
| `steps_per_eval` | `int` | `100` | Steps between evaluations. |
| `max_token_length` | `int` | `2048` | Max token length for generations. |
| `eval_handling` | `EvalHandlingEnum` | `EvalHandlingEnum.STOP_TRAIN` | How evals affect training workers (`STOP_TRAIN`, `LIMIT_TRAIN`, `NONE`). |
| `eval_limit_ratio` | `float` | `0.5` | Ratio of training workers limited during evals (if `eval_handling` is `LIMIT_TRAIN`). |
| `inference_weight` | `float` | `1.0` | Inference weight (set by trainer/policy). `-1` ignores if handled specially. |
| `batch_size` | `int` | `-1` | Training batch size (usually set by trainer via API). |
| `max_batches_offpolicy` | `int` | `3` | Max number of off-policy batches queued. |
| `tokenizer_name` | `str` | `"NousResearch/DeepHermes-3-Llama-3-1B-Preview"` | Default Hugging Face tokenizer. |
| `use_wandb` | `bool` | `True` | Enable/disable W&B logging. |
| `rollout_server_url` | `str` | `"http://localhost:8000"` | URL of the central rollout server (FastAPI). |
| `total_steps` | `int` | `1000` | Total steps to run (can be overridden by trainer). |
| `wandb_name` | `str | None` | `None` | W&B run name (often set automatically). |
| `num_rollouts_to_keep` | `int` | `32` | Number of full rollouts to display on W&B table. |
| `num_rollouts_per_group_for_logging` | `int` | `1` | Rollouts per group to keep for logging. `-1` keeps all. |
| `ensure_scores_are_not_same` | `bool` | `True` | Ensure scores in a group aren't identical (reject group if they are). Set `False` if identical scores are valid. |
| `data_path_to_save_groups` | `str | None` | `None` | If set, save generated/scored groups to this JSONL file path. |
| `min_items_sent_before_logging` | `int` | `2` | Min API sends before logging metrics. `<=0` logs every time. |
#### 10.2.2. Server Manager Config (`atroposlib.envs.server_handling.server_manager.ServerManagerConfig`)
Settings for the `ServerManager` which handles inference server interactions.
| Parameter | Type | Default | Description |
| :-------- | :------ | :------ | :------------------------------------------------ |
| `slurm` | `bool` | `True` | Whether the environment is running on SLURM. |
| `testing` | `bool` | `False` | If `True`, uses mock OpenAI data (for testing). |
#### 10.2.3. Server Baseline Config (`atroposlib.envs.server_handling.server_manager.ServerBaseline`)
Default settings used by `ServerManager` if specific `OpenaiConfig` list isn't provided (e.g., for local/SLURM discovery).
| Parameter | Type | Default | Description |
| :------------------------- | :------ | :-------- | :------------------------------------------------------------------------------------------------------ |
| `timeout` | `int` | `1200` | Request timeout (seconds). |
| `num_max_requests_at_once` | `int` | `512` | Max concurrent requests (training). Divide by generation `n` param. |
| `num_requests_for_eval` | `int` | `64` | Max concurrent requests (evaluation). |
| `model_name` | `str` | `default` | Default model name for inference calls. |
| `rolling_buffer_length` | `int` | `1000` | Buffer length for server metrics (timings, attempts). |
#### 10.2.4. OpenAI Server Config (`atroposlib.envs.server_handling.openai_server.OpenaiConfig`)
Configuration for individual OpenAI-compatible API servers (official OpenAI, local vLLM/SGLang, etc.). A list of these can be passed to the environment.
| Parameter | Type | Default | Description |
| :------------------------- | :----------- | :-------- | :------------------------------------------------------------------------------------------------------ |
| `api_key` | `str | None` | `None` | API key. Use `"x"` or any non-empty string for local servers without auth. `None` might imply env var. |
| `base_url` | `str | None` | `None` | API endpoint URL. `None` for official OpenAI. Local: e.g., `http://localhost:9004/v1`. |
| `timeout` | `int` | `1200` | Request timeout (seconds). |
| `num_max_requests_at_once` | `int` | `512` | Max concurrent requests (training). Divide by generation `n`. |
| `num_requests_for_eval` | `int` | `64` | Max concurrent requests (evaluation). |
| `model_name` | `str` | `default` | **Required.** Model name for this server (e.g., `"gpt-4"`, `"NousResearch/..."`). |
| `rolling_buffer_length` | `int` | `1000` | Buffer length for this server's metrics. |
---
## 11. Debugging Tools
The trajectory-handler provides local debugging tools:
* **Flexible Model Provider Support:** Natively supports any OpenAI API-compliant provider. Provide `base_url` and `api_key` for local testing/running.
* **View Run (`view-run`):** Launch a Gradio UI after starting the API (`run-api`) and an environment (`python environments/gsm8k_server.py serve`). Use `view-run` command to inspect batches of rollouts visually.
* **Offline Data Generation:**
* `atropos-sft-gen`: Collect rollouts and format for Supervised Fine-Tuning (SFT).
* `atropos-dpo-gen`: Collect rollouts and format for Direct Preference Optimization (DPO).
---
## 12. Contributing to Atropos
Contributions are welcome! Follow these guidelines.
### 12.1. How We Develop
* **GitHub:** Used for hosting, issue tracking, and Pull Requests (PRs).
* **GitHub Flow:** Development happens via PRs merged into the `main` branch.
### 12.2. Getting Started
1. **Fork:** Create your copy of the [repository](https://github.com/NousResearch/atropos).
2. **Clone:** `git clone https://github.com/your-username/atropos.git && cd atropos`
3. **Setup Dev Env:**
```bash
python -m venv .venv
source .venv/bin/activate # Windows: .venv\Scripts\activate
pip install -e ".[dev]" # Installs core + dev dependencies
```
4. **Install Pre-commit Hooks:**
```bash
pre-commit install
```
(Runs linters/formatters automatically on commit)
### 12.3. Running Tests
Uses `pytest`.
```bash
pytest
```
Ensure all tests pass before submitting a PR.
### 12.4. How to Contribute
* **Reporting Bugs:** Use the **Bug Report** issue template on GitHub Issues. Provide details: summary, steps to reproduce, expected vs. actual behavior, environment info, error messages/logs.
* **Suggesting Enhancements:** Use the **Feature Request** issue template. Discuss the idea first via an issue.
* **Submitting Changes (Pull Requests):**
1. Create a branch from `main`: `git checkout -b your-branch-name main`
2. Make changes, write code.
3. Add tests if applicable.
4. Update documentation (READMEs, docstrings) if APIs change.
5. Run tests: `pytest`.
6. Ensure code quality (pre-commit hooks run on commit, or run manually: `pre-commit run --all-files`).
7. Commit changes with clear messages: `git commit -m "feat: Describe feature or fix"`
8. Push branch: `git push origin your-branch-name`
9. Open a PR on GitHub from your fork's branch to `NousResearch/atropos:main`.
10. **Use the correct PR template:**
* `environment_pull_request_template.md` for environment changes.
* `non_environment_pull_request_template.md` for other changes.
11. Provide a clear title, description, link relevant issues (e.g., `Closes #123`).
### 12.5. Code Style
* PEP 8 enforced by `black`, `flake8`, `isort` via `pre-commit`.
* Manual check/fix: `pre-commit run --all-files`. Address `flake8` errors manually if needed.
### 12.6. License for Contributions
Contributions are submitted under the Apache License 2.0, the same license as the project.
### 12.7. Environment Contribution Guidelines
* **Legal/GitHub Compliance:** No illegal content. Must comply with GitHub TOS.
* **Explicit Content:** May be considered if clearly labeled and legally compliant.
* **Game Environments:** Welcome, but avoid reverse-engineered commercial games. Ensure rights to assets. Open-source/permissive licenses preferred.
* **Ethical Considerations:** Avoid environments encouraging harm without educational context.
* Discuss potentially controversial environments via an issue first.
### 12.8. Contributor Code of Conduct
Follow the [Contributor Code of Conduct](CODE_OF_CONDUCT.md).
---
## 13. Citation
If Atropos is helpful in your work, please cite:
```latex
@misc{atropos,
title = {{Atropos - An Async First Environment Rollout Controller}},
author = {Dakota Mahan, Roger Jin, Teknium, Shannon Sands, Artem Yatsenko, Jai Suphavadeeprasit, Karan Malhotra, Chen Guang, Joe Li},
url = {https://www.github.com/NousResearch/Atropos},
month = {4},
year = {2025},
version = {0.1},
}
```
*(Note: Year/Version might need updating)*
---
## 14. License
Atropos is licensed under the Apache License 2.0. See the [LICENSE](LICENSE.md) file for details.