first commit

This commit is contained in:
Dakota Nous 2025-04-29 12:10:10 -07:00
commit 621d00dd80
89 changed files with 15315 additions and 0 deletions

60
.github/ISSUE_TEMPLATE/bug_report.md vendored Normal file
View file

@ -0,0 +1,60 @@
---
name: Bug report
about: Create a report to help us improve
title: ''
labels: ''
assignees: ''
---
## Describe the Issue
<!-- A clear and concise description of what the bug or issue is. -->
## Environment/API Details
- **Environment Class/Name:** [e.g., `atroposlib.envs.MyCustomEnv`, `gsm8k_server.py`]
- **Environment Configuration (`BaseEnvConfig` or subclass):** (Optional)
```yaml
# Paste relevant config values here
group_size: 4
max_token_length: 2048
# ... etc.
```
- **API Endpoint/Method Involved:** (If applicable) [e.g., `/register_env`, `/get_status`, `env.collect_trajectory()`]
## Steps to Reproduce
<!-- Detailed steps to reproduce the behavior: -->
1. Initialize environment `...` with config `...`
2. Call method `get_next_item()` and receive `Item` `...`
3. Call method `collect_trajectory()` or `collect_trajectories()` with `Item` `...`
4. Observe issue `...` (e.g., incorrect `ScoredDataGroup`, error during API call)
## Interaction Details (if applicable)
<!-- Provide details about the specific interaction step where the issue occurs. -->
- **Input `Item` to `collect_trajectory`:**
```python
# Paste relevant Item details here
```
- **Output `ScoredDataGroup` (or error):**
```python
# Paste relevant ScoredDataGroup or traceback here
```
- **Expected `ScoredDataGroup` / Behavior:**
## Setup Details
- **OS:** [e.g. macOS, Windows, Linux]
- **Python Version:** [e.g. 3.10, 3.11]
- **`Atropos` Version:** [e.g. output of `pip show atropos` or commit hash]
- **Relevant Libraries/Versions:** [e.g., `pydantic==2.5.0`, `aiohttp==3.9.0`, `transformers==4.35.0`]
## Additional Context & Logs
<!-- Add any other context about the problem here. Include relevant logs or screenshots. -->
```log
# Paste relevant logs here
```

View file

@ -0,0 +1,20 @@
---
name: Feature request
about: Suggest an idea for this project
title: ''
labels: ''
assignees: ''
---
**Is your feature request related to a problem? Please describe.**
A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]
**Describe the solution you'd like**
A clear and concise description of what you want to happen.
**Describe alternatives you've considered**
A clear and concise description of any alternative solutions or features you've considered.
**Additional context**
Add any other context or screenshots about the feature request here.

View file

@ -0,0 +1,44 @@
<!--
╭───────────────────────────────────────────────────────────╮
│ ✨ RL-ENVIRONMENTS PULL REQUEST TEMPLATE ✨ │
│ Fill out each field → delete guidance placeholders. │
│ Incomplete items slow down review & scoring. │
╰───────────────────────────────────────────────────────────╯
-->
## 🔖 Environment Snapshot
| Field | Your Entry |
|-------|------------|
| **Environment Name** | <!-- e.g. "SudokuVerifier-v0" --> |
| **Short Description** | <!-- One-sentence purpose/goal. --> |
| **Category** | <!-- Select: Verifiable-Reasoning / RLAIF / RLHF / Other --> |
| **Dataset Needed?** | <!-- No / Yes (link & license) --> |
| **External Deps** | <!-- Extra pip packages, system libs, etc. --> |
| **Environmental Variables** | <!-- variable name(s) --> |
| **Expected Episode Length** | <!-- e.g. 128 timesteps --> |
| **Compute Footprint Estimate** | <!-- "<1 GB RAM, <1 min CPU verification" or similar --> |
---
## 🧪 Zero-Training Test Results
<details>
**W&B Link:**
**Examples of the Environment scoring a good example and a bad example:**
</details>
## ✅ Developer & Reviewer Checklist
- [ ] Code follows project style (black, isort, flake8 pass with pre-commit).
- [ ] I have performed a self-review of my own code
- [ ] Docstrings added for all new public classes / functions.
- [ ] If .env vars required, did you add it to the .env.example in repo root?
- [ ] Automatic rollout script (`scripts/run_smoke_test.py`) runs without training and reproduces the metrics above.
- [ ] I have commented my code, particularly in hard-to-understand areas
- [ ] I have made corresponding changes to the documentation to include my environment
- [ ] My changes generate no new warnings
- [ ] New and existing unit tests pass locally with my changes
---

View file

@ -0,0 +1,30 @@
## Description
<!-- Briefly describe the changes introduced by this pull request. -->
## Related Issues
<!-- Link any relevant issues here. Use "Closes #issue_number" to automatically close issues. -->
## Type of Change
<!-- Please delete options that are not relevant. -->
- [ ] Bug fix (non-breaking change which fixes an issue)
- [ ] New feature (non-breaking change which adds functionality)
- [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected)
- [ ] This change requires a documentation update
- [ ] Code refactor (no functional changes)
- [ ] Build/CI/CD related changes
- [ ] Other (please describe):
## Checklist
- [ ] My code follows the style guidelines of this project
- [ ] I have performed a self-review of my own code
- [ ] I have commented my code, particularly in hard-to-understand areas
- [ ] I have made corresponding changes to the documentation
- [ ] My changes generate no new warnings
- [ ] I have added tests that prove my fix is effective or that my feature works
- [ ] New and existing unit tests pass locally with my changes
- [ ] Any dependent changes have been merged and published in downstream modules

184
.gitignore vendored Normal file
View file

@ -0,0 +1,184 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# UV
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
#uv.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
.pdm.toml
.pdm-python
.pdm-build/
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
.idea/
# PyPI configuration file
.pypirc
# data
*.json
*.jsonl
# wandb
wandb
# Allow MCP configuration
!configs/mcp.json
# gradio
.gradio/

25
.pre-commit-config.yaml Normal file
View file

@ -0,0 +1,25 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.5.0
hooks:
- id: trailing-whitespace
- id: end-of-file-fixer
- id: check-yaml
- id: check-added-large-files
- repo: https://github.com/psf/black
rev: 24.1.1
hooks:
- id: black
- repo: https://github.com/pycqa/isort
rev: 5.13.2
hooks:
- id: isort
args: ["--profile", "black", "--filter-files"]
- repo: https://github.com/pycqa/flake8
rev: 7.0.0
hooks:
- id: flake8
args: ["--max-line-length=120", "--extend-ignore=E203,W503"]

44
CODE_OF_CONDUCT.md Normal file
View file

@ -0,0 +1,44 @@
# Code of Conduct
## Our Pledge
We, as contributors and maintainers of the Atropos project, pledge to foster an open and welcoming environment. We aim to make participation in our project a positive, respectful, and harassment-free experience for everyone, regardless of background or experience level.
## Our Standards
Examples of behavior that contributes to creating a positive environment include:
* Being respectful of differing viewpoints and experiences
* Gracefully accepting constructive criticism
* Focusing on what is best for the community and the project
* Showing empathy towards other community members
Examples of unacceptable behavior include:
* Unwelcome personal attacks or criticism
* Excessive trolling, insulting/derogatory comments, and personal attacks
* Public or private harassment
* Publishing others' private information without explicit permission
* Other conduct which could reasonably be considered inappropriate in a professional setting
## Our Responsibilities
Project maintainers are responsible for clarifying the standards of acceptable behavior and are expected to take appropriate and fair corrective action in response to any instances of unacceptable behavior.
Project maintainers have the right and responsibility to remove, edit, or reject comments, commits, code, issues, and other contributions that are not aligned with this Code of Conduct.
## Scope
This Code of Conduct applies within all project spaces, including GitHub repositories, issue trackers, and related communication channels. It also applies when an individual is representing the project or its community in public spaces.
## Enforcement
Instances of abusive, harassing, or otherwise unacceptable behavior may be reported by contacting the project maintainers. All complaints will be reviewed and investigated promptly and fairly.
Project maintainers are obligated to respect the privacy and security of the reporter of any incident.
## Attribution
This Code of Conduct is adapted from general open source community standards and GitHub's community guidelines.
Remember: Respect each other, collaborate constructively, and focus on making Atropos better for everyone.

67
CONFIG.md Normal file
View file

@ -0,0 +1,67 @@
# AtroposLib Configuration
This document outlines the configuration options available for the `atroposlib` library, primarily defined using Pydantic models.
These configurations are often managed via a command-line interface built using `pydantic-cli`, especially when using the `serve` command provided by environment classes inheriting from `BaseEnv`.
## Base Environment Configuration (`atroposlib.envs.base.BaseEnvConfig`)
Basic environment configuration settings.
| Parameter | Type | Default | Description |
| :------------------------------- | :----------------------- | :---------------------------------------------- | :--------------------------------------------------------------------------------------------------------- |
| `group_size` | `int` | `4` | How many responses are grouped together for scoring. |
| `max_num_workers` | `int` | `-1` | Maximum number of workers to use. `-1` calculates from `max_num_workers_per_node`. |
| `max_eval_workers` | `int` | `16` | Maximum number of workers to use for evaluation. |
| `max_num_workers_per_node` | `int` | `8` | Maximum number of workers to use per node. |
| `steps_per_eval` | `int` | `100` | Number of steps to take before evaluating. |
| `max_token_length` | `int` | `2048` | Maximum token length used in generations. |
| `eval_handling` | `EvalHandlingEnum` | `EvalHandlingEnum.STOP_TRAIN` | How to handle evaluations (`STOP_TRAIN`, `LIMIT_TRAIN`, `NONE`). |
| `eval_limit_ratio` | `float` | `0.5` | Ratio of training workers to limit during evals (used if `eval_handling` is `LIMIT_TRAIN`). |
| `inference_weight` | `float` | `1.0` | Inference weight. Set to `-1` to ignore if doing something special. |
| `batch_size` | `int` | `-1` | Batch size for training. Usually set by the trainer via the API. |
| `max_batches_offpolicy` | `int` | `3` | Maximum number of off-policy batches to have in the queue. |
| `tokenizer_name` | `str` | `"NousResearch/DeepHermes-3-Llama-3-1B-Preview"` | Hugging Face tokenizer to use. |
| `use_wandb` | `bool` | `True` | Whether to use Weights & Biases for logging. |
| `rollout_server_url` | `str` | `"http://localhost:8000"` | URL of the rollout server (FastAPI interface). |
| `total_steps` | `int` | `1000` | Total number of steps to run. |
| `wandb_name` | `str | None` | `None` | Name to be grouped by in WandB. |
| `num_rollouts_to_keep` | `int` | `32` | Number of rollouts to display on WandB. |
| `num_rollouts_per_group_for_logging` | `int` | `1` | Number of rollouts per group to keep for logging. `-1` keeps all. |
| `ensure_scores_are_not_same` | `bool` | `True` | Ensure that scores within a group are not identical (usually `True`). |
| `data_path_to_save_groups` | `str | None` | `None` | Path to save generated groups as a JSONL file. If set, groups will be written here. |
| `min_items_sent_before_logging` | `int` | `2` | Minimum number of items sent to the API before logging metrics. `0` or less logs every time. |
## Server Manager Configuration (`atroposlib.envs.server_handling.server_manager.ServerManagerConfig`)
Settings for the `ServerManager`.
| Parameter | Type | Default | Description |
| :-------- | :------ | :------ | :------------------------------------------------ |
| `slurm` | `bool` | `True` | Whether the environment is running on SLURM. |
| `testing` | `bool` | `False` | If `True`, uses mock OpenAI data for testing. |
## Server Baseline Configuration (`atroposlib.envs.server_handling.server_manager.ServerBaseline`)
Baseline configuration used by `ServerManager` if a list of `OpenaiConfig` is not provided, particularly for setting up local or SLURM-based server discovery.
| Parameter | Type | Default | Description |
| :------------------------- | :------ | :-------- | :------------------------------------------------------------------------------------------------------ |
| `timeout` | `int` | `1200` | Timeout for the request in seconds. |
| `num_max_requests_at_once` | `int` | `512` | Maximum number of concurrent requests (training). Divide this by the generation `n` parameter. |
| `num_requests_for_eval` | `int` | `64` | Maximum number of concurrent requests for evaluation. |
| `model_name` | `str` | `default` | Model name to use when calling inference servers. |
| `rolling_buffer_length` | `int` | `1000` | Length of the rolling buffer to store server metrics (like request timings, attempts). |
## OpenAI Server Configuration (`atroposlib.envs.server_handling.openai_server.OpenaiConfig`)
Configuration for individual OpenAI-compatible API servers (including local SGLang/vLLM instances).
| Parameter | Type | Default | Description |
| :------------------------- | :----------- | :-------- | :------------------------------------------------------------------------------------------------------ |
| `api_key` | `str \| None` | `None` | API key for OpenAI API. Use `"x"` or any non-empty string for local servers that don't require auth. |
| `base_url` | `str \| None` | `None` | URL of the API endpoint. `None` for official OpenAI API, otherwise the local server URL (e.g., `http://localhost:9004/v1`). |
| `timeout` | `int` | `1200` | Timeout for the request in seconds. |
| `num_max_requests_at_once` | `int` | `512` | Maximum number of concurrent requests (training). Divide this by the generation `n` parameter. |
| `num_requests_for_eval` | `int` | `64` | Maximum number of concurrent requests for evaluation. |
| `model_name` | `str` | `default` | The model name to use. Required for both OpenAI and local models (e.g., `"gpt-4"`, `"NousResearch/..."`). |
| `rolling_buffer_length` | `int` | `1000` | Length of the rolling buffer to store server metrics (like request timings, attempts). |

150
CONTRIBUTING.md Normal file
View file

@ -0,0 +1,150 @@
# Contributing to Atropos
First off, thank you for considering contributing to Atropos! It's people like you that make open source projects such great tools.
We welcome any type of contribution, not just code. You can help with:
* **Reporting a bug**
* **Discussing the current state of the code**
* **Submitting a fix**
* **Proposing new features**
* **Becoming a maintainer**
## We Develop with GitHub
We use GitHub to host the code, track issues and feature requests, and accept pull requests.
## We Use GitHub Flow
We follow the [GitHub Flow](https://guides.github.com/introduction/flow/index.html) development workflow. All code changes happen through Pull Requests.
## Getting Started
### Project Setup
1. **Fork the repository:** Click the "Fork" button on the top right of the [repository page](https://github.com/NousResearch/atropos). This creates your own copy of the project.
2. **Clone your fork:**
```bash
git clone https://github.com/your-username/atropos.git
cd atropos
```
3. **Set up the development environment:** This project uses standard Python `venv` for environment creation and `pip` for dependency management.
```bash
# Ensure you have Python 3.10+ installed
# Create and activate a virtual environment
python -m venv .venv
source .venv/bin/activate # On Windows use `.venv\Scripts\activate`
# Install dependencies, including development dependencies
pip install -e ".[dev]"
```
4. **Install pre-commit hooks:** This project uses `pre-commit` for code quality checks. The hooks will run automatically when you commit changes.
```bash
pre-commit install
```
### Running Tests
We use `pytest` for running tests. To run the test suite:
```bash
pytest
```
Ensure all tests pass before submitting a pull request.
## How to Contribute
### Reporting Bugs
We use GitHub issues to track public bugs. Report a bug by [opening a new issue](https://github.com/NousResearch/atropos/issues) (replace with the actual link if different).
When opening a bug report, please use the **Bug Report** issue template. This template is designed to gather the information we need to efficiently understand and resolve the issue.
**Great Bug Reports** tend to have:
* A quick summary and/or background.
* Steps to reproduce the bug:
* Be specific!
* Provide the exact commands run or a minimal code snippet if possible.
* What you expected to happen.
* What actually happened (including any error messages or logs).
* Your environment details (OS, Python version, relevant package versions).
* Notes (possibly including why you think this might be happening, or stuff you tried that didn't work).
Thorough bug reports help us address issues faster!
### Suggesting Enhancements
If you have an idea for a new feature or an improvement to an existing one, please open an issue first to discuss it. This allows us to coordinate efforts and ensure the suggestion aligns with the project's goals.
When suggesting an enhancement, please use the **Feature Request** issue template. This helps structure your request and provides context for maintainers and the community to better understand your suggestion.
### Submitting Changes (Pull Requests)
Pull requests are the best way to propose changes to the codebase. We actively welcome your pull requests:
1. **Fork the repo** and create your branch from `main`.
```bash
git checkout -b your-feature-or-fix-branch main
```
2. **Make your changes:** Write your code.
3. **Add tests:** If you've added code that should be tested, add tests.
4. **Update documentation:** If you've changed APIs or added features, update relevant documentation (README, docstrings, etc.).
5. **Ensure tests pass:** Run `pytest`.
6. **Ensure code lints and formats:** The pre-commit hooks will run automatically on commit. You can also run them manually: `pre-commit run --all-files`.
7. **Commit your changes:** Use clear and descriptive commit messages that explain the purpose of the changes.
```bash
git add .
git commit -m "Clearly describe the changes made in this commit"
```
8. **Push your branch:**
```bash
git push origin your-feature-or-fix-branch
```
9. **Open a Pull Request (PR):** Go to the original repository on GitHub and open a PR from your fork's branch to the `main` branch.
* Provide a clear title and description for your PR.
* Link any relevant issues (e.g., "Closes #123").
* Explain the changes you've made and why.
* **Follow the PR template**: We have two PR templates:
- For environment-related changes, use the `environment_pull_request_template.md`
- For all other changes, use the `non_environment_pull_request_template.md`
Please fill out the appropriate template thoroughly to help reviewers understand your changes.
## Code Style
This project uses standard Python code style (PEP 8) enforced by `black`, `flake8`, and `isort` via `pre-commit`. Please ensure your code adheres to these standards. The pre-commit hooks should help automate formatting and linting.
You can manually run the checks on all files using:
```bash
pre-commit run --all-files
```
This command will automatically fix formatting issues found by `black` and `isort`. However, you may need to manually address any linting errors reported by `flake8`.
## License for Contributions
Any contributions you make will be under the Apache License 2.0. In short, when you submit code changes, your submissions are understood to be under the same [Apache License 2.0](LICENSE) that covers the project. Feel free to contact the maintainers if that\'s a concern.
## Environment Contribution Guidelines
Since Atropos is focused on reinforcement learning environments, we encourage contributions of new training environments. However, please adhere to the following guidelines:
* **Legal compliance**: Do not submit environments that involve illegal activities or content.
* **GitHub compliance**: All contributions must comply with [GitHub's Terms of Service and Community Guidelines](https://docs.github.com/en/site-policy/github-terms/github-terms-of-service).
* **Explicit content**: Explicit environments may be considered, but must be:
* Clearly labeled as such
* Comply with all legal requirements
* **Game environments**: Game-based environments are welcome, but:
* Do not submit reverse-engineered commercial game environments that could lead to copyright or intellectual property issues
* Ensure you have the appropriate rights to any assets used
* Open-source games or games with permissive licenses are preferred
* **Ethical considerations**: Consider the ethical implications of your environment. Environments that encourage harmful behaviors without educational context may be rejected.
When in doubt about the appropriateness of an environment, please open an issue to discuss it before investing significant development effort.
## Code of Conduct
Please note that this project is released with a [Contributor Code of Conduct](CODE_OF_CONDUCT.md). By participating in this project you agree to abide by its terms.
Thank you again for your contribution!

21
LICENSE Normal file
View file

@ -0,0 +1,21 @@
MIT License
Copyright (c) 2025 Nous Research
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

224
README.md Normal file
View file

@ -0,0 +1,224 @@
# Atropos - Nous Research's LLM RL Gym
![newatr-02](https://github.com/user-attachments/assets/e9b64e10-340e-48f2-835c-ae28fa14730a)
<div align="center">
*In Greek mythology, Atropos was the eldest of the three Fates. While her sisters spun and measured the threads of mortal lives, Atropos alone held the shears that would cut these threads, determining the final destiny of each soul. Just as Atropos guided souls to their ultimate fate, this system guides language models toward their optimal potential through reinforcement learning.*
</div>
<div align="center">
</div>
<div id="badges" align="center">
<a href="https://huggingface.co/NousResearch">
<img src="https://img.shields.io/badge/NousResearch-orange?style=for-the-badge&logo=huggingface&logoColor=white" alt="HuggingFace"/>
</a>
<a href="https://nousresearch.com">
<img src="https://img.shields.io/badge/NousResearch.com-white?style=for-the-badge&logo=data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAACQAAAAlCAYAAAAqXEs9AAAAIGNIUk0AAHomAACAhAAA+gAAAIDoAAB1MAAA6mAAADqYAAAXcJy6UTwAAAAJcEhZcwAAFiUAABYlAUlSJPAAAAAGYktHRAD/AP8A/6C9p5MAAAAldEVYdGRhdGU6Y3JlYXRlADIwMjUtMDQtMjlUMTU6NDI6MjcrMDA6MDAUtMrgAAAAJXRFWHRkYXRlOm1vZGlmeQAyMDI1LTA0LTI5VDE1OjQyOjI3KzAwOjAwZelyXAAAACh0RVh0ZGF0ZTp0aW1lc3RhbXAAMjAyNS0wNC0yOVQxNTo0MjoyNyswMDowMDL8U4MAAAhJSURBVFhHzVhZTJVXEB4RNxBBwBXiwqIo7juKorjjFn1QaVKrUROLRkxKjKKNItWkJrYuscYXouVNrCa0LO5bArgHUFxRqIJAFRAQUJbpfMP9/96LF7RvfMnJ/Zfzn/OdmW/mzLntiIiltRk4WH7bDEwLxcXFkb+/P7Vrh0etw+hTX19PtbW19OnTJ33m5OREXbp0kWsHamioJ+YvGx99njx5QuvWrdN7k9CVK1do+vTpuGwVjY2N9OFDtbQqys/Ppxs3btC1a9eopKSEevfuTSNGjKCQkOk0fvx4cnXtZvmqZYDQpUuXaPbs2ZYnTYRYCLFMZrcBZWVlfPr0aY74PoKXLl3Kq1Z9x0eOHOGHD3M4PT2DV6xYwR06dNCxxEo8bdo0TkhI4Nraj/q9vXHRGhoa+MKFC/qdpbVMCBB3cHJyMs+cOZM7depk/SE7ODiwWIIvXrzIVVVVvG3bdpMUmrOzM2/dupXfv6/QsZqPj/bVhIDS0lLeEb2D3dzcrD8wm+hGm6+vL6empiqpVd+uMt+DcLdu3fjAgQMserMZ32jNCdmNMgi0pOQf2r59O/3y6y+ilw/k4uJCjo6OpqA7duxIPj4+NGvWLNq580eSQamoqIiid0STENT3aJMnT6bOnTtTTk6O+W1rsEsIUZOU9BdlZ2eLMF1pwoQJ5OHhQQEBARQcHEwDBgygPn360JYtWygqKoqGDx9OeXl59ODBAxo8eDCJzmjNmjW0ePFiioyMpFGjR5Pojd69e/dFUo6WXxP4ACudOHEi9erVi8QVNGnSJCFZRzU11TRw4EASV1L37t2lzyRKSDhFERERVF5erikAhOfNn0d37t7RtLBv3z7y9vKijJs3acGCBbRkyRLLTPbxGSEMevbsWQlHIk9PD6qoqKCTJ09KXmnQSdFEXtS3b19tCPeXL19qOrh69SrdlIm9vb3p6bNnlC9Wq6mpoYLXr6myspISE/+k+fPnkwjfMpt9qJggauDp06c8btw4FuuomCFMNPRp3749i7vMe/yKrvTauA8LW8AhISEcGBio4S+5iSVhsrieBw0axI8fP9Z5vlrUyJqPHj0yrYGVy/f6Dlbq2bOn6gnAO7jFAO5zch7SM7EOngtBFTa++/jxI70WS927d8/S2z4+IwR9VFdX6wAGevToIQM3mVkSJI0ePUavAUxqAPp78+aNuvHFixc6Tl1dHUk+0l+4738TQlQhvK2BMA4P/4bElSRupNCZobRo0SIKDQ1VkQ4bNky/gSWxEFgHBHAP7SBlwEq4z8rKUqKtRZv6ztBQXl4+Bw4NZDG16kMm1qQng7CsmvfExrLsebx5cyRnZmbK8xrVhaQEUwdiUe7Xr59+D93Jps2O8ot3SKKyB+pcX9SQ9KH+/fvRrt27KCgoSPMOcPv2bV0hQn7unLkkGVlckktpaWnqMuhKthXtKwQ0F40aNUpTA76DtZzEbQBSCtzZGpSZ9dYB1iJMPnXqFE+dOpU3bdrE5bK5AmVl5ZyYmMgpKan89u1bfVZYWMhCwlxl8+bu7s6SCsz7gwcP6ndftJAB+NfPz0/ziZOzE23cuJFcRTsyhpYUCxcupHnz5prRBqEjIlsCdAQLGrq5e/euWs4e7BIygIjz8/UnWaFGiDVADg149eqVDSG4zTow4DIQMNwKYWNse8JulZC4RAbpqCuEVlqKDBG1TZoAUcMCRi7CggxC2PeQq+yhVUJFRcW6mvSMDDpx4gRly+aJ5NccSKQAJsfOjj6G9WAplLU1EuqwHN5je8qQMe2hVULFxUV069YtHWzkyJHkLlFjDVgMK8/NzdV77G2oDKxhuK9WLAirISchr6H0NWpxa9gQgq9LS8v0GsmtoKBAN9fLly9rrewlu7Z1ZgbwXqJMryUiaezYsXoNYDKQwcQYDw2kkFKwRSGjt0gIK0E9A/OjE3INfA0kJSXR+fMX9BrvMCjIAxAzogxo7gpDd8jMcCGI4Xfo0KFa9EF7FRWVNqTMUMDRBYQgzuDgKSS7vroCnSHqqKgfpLy4on0RVSghYmL2mOQwuZtbd62H7t+/r+QcHNrpryFwTZByVML2BEIYRw4BYsX2+h4wCTU2Nu3Ihw4dEu0U67EIK4K/QQwN76yB9yjOQBoEx4wZI24oVtcA9fXYv/6rBgAvyW25uS/MNJGVlSm7Q3+9BmwE4enpSV27dqWf9+9XC4WHh+vKsSJ7RRVcefjwYXUBFgMrokI0CBmRZg0nibgzZ/5QiaA9f/7c7A/YEAoIGKKFO4hs2LBBrQKzDhky5LPBEb7Yr5YvX06xsT9JmRukpAz32AMsiaiFoLELIBjgOiReAyYh5I6AACnQUR+LSI8dO6YnhWXLlql1rFcB+Pj4kpzDdPDS0ncUH/+7nFhDNOcgGmFt1OT+/oM0QrExwyIQOIBnKSkpNGfOHE0FBiBvXTo0gwGxCgjx6NHfxDp/y0mhlJKTkzTbYqc2AFdGR0eru0AWWV33LOl3U6wAcc+YMYNcurrQe0kNGRnpSuDcuXNaBUg5q5rbu3evWg3HKQDy3o2L1atX6/EGgEW8vb3o+vUblJefR3JMptfiOmtCAPIU9IV9CZNjgqTkZAoLC6P169dronT3cJffPipc5B00fLN27Vo9RkGzKEfi4+MtozZZyCzQrCF6kKN0Hcupg0Uz2g9FF4p3FPErV67ktLR0Fp1xVlY2x8TEcFxcnJYV9iCaZNEmS6K1PGkqQazLDzPskQTT09Mtd02AWwoKCtWscCM0gPMaNkkkQ5SwcOXx48dVBxDplClT9Cgkc1lGsQXew8VIoEYf/ItiwNRQW4HtxtQG0MYIEf0L1N75qS9kGwUAAAAASUVORK5CYII=" alt="Website"/>
</a>
<a href="https://x.com/NousResearch">
<img src="https://img.shields.io/badge/@NousResearch-black?style=for-the-badge&logo=X&logoColor=white" alt="@NousResearch"/>
</a>
</div>
Atropos is a Language Model Reinforcement Learning Environments framework for collecting and evaluating LLM trajectories through diverse environments including:
<div align="center">
| Environment Type | Examples | Purpose |
|---------------------------|--------------------------------------------|----------------------------------------------------|
| 📚 Dataset environments | GSM8K, MMLU | Evaluate and improve LLM performance on static data|
| 🎮 Online environments | Crosswords, Hangman | Train LLMs through interactive game-based learning |
| 🤖 RLAIF and RLHF | LLM Judge/Reward Models | Fine-tune LLMs using human feedback and alignment |
| 🔄 Multi-Turn RL | deepresearch, internal tool calling | Train LLMs on complex multi-step interactions |
</div>
Atropos is a robust, scalable framework for **Reinforcement Learning Environments with LLMs**. Key features:
- **Multi-Turn & Asynchronous RL:** Efficiently supports complex, multi-turn, and asynchronous interactions, decoupling environment steps from policy updates.
- **Inference Agnostic:** Integrates with standard inference APIs (e.g., OpenAI, vLLM, SGLang), enabling easy switching between LLM providers and frameworks.
- **Trainer Independent:** Offers a standardized training interface for experimenting with different RL algorithms and frameworks without major code changes.
- **Scalable & Decentralized:** Easily scale by launching more environment instances (locally or across decentralized resources) that contribute rollouts to a central service.
- **Diverse Environment Integration:** Manages many varied environment types concurrently for heterogeneous, multi-modal training.
The goal: provide a flexible, scalable, and standardized platform to accelerate LLM-based RL research across diverse, interactive settings.
## 🎉 Upcoming Atropos Hackathon: LLM RL Environments
Join us in San Francisco on May 18th, 2025 for an exciting hackathon focused on building and experimenting with LLM RL Environments! This in-person event will bring together researchers and developers interested in advancing the field of LLM reinforcement learning.
More details coming soon! Follow us on Twitter [@NousResearch](https://x.com/NousResearch) to stay updated.
---
## Experimental results from models trained using Atropos' environments
We have been able to achieve significant improvements on specific domains or tasks with Atropos - Below are some of the results.
**Tool Calling Environment Results:**
<div align="center">
| Berkeley Function Calling Benchmark Type | Base Model | With Atropos RL | Improvement |
|---------------|------------|-----------------|-------------|
| Parallel Tasks| 10% | 46% | **4.6x** ⬆️ |
| Simple Tasks | 21% | 51.75% | **2.5x** ⬆️ |
</div>
Model Artifact:
https://huggingface.co/NousResearch/DeepHermes-ToolCalling-Specialist-Atropos
Environment Used:
https://github.com/NousResearch/Atropos/environments/tool_calling_server.py
---
**Financial Fundamentals Prediction Environment Results**:
<div align="center">
| Metric | Initial Accuracy | With Atropos RL | Improvement |
|--------|-----------------|-----------------|-------------|
| Directional Prediction Eval Accuracy | 20% | 50% | **2.5x** 📈 |
</div>
Model Artifact:
https://huggingface.co/NousResearch/DeepHermes-Financial-Fundamentals-Prediction-Specialist-Atropos
Environment Used:
https://github.com/NousResearch/Atropos/environments/fundamental_prediction_environment.py
---
## RLAIF Experiment Artifacts
Using the RLAIF Environment to change the personality of the model, we have produced several artifacts of interesting and weird personalities.
**DeepHermes Egregore v1 and v2 8B:**
https://huggingface.co/NousResearch/DeepHermes-Egregore-v1-RLAIF-8b-Atropos
https://huggingface.co/NousResearch/DeepHermes-Egregore-v2-RLAIF-8b-Atropos
**DeepHermes Ascension Maze 8B:**
https://huggingface.co/NousResearch/DeepHermes-AscensionMaze-RLAIF-8b-Atropos
---
## Navigating the Repo
| Category | Description |
|----------|------------|
| 📁 [`atroposlib/`](atroposlib/) | Core library containing base classes and utilities |
| 🎮 [`environments/`](environments/) | Collection of ready-to-use RL environments |
| 📚 [`example_trainer/`](example_trainer/) | Example training scripts and configurations |
Key Documents:
- [Base Environment Class](atroposlib/envs/README.md) - Documentation for creating custom environments
- [Environments Overview](environments/README.md) - Documentation for existing environments
- [Full Environment Config Options](CONFIG.md) - Documentation for creating custom environments
- [Example Trainer](example_trainer/README.md) - Getting started with training
- [Slurm Guide](SLURM.md) - Guide for using Atropos with Slurm for distributed inference
- [Contributing Guide](CONTRIBUTING.md) - Guidelines for contributors
- [License](LICENSE.md) - Apache 2.0 license details
---
## Installation
Get your Python 3.10 (or later) environment ready, then simply pip install:
```bash
pip install atroposlib
```
If you're looking to get into developing the repo or using the environments:
```bash
pip install -e . # for using
pip install -e .[dev] # for development
pip install -e .[examples] # for running examples
pip install -e .[all] # for everything
```
**Important:** If you're committing to the repository, please install the pre-commit hooks:
```bash
pre-commit install
```
---
### Quick Start Guide
1. **Create Your First Environment**
- Review our [Base Class Documentation](atroposlib/envs/README.md) to understand the core concepts
- Check out existing environments in the [`environments/`](environments) directory for examples
2. **Run an Example Environment**
```bash
# Start the API server and run the GSM8K environment
run-api & python environments/gsm8k_server.py serve \
--tokenizer_name="Qwen/Qwen2.5-1.5B-Instruct" \
--model_name="Qwen/Qwen2.5-1.5B-Instruct"
```
3. **Training Your Model**
- Follow our [training example guide](example_trainer/README.md) for detailed instructions
- Monitor progress through our built-in logging and reporting system:
- Completion lengths
- Evaluation accuracies
- Full rollouts and scores
You can use multiple environments at once, just point them all to the same server.
Environments come with detailed logging and reporting support, runs track completion lengths, eval accuracies, full rollouts and scores, and more:
![image](https://github.com/user-attachments/assets/153a2932-191a-42e3-8da9-25a1b05abb8e)
---
## Debugging Tools
The trajectory-handler provides several debugging tools to help environment developers test and understand their environments locally without requiring the full distributed infrastructure.
* **Flexible Model Provider Support:** Atropos natively supports any model provider that adheres to the OpenAI API standard. Simply provide the provider's base URL and your API key, and Atropos can integrate with their models seamlessly for testing or running environments locally.
After launching the API and your selected environments (e.g. `run-api & python environments/gsm8k_server.py serve`), you are then able to view them to get a quick look, or try to prepare some datasets for some offline training:
* **View Run (`view-run`):** Launch a Gradio UI to inspect batches of rollouts generated by your environment runs. This is useful for visually debugging the interactions and data flow.
* **Offline Data Generation:** Use `atropos-sft-gen` and `atropos-dpo-gen` to collect rollouts from environments and convert them into formats suitable for Supervised Fine-Tuning (SFT) or Direct Preference Optimization (DPO).
---
## Citation
If you have found the library helpful in your work, you can cite this repository as:
```latex
@misc{atropos,
title = {{Atropos - An Async First Environment Rollout Controller}},
author = {Dakota Mahan, Roger Jin, Teknium, Shannon Sands, Artem Yatsenko, Jai Suphavadeeprasit, Karan Malhotra, Chen Guang, Joe Li},
url = {https://www.github.com/NousResearch/Atropos},
month = {4},
year = {2025},
version = {0.1},
}
```
---
## Contributing
Atropos is built by the open-source AI community, and relies on our amazing contributors! Please see our [contributing](CONTRIBUTING.md) guide for more details on our code formatting, testing, etc.
Please follow the [Code of Conduct](CODE_OF_CONDUCT.md).
---
## License
Atropos is licensed as Apache 2.0, see the [LICENSE](LICENSE.md) file here for more information

175
SLURM.md Normal file
View file

@ -0,0 +1,175 @@
## Using `ServerManager` with Slurm
The `ServerManager` class in `atroposlib` provides built-in support for discovering and managing inference servers distributed across nodes allocated by Slurm. Here's how to use it:
**Core Concept:**
The setup assumes you have a Slurm job allocation where:
1. One or more nodes are designated for your main "training" or orchestrator process (the script that initializes `ServerManager`).
2. The remaining nodes in the allocation are dedicated to running the LLM inference servers (e.g., SGLang, TGI, vLLM, etc., accessible via an OpenAI-compatible API).
**How `ServerManager` Detects Servers:**
When you initialize `ServerManager` with `slurm=True`:
1. It reads the `SLURM_JOB_NODELIST` environment variable to get the hostnames of all allocated nodes. It uses the `scontrol show hostnames` command internally.
2. It reads the `NUM_TRAINING_NODES` environment variable. This crucial variable tells the manager how many nodes *at the beginning* of the nodelist are *reserved for the training/orchestrator process* and should **not** be treated as inference server nodes.
3. It iterates through the hostnames *after* the first `NUM_TRAINING_NODES`. These are assumed to be the inference nodes.
4. For each inference node, it constructs potential server URLs. By default, it assumes:
* Servers run on ports starting from `9000` (`9000`, `9001`, `9002`, ...).
* The number of server instances per node is determined by `8 // INFER_TP` (where `INFER_TP` is another environment variable, defaulting to 1 if not set, implying 8 servers per node). You should set `INFER_TP` according to your inference server's tensor parallelism configuration if applicable.
* The URL format is `http://{node_hostname}:{port}/v1`.
5. It uses the *first* configuration object you pass in the `configs` list as a template (for settings like `timeout`, `num_max_requests_at_once`, etc.) and creates specific `OpenaiConfig` objects for each discovered URL.
6. The `ServerManager` then load-balances requests across these automatically configured `OpenAIServer` instances.
**Setup Steps:**
1. **Launch Inference Servers:** In your Slurm submission script (`sbatch`), launch your inference server instances on the designated inference nodes.
* Ensure they listen on the correct hostname and the expected ports (9000, 9001, ...).
* The number of instances per node should match the `8 // INFER_TP` logic. Adjust the port range or `INFER_TP` environment variable accordingly if your setup differs.
* You might use `srun` to launch these processes on specific nodes.
2. **Set Environment Variables:** In the part of your Slurm script that launches your *main application* (the one using `ServerManager`):
* `export NUM_TRAINING_NODES=<number_of_non_inference_nodes>` (e.g., `export NUM_TRAINING_NODES=1` if only the first node runs the main script).
* `export INFER_TP=<your_tensor_parallel_size>` (Optional, defaults to 1. Set this if your inference servers use tensor parallelism and you run fewer than 8 instances per node).
3. **Initialize `ServerManager`:** In your Python script:
```python
from atroposlib.envs.server_handling.server_manager import ServerManager, ServerBaseline, OpenaiConfig
# Provide at least one config object. It will be used as a template
# for Slurm-discovered servers if slurm=True.
# If you pass ServerBaseline, ensure NUM_TRAINING_NODES and potentially INFER_TP are set.
# If you pass a list of OpenaiConfig, the first one is used as the template.
base_config = ServerBaseline(
timeout=1200,
# other baseline settings...
)
# OR
# base_config = OpenaiConfig(
# base_url="http://dummy", # This URL is ignored when slurm=True finds nodes
# api_key="dummy",
# timeout=1200,
# # other config settings...
# )
server_manager = ServerManager(
configs=base_config, # Or [base_config] if using OpenaiConfig
slurm=True
)
# Now use server_manager.chat_completion(...) or server_manager.completion(...)
```
4. **Submit Slurm Job:** Submit your job ensuring the necessary nodes and resources (like GPUs for inference) are requested.
**Example Conceptual Slurm Script:**
```bash
#!/bin/bash
#SBATCH --nodes=5 # 1 trainer node + 4 inference nodes
#SBATCH --ntasks-per-node=1
#SBATCH --gpus-per-node=8 # Assuming 8 GPUs/node for inference
#SBATCH --job-name=atropos-rl
# Get allocated node hostnames
nodes=$(scontrol show hostnames $SLURM_JOB_NODELIST)
nodes_array=($nodes)
training_node=${nodes_array[0]}
inference_nodes=${nodes_array[@]:1} # Nodes from index 1 onwards
echo "Training Node: $training_node"
echo "Inference Nodes: ${inference_nodes[@]}"
# --- Launch Inference Servers (Example using srun, adapt for your server type) ---
TP_SIZE=1 # Example: Tensor Parallelism = 1
INSTANCES_PER_NODE=$((8 / TP_SIZE))
echo "Launching $INSTANCES_PER_NODE inference servers per node..."
for node in ${inference_nodes[@]}; do
for i in $(seq 0 $((INSTANCES_PER_NODE - 1))); do
port=$((9000 + i))
gpu_id=$i # Basic GPU assignment, might need refinement
echo "Starting server on $node:$port (GPU $gpu_id)"
srun --nodes=1 --ntasks=1 --gpus-per-task=1 --gpu-bind=map_gpu:$gpu_id --nodelist=$node \
your_inference_server_launch_cmd --host 0.0.0.0 --port $port --tp $TP_SIZE [other_args] &
done
done
echo "Waiting for servers to start..."
sleep 60 # Simple wait, consider a more robust check
# --- Launch W&B Watcher on each Inference Node ---
echo "Launching W&B watchers..."
# Assume the main API server runs on the training_node at default port 8000
TRAINER_API_ADDR="http://${training_node}:8000"
inference_node_index=0 # Start index for node_num
for node in ${inference_nodes[@]}; do
echo "Starting watcher on $node (Node Index $inference_node_index)"
srun --nodes=1 --ntasks=1 --nodelist=$node \
python atroposlib/cli/inference_node_wandb_watcher.py \
--api_addr $TRAINER_API_ADDR \
--tp $TP_SIZE \
--node_num $inference_node_index &
inference_node_index=$((inference_node_index + 1))
done
# --- Launch Main Application on the Training Node ---
export NUM_TRAINING_NODES=1
export INFER_TP=$TP_SIZE
echo "Starting main application on $training_node..."
srun --nodes=1 --ntasks=1 --nodelist=$training_node \
python your_main_atropos_script.py --some_arg=value
echo "Job finished."
wait # Wait for background server processes launched with '&'
```
**Important Notes:**
* This setup relies on the `scontrol` command being available in the environment where `ServerManager` is initialized.
* Ensure network connectivity and firewall rules allow the training node(s) to reach the inference nodes on ports 9000+.
* The logic assumes a specific port assignment (9000+) and server count based on `INFER_TP`. If your inference server setup differs (e.g., different ports, different discovery mechanism), you would need to modify `server_manager.py` or manually provide the correct list of `OpenaiConfig` objects instead of relying on `slurm=True`.
## Monitoring Inference Nodes with Weights & Biases
Atropos includes a utility script, `inference-node-wandb-watcher`, located in `atroposlib/cli/`, designed to run on each inference node alongside the inference servers.
**Purpose:**
* **Health Monitoring:** Periodically checks the `/health_generate` endpoint of each local inference server instance (assuming ports 9000+).
* **W&B Logging:** Logs the health status (1 for healthy, 0 for unhealthy) of each server instance to a shared Weights & Biases run group. This allows you to visualize server uptime and availability directly in your W&B dashboard alongside your training metrics.
* **Step Synchronization:** It fetches the current training step from the main Atropos API server (`run-api`) to ensure W&B logs are correctly associated with training progress.
**Integration into Slurm Script:**
You can launch this watcher on each inference node using `srun` similarly to how the inference servers are launched. Add the following section to the example Slurm script, **after** launching the inference servers and **before** launching the main application:
```bash
# --- Launch W&B Watcher on each Inference Node ---
echo "Launching W&B watchers..."
# Assume the main API server runs on the training_node at default port 8000
TRAINER_API_ADDR="http://${training_node}:8000"
inference_node_index=0 # Start index for node_num
for node in ${inference_nodes[@]}; do
echo "Starting watcher on $node (Node Index $inference_node_index)"
srun --nodes=1 --ntasks=1 --nodelist=$node \
python atroposlib/cli/inference_node_wandb_watcher.py \
--api_addr $TRAINER_API_ADDR \
--tp $TP_SIZE \
--node_num $inference_node_index &
inference_node_index=$((inference_node_index + 1))
done
```
**Explanation of Arguments:**
* `--api_addr`: This is the address of the main Atropos API server (usually started with `run-api`). The script needs this to fetch W&B project/group info and the current training step. In the example, we construct it assuming the API runs on the `training_node` (first node in the allocation) at port `8000` (the default for `run-api`). **Ensure this port is correct and accessible from the inference nodes.**
* `--tp`: This should be the same tensor parallelism size (`TP_SIZE`) used when launching the inference servers. It tells the watcher how many server instances (ports 9000 to 9000 + `8 // TP_SIZE` - 1) to monitor on the local node.
* `--node_num`: A unique integer identifying this specific inference node within the Slurm job. This helps distinguish the metrics from different nodes in W&B (e.g., `server/server_heath_0_0`, `server/server_heath_1_0`). The example script assigns sequential indices starting from 0.
**Important Notes:**
* Ensure the `run-api` server is running and accessible from the inference nodes.
* The `inference-node-wandb-watcher` script should be executable and accessible from the inference nodes.
* The script assumes the default port for the `run-api` server (8000). If your setup uses a different port, you may need to modify the script or the port in the `TRAINER_API_ADDR` construction.

0
atroposlib/__init__.py Normal file
View file

182
atroposlib/api/README.md Normal file
View file

@ -0,0 +1,182 @@
# Trajectory Handler API
## Overview
The AtroposLib API is a FastAPI application designed to act as a central buffer and aggregator for reinforcement learning (RL) experience data. Its primary purpose is to decouple RL data generation (by "Rollout Handlers" or "Environments") from RL data consumption (by one or more "Trainers"), particularly in distributed online RL settings.
This service specifically handles the **experience data pathway**:
* Rollout Handlers connect and push trajectories (tokens, masks, scores, etc.).
* The API buffers this data in a queue.
* Trainers connect and pull processed batches of experience data for training updates.
**Important:** This service does *not* handle the distribution of updated policies from the Trainer back to the Rollout Handlers/Inference Servers. That part of the online RL loop is assumed to be handled by a separate mechanism.
## Features
* Centralized, in-memory queue for RL trajectory data.
* Registration endpoints for Trainers and Rollout Handlers.
* Serves batches of aggregated experience data to Trainers.
* Supports heterogeneous environments with weighting (via `/register-env` weight and internal batching).
* Provides status endpoints for monitoring queue size and training step count.
* Basic integration with Weights & Biases (W&B) project/group info.
* Endpoints for Rollout Handlers to disconnect gracefully.
* Debug endpoint to retrieve the latest submitted data sample.
## Architecture Context
This API typically sits within a larger RL system:
1. **Rollout Handlers:** Instances simulating the environment. They interact with Inference Servers to get actions based on the current policy and send resulting trajectory data (`ScoredData`) to this AtroposLib API (`/scored_data`).
2. **Inference Servers (External):** Serve the current policy (e.g., via an OpenAI-compatible API). Receive policy updates directly from the Trainer. *Not part of this service.*
3. **AtroposLib API (This Service):** Buffers and batches experience data received from Rollout Handlers.
4. **Trainer(s):** Pull batches of experience data from this API (`/batch`), compute gradients, update the policy, and push updated policies directly to the Inference Servers.
## Running the Server
with the repository installed we provide a helper script to run the server:
```bash
run-api
```
if you need more control over the server you can run it directly with:
```bash
uvicorn atroposlib.api.server:app --host 0.0.0.0 --port 8000 --reload
```
* `--host 0.0.0.0`: Makes the server accessible on your network.
* `--port 8000`: Specifies the port (change if needed).
* `--reload`: Enables auto-reloading on code changes (for development). Remove for production.
The API documentation (Swagger UI) will be available at `http://<your-server-ip>:8000/docs`.
## API Endpoints
### General
* `GET /`
* **Description:** Root endpoint for basic health check.
* **Response:** `{"message": "AtroposLib API"}`
### Trainer Registration & Info
* `POST /register`
* **Description:** Called once by the Trainer process to initialize the server state for a training run. Resets state if called again.
* **Request Body:** `Registration` model
```python
class Registration(BaseModel):
wandb_group: str
wandb_project: str
batch_size: int
max_token_len: int # Max token length expected in trajectories
checkpoint_dir: str # Shared location for checkpoints
save_checkpoint_interval: int
starting_step: int
num_steps: int # Total expected training steps
```
* **Response:** `{"uuid": <generated_uuid_int>}`
* `GET /wandb_info`
* **Description:** Retrieve W&B group and project info set during registration.
* **Response:** `{"group": <group_name_or_null>, "project": <project_name_or_null>}`
* `GET /info`
* **Description:** Retrieve batch size and max token length set during registration.
* **Response:** `{"batch_size": <size_or_-1>, "max_token_len": <len_or_-1>}`
* `GET /status`
* **Description:** Get the current training step (based on batches served) and queue size.
* **Response:** `{"current_step": <step_count>, "queue_size": <queue_length>}`
### Rollout Handler Registration & Info
* `POST /register-env`
* **Description:** Called by each Rollout Handler instance to register itself.
* **Request Body:** `RegisterEnv` model
```python
class RegisterEnv(BaseModel):
max_token_length: int # Max length this env produces
desired_name: str # Base name for identification/logging
weight: float # Weight for sampling/batching (e.g., 1.0)
```
* **Response:** Provides assigned ID, unique W&B name, checkpoint info.
```json
{
"status": "success",
"env_id": <assigned_env_id_int>,
"wandb_name": <generated_unique_name>,
"checkpoint_dir": <checkpoint_dir_from_registration>,
"starting_step": <current_server_step>,
"checkpoint_interval": <interval_from_registration>,
"num_steps": <num_steps_from_registration>
}
```
* `POST /disconnect-env`
* **Description:** Allows a Rollout Handler to signal it's disconnecting gracefully.
* **Request Body:** `EnvIdentifier` model `{"env_id": <registered_env_id_int>}`
* **Response:** `{"status": "success"}` or `{"status": "failure", "error": ...}`
* `GET /status-env`
* **Description:** Called by a Rollout Handler to get general status plus its calculated sampling weight relative to other connected environments.
* **Query Parameter:** Requires `env: EnvIdentifier` model (e.g., `?env_id=0` - actual implementation might differ slightly, check FastAPI docs for query parameter models). **Note:** The code shows `env: EnvIdentifier` as a body parameter for a GET request, which is non-standard. This might need adjustment or testing. Assuming it works via query or a POST instead.
* **Response:** `{"current_step": <step>, "queue_size": <size>, "env_weight": <calculated_weight_float>}`
### Data Handling
* `POST /scored_data`
* **Description:** Endpoint for Rollout Handlers to push a single chunk of trajectory data.
* **Request Body:** `ScoredData` model
```python
class ScoredData(BaseModel):
tokens: List[List[int]]
masks: List[List[int]]
scores: List[float]
ref_logprobs: Optional[List[List[float]]] = None
overrides: Optional[List[dict]] = None # Per-item logging overrides
group_overrides: Optional[dict] = None # Group logging overrides
```
* **Response:** `{"status": "received"}`
* `POST /scored_data_list`
* **Description:** Endpoint for Rollout Handlers to push a list of `ScoredData` chunks.
* **Request Body:** `List[ScoredData]`
* **Response:** `{"status": "received", "groups_processed": <count>}`
* `GET /batch`
* **Description:** Called by the Trainer to request a batch of data for training. The server uses internal logic (`grab_exact_from_heterogeneous_queue`) to form a batch of the configured size from the available data in the queue, potentially respecting environment weights. The server increments its internal step counter when a batch is successfully formed and returned.
* **Response:**
* Success: `{"batch": [<data_item_1>, ..., <data_item_N>]}` where each `data_item` matches the structure pushed via `/scored_data`.
* Not enough data: `{"batch": null}`
* `GET /latest_example`
* **Description:** Debug endpoint to retrieve the most recently added `ScoredData` item.
* **Response:** The last `ScoredData` dictionary pushed, or empty lists if none yet.
### Debugging
* `GET /reset_data`
* **Description:** **Warning:** Resets all server state, including the queue, configuration, registered environments, and step count. Use with caution during development/debugging.
* **Response:** Plain text `Reset successful` with HTTP status 200.
## Common Workflow Example
1. **Start Server:** Launch the `AtroposLib` API server.
2. **Trainer Initialization:** The main Trainer process sends a `POST /register` request with run parameters.
3. **Rollout Handler Initialization:** Each Rollout Handler starts and sends `POST /register-env`.
4. **Data Generation:** Handlers run simulations, collect data, and send `POST /scored_data` or `POST /scored_data_list` periodically.
5. **Training Loop:**
* The Trainer (e.g., Rank 0 in distributed setup) enters a loop:
* Calls `GET /batch`.
* If `batch` is not `null`:
* (Distribute batch to other ranks if applicable).
* Perform training step.
* Optionally call `GET /status` for monitoring.
* If `batch` is `null`:
* Wait briefly (`time.sleep`) and retry `GET /batch`.
* mermaid diagram of how a trainer interacts with the api is located [here](trainer_interaction.md).
* (In distributed setups, other ranks (1..N-1) might poll `GET /status` to wait for the step counter to increment before expecting the broadcasted batch from Rank 0).
* The envs periodically poll `GET /status-env` to check their status and sampling weight.
* In asynchronous setups, they may stop at a maximum off-policy step count.
* mermaid diagram of how a rollout handler interacts with the api is located [here](env_interaction.md).
6. **Shutdown:** Handlers may call `POST /disconnect-env`.
## Limitations & TODOs
* **In-Memory State:** The primary limitation is that all queues, configurations, and states are stored in the FastAPI application's memory (`app.state`).
* **No Persistence:** Data is lost if the server restarts.
* **Scalability Bottleneck:** API cannot scale beyond a single server instance easily.

View file

@ -0,0 +1,3 @@
from .server import app
__all__ = ["app"]

View file

@ -0,0 +1,70 @@
```mermaid
sequenceDiagram
participant RH as Rollout Handler
participant API as AtroposLib API
%% --- Initialization ---
RH->>API: POST /register-env (Send env details)
activate API
API-->>RH: Response (env_id, starting_step, wandb_name, ...) %% wandb_name is unique to this handler
deactivate API
Note over RH: Store env_id and unique wandb_name.
Note over RH: Fetch W&B configuration (Assumes Trainer already called /register)
RH->>API: GET /wandb_info
activate API
API-->>RH: Response {"group": wb_group, "project": wb_project}
deactivate API
Note over RH: Initialize wandb logging (e.g., wandb.init) using group=wb_group, project=wb_project, name=wandb_name.
Note over RH: Know target batch_size (from config?). Set off_policy_tolerance (e.g., 3). Set internal state = 'Running'.
loop Simulation Loop
%% --- Check Pause State & Generate/Send Data ---
alt State is 'Running'
Note over RH: Generating data using internal environment logic...
%% (Internal simulation steps, action selection, etc., happen here - details are opaque to the API)
Note over RH: Trajectory chunk collected (contains tokens, masks, scores...). Log env-specific metrics to wandb (e.g., episode reward, length).
%% --- Send Data ---
RH->>API: POST /scored_data or /scored_data_list (Send collected chunk)
activate API
API-->>RH: Ack {"status": "received", ...}
deactivate API
else State is 'Paused'
Note over RH: Currently paused, skipping data generation and sending. Will check status again.
%% Implement delay/sleep here to avoid busy-checking status when paused
end
%% --- Periodic Queue Size Check (Pause/Resume Logic) ---
Note over RH: Checking API queue status to decide pause/resume state.
RH->>API: GET /status-env (using stored env_id)
activate API
API-->>RH: Response {"current_step": T_current, "queue_size": Q, "env_weight": W}
deactivate API
Note over RH: T_current might be logged or used for other internal reasons by the handler. Log queue size Q?
Note over RH: Calculate threshold = off_policy_tolerance * batch_size
alt Check if queue size exceeds threshold (Q > threshold)
Note over RH: Queue size (Q = Q) > threshold. Setting internal state to 'Paused'.
opt State was 'Running'
Note over RH: Stopping data generation. Log pause event to wandb.
end
else Queue size is acceptable (Q <= threshold)
Note over RH: Queue size (Q = Q) <= threshold. Ensuring state is 'Running'.
opt State was 'Paused'
Note over RH: Resuming data generation. Log resume event to wandb.
end
end
end %% End Simulation Loop
%% --- Optional Shutdown ---
RH->>API: POST /disconnect-env (using stored env_id)
activate API
API-->>RH: Ack {"status": "success"}
deactivate API
Note over RH: Finalize wandb logging (wandb.finish).
```

305
atroposlib/api/server.py Normal file
View file

@ -0,0 +1,305 @@
import time
import uuid
from typing import Any, List, Optional
from fastapi import FastAPI, status
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import PlainTextResponse
from pydantic import BaseModel
from atroposlib.api.utils import grab_exact_from_heterogeneous_queue
app = FastAPI(title="AtroposLib API")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.get("/")
async def root():
return {"message": "AtroposLib API"}
class Registration(BaseModel):
wandb_group: str
wandb_project: str
batch_size: int
max_token_len: int
checkpoint_dir: str
save_checkpoint_interval: int
starting_step: int
num_steps: int
class RegisterEnv(BaseModel):
max_token_length: int
desired_name: str
weight: float
class EnvIdentifier(BaseModel):
env_id: int
class ScoredData(BaseModel):
tokens: List[List[int]]
masks: List[List[int]]
scores: List[float]
ref_logprobs: Optional[List[List[float]]] = None
overrides: Optional[List[dict]] = None
group_overrides: Optional[dict] = None
images: Optional[Any] = None
class Status(BaseModel):
"""
basemodel for status information of the current server
"""
current_step: int
queue_size: int
class Info(BaseModel):
"""
basemodel for useful information
"""
batch_size: int = -1
@app.post("/register")
async def register(registration: Registration):
try:
isinstance(app.state.queue, list)
except AttributeError:
app.state.queue = []
app.state.group = registration.wandb_group
app.state.project = registration.wandb_project
app.state.batchsize = int(registration.batch_size)
app.state.max_token_len = int(registration.max_token_len)
app.state.status_dict = {"step": registration.starting_step}
app.state.checkpoint_dir = registration.checkpoint_dir
app.state.save_checkpoint_interval = registration.save_checkpoint_interval
app.state.num_steps = registration.num_steps
app.state.curr_batch = []
app.state.started = False
app.state.envs = []
try:
app.state.requesters.append(uuid.uuid4().int)
except AttributeError:
# If requesters doesn't exist, create it
app.state.requesters = [uuid.uuid4().int]
return {"uuid": app.state.requesters[-1]}
@app.post("/register-env")
async def register_env_url(register_env: RegisterEnv):
try:
isinstance(app.state.envs, list)
except AttributeError:
app.state.envs = []
checkpoint_dir = ""
try:
checkpoint_dir = app.state.checkpoint_dir
except AttributeError:
pass
real_name = (
f"{register_env.desired_name}_"
f"{len([x for x in app.state.envs if x['desired_name'] == register_env.desired_name])}"
)
registered_id = len(app.state.envs)
app.state.envs.append(
{
"max_context_len": register_env.max_token_length,
"weight": register_env.weight if register_env.weight is not None else 1.0,
"desired_name": register_env.desired_name,
"real_name": real_name,
"registered_id": registered_id,
"last_update": time.time(),
"connected": True,
}
)
return {
"status": "success",
"env_id": registered_id,
"wandb_name": real_name,
"checkpoint_dir": checkpoint_dir,
"starting_step": app.state.status_dict["step"],
"checkpoint_interval": app.state.save_checkpoint_interval,
"num_steps": app.state.num_steps,
}
@app.post("/disconnect-env")
async def disconnect_env(disconnect_env: EnvIdentifier):
try:
app.state.envs[disconnect_env.env_id]["connected"] = False
return {"status": "success"}
except (AttributeError, IndexError) as e:
return {"status": "failure", "error": str(e)}
@app.get("/wandb_info")
async def wandb_info():
try:
return {"group": app.state.group, "project": app.state.project}
except AttributeError:
return {"group": None, "project": None}
@app.get("/info")
async def info():
try:
return {
"batch_size": app.state.batchsize,
"max_token_len": app.state.max_token_len,
}
except AttributeError:
return {"batch_size": -1, "max_token_len": -1}
@app.get("/batch")
async def get_batch():
if not app.state.started:
app.state.started = True
if len(app.state.curr_batch) > 0:
return {"batch": app.state.curr_batch.pop()}
else:
new_batches = []
batch, app.state.queue = grab_exact_from_heterogeneous_queue(
app.state.queue, app.state.batchsize
)
while batch is not None:
new_batches.append(batch)
batch, app.state.queue = grab_exact_from_heterogeneous_queue(
app.state.queue, app.state.batchsize
)
steps_to_take = len(new_batches)
if steps_to_take == 0:
return {"batch": None}
app.state.status_dict["step"] += steps_to_take
# chunk it
for batch in new_batches:
app.state.curr_batch.append(batch)
curr_batch = app.state.curr_batch.pop()
# check length before sending
print(f"Sending batch of length {sum(len(x['tokens']) for x in curr_batch)}")
return {"batch": curr_batch}
@app.get("/latest_example")
async def get_latest_example():
try:
return app.state.latest
except AttributeError:
return {
"tokens": [],
"masks": [],
"scores": [],
"ref_logprobs": [],
"images": [],
}
@app.post("/scored_data")
async def scored_data(scored_data: ScoredData):
app.state.queue.append(
{
"tokens": scored_data.tokens,
"masks": scored_data.masks,
"scores": scored_data.scores,
"ref_logprobs": scored_data.ref_logprobs,
"overrides": scored_data.overrides,
"group_overrides": scored_data.group_overrides,
"images": scored_data.images,
}
)
app.state.latest = app.state.queue[-1]
return {"status": "received"}
@app.post("/scored_data_list")
async def scored_data_list(scored_data_list: List[ScoredData]):
"""Handle a list of ScoredData objects for step-based learning"""
for idx, scored_data in enumerate(scored_data_list):
app.state.queue.append(
{
"tokens": scored_data.tokens,
"masks": scored_data.masks,
"scores": scored_data.scores,
"ref_logprobs": scored_data.ref_logprobs,
"images": scored_data.images,
}
)
if scored_data_list:
app.state.latest = app.state.queue[-1]
return {"status": "received", "groups_processed": len(scored_data_list)}
@app.get("/status")
async def get_status():
try:
return {
"current_step": app.state.status_dict["step"],
"queue_size": len(app.state.queue),
}
except AttributeError:
return {"current_step": 0, "queue_size": 0}
@app.get("/status-env")
async def get_status_env(env: EnvIdentifier):
total = sum(
[
x["max_context_len"] * max(0.0, x["weight"])
for x in app.state.envs
if x["connected"]
]
)
env_weight = (
app.state.envs[env.env_id]["max_context_len"]
* app.state.envs[env.env_id]["weight"]
/ total
)
env_weight = max(
0.01, env_weight
) # Minimum weight of 0.01 :) TODO: try to figure out a better way to do this
try:
ret_dict = {
"current_step": app.state.status_dict["step"],
"queue_size": len(app.state.queue),
}
except AttributeError:
ret_dict = {"current_step": 0, "queue_size": 0}
ret_dict["env_weight"] = env_weight
return ret_dict
@app.get("/reset_data")
async def reset_data():
try:
del app.state.queue
app.state.group = None
app.state.project = None
app.state.batchsize = -1
app.state.num_steps = -1
app.state.status_dict = {"step": 0}
app.state.curr_batch = []
app.state.started = False
app.state.requesters = []
app.state.envs = []
except KeyError:
pass
return PlainTextResponse("Reset successful", status_code=status.HTTP_200_OK)

View file

@ -0,0 +1,58 @@
```mermaid
sequenceDiagram
participant R0 as Trainer Rank 0
participant R1N as Trainer Rank 1..N-1
participant API as AtroposLib API
R0->>API: POST /register (send Registration data)
activate API
API-->>R0: Respond with {'uuid': trainer_uuid}
deactivate API
Note over R0, R1N: Initialization complete. Trainer begins requesting data
loop Training Steps
%% --- Phase 2: Rank 0 fetches batch, others wait/poll ---
par Fetch vs Poll
loop While Batch is Null:
R0->>API: GET /batch
activate API
Note over API: Checks queue, potentially increments step counter if batch is formed.
alt Batch Available
API-->>R0: {'batch': [data_item_1, ...]}
Note over R0: Received batch for step S+1. Breaking loop.
else No Batch Available
API-->>R0: {'batch': null}
Note over R0: No batch ready yet. Will retry.
end
deactivate API
end
and
Note over R1N: Poll status until step increments from S.
loop While Server Step is S
R1N->>API: GET /status
activate API
API-->>R1N: {'current_step': S_new, 'queue_size': Q_new}
deactivate API
Note over R1N: Checking if S_new > S... (Current S_new = S_new)
%% In implementation, add delay here if S_new == S to avoid busy-wait
end
Note over R1N: Detected step incremented (S_new > S). Ready for broadcast.
end
%% --- Phase 3: Handle result ---
Note over R0: Broadcasts received batch data to Ranks 1..N-1 (External Mechanism)
Note over R1N: Receives broadcasted data from Rank 0.
Note over R0, R1N: All ranks now have the same batch for step S+1.
%% --- Phase 4: Perform Training Step ---
par Perform Training
R0->>R0: Perform training step with batch data
and
R1N->>R1N: Perform training step with batch data
end
Note over R0, R1N: Training step S+1 complete.
end # End Training Steps Loop
```

61
atroposlib/api/utils.py Normal file
View file

@ -0,0 +1,61 @@
from typing import Dict, List, Optional, Tuple
def grab_exact_from_heterogeneous_queue(
queue: List[Dict[str, List]], batch_size: int
) -> Tuple[Optional[List], List]:
"""
Grabs a batch of size batchsize from a queue of different sized items
e.g. queue = [{"tokens": [[1, 2, 3],[4, 5, 6, 7, 8]]}, {"tokens": [[9, 10]]}]
without going over the batchsize. This function will return a batch of size batchsize, and the new queue.
Because all groups are a common denominator of the batchsize, and all groups are a power of 2,
we can simplify a bit by assuming we can grab groups of groups to be equal to the maximum group size.
Note that we cannot drop items from groups, so we must grab the entire group if we grab it.
There may be a more efficient clearing mechanism by grouping these smaller groups heterogeneously, but
forcing them all into powers of two groups is a simple way to ensure we can grab a batch of the correct size.
:param queue:
:param batch_size:
:return: batch, new_queue
"""
# check if we can even potentially grab a batch
if sum(len(item["tokens"]) for item in queue) < batch_size:
return None, queue
# Get max batch size
max_group_size = max(len(group["tokens"]) for group in queue)
group_sizes = set(len(group["tokens"]) for group in queue)
group_batching_storage = {i: [] for i in group_sizes}
# pack the groups into [max_group_size // group_size] packs
potential_batch = []
for i, item in enumerate(queue):
key = len(item["tokens"])
group_batching_storage[key].append({"group": item, "indx": i})
if len(group_batching_storage[key]) * key == max_group_size:
potential_batch.extend(group_batching_storage[key])
group_batching_storage[key] = []
if (
sum(len(grouped_items["group"]["tokens"]) for grouped_items in potential_batch)
< batch_size
):
return None, queue
# we have a batch
batch = []
indxes_to_remove_from_queue = []
for item in potential_batch:
group = item["group"]
indx = item["indx"]
batch.append(group)
indxes_to_remove_from_queue.append(indx)
if sum(len(item["tokens"]) for item in batch) == batch_size:
break
if sum(len(item["tokens"]) for item in batch) != batch_size:
return None, queue
# remove the items from the queue
new_queue = [
item for i, item in enumerate(queue) if i not in indxes_to_remove_from_queue
]
return batch, new_queue

322
atroposlib/cli/dpo.py Normal file
View file

@ -0,0 +1,322 @@
import argparse
import asyncio
import os
import random
import aiohttp
import jsonlines
from tqdm.asyncio import tqdm # Import tqdm for async
from transformers import AutoTokenizer
def find_common_prefix(strings):
"""
Finds the longest common prefix among a list of strings.
Args:
strings: A list of strings.
Returns:
The longest common prefix string, or an empty string if the list is empty
or no common prefix exists.
"""
if not strings:
return ""
prefix = strings[0]
for s in strings[1:]:
while not s.startswith(prefix):
prefix = prefix[:-1]
if not prefix:
return ""
return prefix
async def register_to_api(group_size, max_token_len, api_url, num_steps):
"""
Registers this data grabber instance with the Atropos API.
This involves resetting any previous data on the server and then sending
configuration parameters for the current session.
Args:
group_size: The number of sequences processed per group by the API.
max_token_len: The maximum token length for sequences.
api_url: The base URL of the Atropos API server.
num_steps: The number of steps to run the API for.
"""
async with aiohttp.ClientSession() as session:
# Reset data on the API server before registering
async with session.get(f"{api_url}/reset_data") as response:
print(await response.text())
# Register this instance with its configuration
async with session.post(
f"{api_url}/register",
json={
"wandb_group": "test",
"wandb_project": "test",
"batch_size": group_size * 8,
"max_token_len": max_token_len,
"checkpoint_dir": "checkpoints",
"save_checkpoint_interval": 10,
"starting_step": 0,
"num_steps": num_steps * 2, # For a bit of a buffer just in case
},
) as response:
print("output of register is")
print(await response.text())
async def check_for_batch(api_url):
"""
Continuously polls the Atropos API until a batch of data is available.
Args:
api_url: The base URL of the Atropos API server.
Returns:
The batch data received from the API.
"""
while True:
async with aiohttp.ClientSession() as session:
async with session.get(f"{api_url}/batch") as response:
data = await response.json()
if data["batch"] is not None:
return data["batch"]
await asyncio.sleep(1) # Wait before polling again
def grab_group_data(
tok,
datagroup,
save_messages,
save_n_pairs_per_group,
allow_negative_scores=False,
minimum_score_diff_max_min=0.0,
):
"""
Processes a single group of data received from the API.
This function sorts the sequences within the group by score, filters them
based on scoring criteria, and formats them for saving.
Args:
tok: The Hugging Face tokenizer instance.
datagroup: A dictionary representing a group of sequences and their scores.
save_messages: Boolean indicating whether to save raw message structures
or decoded text completions.
save_n_pairs_per_group: The maximum number of sequences to save from this group.
allow_negative_scores: Boolean indicating whether to allow sequences with
negative scores.
minimum_score_diff_max_min: The minimum score difference required to save a pair.
Returns:
A list of processed and filtered sequences from the group, ready to be
written to the output file.
"""
if save_messages:
chats = datagroup["messages"]
else:
chats = [tok.decode(chat) for chat in datagroup["tokens"]]
# find common prefix
prefix = find_common_prefix(chats)
chats = [(prefix, chat.split(prefix)[1]) for chat in chats]
# sort chats by scores
scores = datagroup["scores"]
sorted_chats = [
(
{"prefix": x[0], "pos": x[1], "score": score}
if not save_messages
else {"pos": x, "score": score}
)
for score, x in sorted(
zip(scores, chats), key=lambda pair: pair[0], reverse=True
)
]
neg_sorted_chats = [
(
{"prefix": x[0], "completion": x[1], "score": score}
if not save_messages
else {"messages": x, "score": score}
)
for score, x in sorted(
zip(scores, chats), key=lambda pair: pair[0], reverse=False
)
]
neg_sorted_chats = neg_sorted_chats[:save_n_pairs_per_group]
if not allow_negative_scores:
sorted_chats = [x for x in sorted_chats if x["score"] > 0]
total_pairs = []
for i in range(min(save_n_pairs_per_group, len(sorted_chats))):
neg_candidates = [
x
for x in neg_sorted_chats
if x["score"] < sorted_chats[i]["score"] - minimum_score_diff_max_min
]
if len(neg_candidates) > 0:
if save_n_pairs_per_group > 0:
neg_candidate = random.choice(neg_candidates)
else:
neg_candidate = neg_sorted_chats[0] # worst negative candidate
# remove from neg_sorted_chats
neg_sorted_chats.remove(neg_candidate)
sorted_chats[i]["neg"] = (
neg_candidate["completion"]
if "completion" in neg_candidate
else neg_candidate["messages"]
)
total_pairs.append(sorted_chats[i])
return total_pairs
async def dpo_data_grabber(
filepath,
api_url,
group_size,
max_token_len,
tokenizer,
save_messages,
save_n_pairs_per_group,
num_seqs_to_save,
allow_negative_scores,
minimum_score_diff_max_min,
append_to_previous,
):
"""
Main asynchronous function to grab DPO data from the Atropos API.
It registers with the API, continuously fetches batches of data, processes
each batch, and writes the selected sequences to a JSONL file until the
desired number of sequences is saved.
Args:
filepath: Path to the output JSONL file.
api_url: Base URL of the Atropos API server.
group_size: Number of sequences processed per group by the API.
max_token_len: Maximum token length for sequences.
tokenizer: Hugging Face tokenizer model ID.
save_messages: Whether to save raw messages or decoded text.
save_n_pairs_per_group: Max sequences to save per group.
num_seqs_to_save: Total number of sequences to save.
allow_negative_scores: Whether to allow negative scores.
minimum_score_diff_max_min: Min score difference from group minimum.
append_to_previous: Whether to append to an existing file or overwrite.
"""
tok = AutoTokenizer.from_pretrained(tokenizer)
total_count = 0
async def grab_batch(jsonl_writer: jsonlines.Writer):
data = await check_for_batch(api_url)
count = 0
for group in data:
for item in grab_group_data(
tok,
group,
save_messages,
save_n_pairs_per_group,
allow_negative_scores,
minimum_score_diff_max_min,
):
jsonl_writer.write(item)
count += 1
return count
await register_to_api(group_size, max_token_len, api_url)
if os.path.exists(filepath) and not append_to_previous:
raise ValueError("File already exists and append_to_previous is False.")
with open(filepath, "w" if not append_to_previous else "a") as f:
jsonl_writer = jsonlines.Writer(f)
with tqdm(total=num_seqs_to_save, desc="Grabbing DPO data", unit="seq") as pbar:
while total_count < num_seqs_to_save:
batch_count = await grab_batch(jsonl_writer)
total_count += batch_count
pbar.update(min(batch_count, num_seqs_to_save - total_count))
def main():
parser = argparse.ArgumentParser(
description="Grab SFT data from an Atropos API instance."
)
parser.add_argument(
"filepath",
type=str,
default="sft_data.jsonl",
help="Path to the output JSONL file for SFT data.",
)
parser.add_argument(
"--api-url",
type=str,
default="http://localhost:8000",
help="Base URL for the Atropos API server.",
)
parser.add_argument(
"--group-size",
type=int,
default=2,
help="Number of sequences processed per group by the API.",
)
parser.add_argument(
"--max-token-len",
type=int,
default=2048,
help="Maximum token length for sequences.",
)
parser.add_argument(
"--tokenizer",
type=str,
default="NousResearch/DeepHermes-3-Llama-3-8B-Preview",
help="Hugging Face tokenizer model ID (used if --save-messages is not set).",
)
parser.add_argument(
"--save-messages",
action="store_true",
help="Save raw message structures instead of decoded text completions, if your environment supports it.",
)
parser.add_argument(
"--save-n-pairs-per-group",
type=int,
default=3,
help="Maximum number of paired sequences to save from each group.",
)
parser.add_argument(
"--num-seqs-to-save",
type=int,
default=100,
help="Total number of sequences to save before stopping.",
)
parser.add_argument(
"--allow-negative-scores",
action="store_true",
help="Allow sequences with negative scores to be saved.",
)
parser.add_argument(
"--minimum-score-diff-max-min",
type=float,
default=0.5,
help="Minimum score difference from the group minimum required to save a sequence.",
)
parser.add_argument(
"--append-to-previous",
action="store_true",
help="Append to the previous file instead of overwriting it.",
)
args = parser.parse_args()
asyncio.run(
dpo_data_grabber(
args.filepath,
args.api_url,
args.group_size,
args.max_token_len,
args.tokenizer,
args.save_messages,
args.save_n_pairs_per_group,
args.num_seqs_to_save,
args.allow_negative_scores,
args.minimum_score_diff_max_min,
args.append_to_previous,
)
)
if __name__ == "__main__":
main()

View file

@ -0,0 +1,66 @@
import argparse
import time
import requests
import wandb
def update_wandb(health_statuses):
wandb.log(health_statuses)
def run(api_addr, tp, node_num):
print(f"Starting up with {api_addr}, {tp}, {node_num}", flush=True)
while True:
try:
data = requests.get(f"{api_addr}/wandb_info").json()
wandb_group = data["group"]
wandb_project = data["project"]
except (requests.exceptions.ConnectionError, requests.exceptions.Timeout):
wandb_project = None
wandb_group = None
print("Waiting for init...")
if wandb_project is None:
time.sleep(1)
else:
wandb.init(
project=wandb_project, group=wandb_group, name=f"inf_node_{node_num}"
)
break
curr_step = 0
health_statuses = {
f"server/server_heath_{node_num}_{i}": 0.0 for i in range(8 // tp)
}
while True:
data = requests.get(f"{api_addr}/status").json()
step = data["current_step"]
if step > curr_step:
wandb.log(health_statuses, step=step)
curr_step = step
time.sleep(60)
# Check on each server
for i in range(8 // tp):
try:
health_status = requests.get(
f"http://localhost:{9000 + i}/health_generate"
).status_code
health_statuses[f"server/server_heath_{node_num}_{i}"] = (
1 if health_status == 200 else 0
)
except (requests.exceptions.ConnectionError, requests.exceptions.Timeout):
health_statuses[f"server/server_heath_{node_num}_{i}"] = 0
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--api_addr", type=str, required=True)
parser.add_argument("--tp", type=int, required=True)
parser.add_argument("--node_num", type=int, required=True)
args = parser.parse_args()
run(args.api_addr, args.tp, args.node_num)
if __name__ == "__main__":
main()

View file

@ -0,0 +1,9 @@
import uvicorn
def main():
uvicorn.run("atroposlib.api:app", host="0.0.0.0", port=8000, reload=True)
if __name__ == "__main__":
main()

318
atroposlib/cli/sft.py Normal file
View file

@ -0,0 +1,318 @@
import argparse
import asyncio
import os
import aiohttp
import jsonlines
from tqdm.asyncio import tqdm # Import tqdm for async
from transformers import AutoTokenizer
def find_common_prefix(strings):
"""
Finds the longest common prefix among a list of strings.
Args:
strings: A list of strings.
Returns:
The longest common prefix string, or an empty string if the list is empty
or no common prefix exists.
"""
if not strings:
return ""
prefix = strings[0]
for s in strings[1:]:
while not s.startswith(prefix):
prefix = prefix[:-1]
if not prefix:
return ""
return prefix
async def register_to_api(group_size, max_token_len, api_url, num_steps):
"""
Registers this data grabber instance with the Atropos API.
This involves resetting any previous data on the server and then sending
configuration parameters for the current session.
Args:
group_size: The number of sequences processed per group by the API.
max_token_len: The maximum token length for sequences.
api_url: The base URL of the Atropos API server.
num_steps: The number of steps to run the API for.
"""
async with aiohttp.ClientSession() as session:
# Reset data on the API server before registering
async with session.get(f"{api_url}/reset_data") as response:
print(await response.text())
# Register this instance with its configuration
async with session.post(
f"{api_url}/register",
json={
"wandb_group": "test",
"wandb_project": "test",
"batch_size": group_size * 8,
"max_token_len": max_token_len,
"checkpoint_dir": "checkpoints",
"save_checkpoint_interval": 10,
"starting_step": 0,
"num_steps": num_steps * 2, # For a bit of a buffer just in case
},
) as response:
print("output of register is")
print(await response.text())
async def check_for_batch(api_url):
"""
Continuously polls the Atropos API until a batch of data is available.
Args:
api_url: The base URL of the Atropos API server.
Returns:
The batch data received from the API.
"""
while True:
async with aiohttp.ClientSession() as session:
async with session.get(f"{api_url}/batch") as response:
data = await response.json()
if data["batch"] is not None:
return data["batch"]
await asyncio.sleep(1) # Wait before polling again
def grab_group_data(
tok,
datagroup,
save_messages,
save_top_n_per_group,
allow_negative_scores=False,
minimum_score_diff_max_min=0.0,
):
"""
Processes a single group of data received from the API.
This function sorts the sequences within the group by score, filters them
based on scoring criteria, and formats them for saving.
Args:
tok: The Hugging Face tokenizer instance.
datagroup: A dictionary representing a group of sequences and their scores.
save_messages: Boolean indicating whether to save raw message structures
or decoded text completions.
save_top_n_per_group: The maximum number of sequences to save from this group.
allow_negative_scores: Boolean indicating whether to allow sequences with
negative scores.
minimum_score_diff_max_min: The minimum score difference from the group's
minimum score required to save a sequence.
Returns:
A list of processed and filtered sequences from the group, ready to be
written to the output file.
"""
if save_messages:
# Use raw message structures if specified
chats = datagroup["messages"]
else:
# Decode tokens into text and find common prefix/completion pairs
chats = [tok.decode(chat) for chat in datagroup["tokens"]]
prefix = find_common_prefix(chats)
# Split each chat into (prefix, completion)
chats = [(prefix, chat.split(prefix)[1]) for chat in chats]
scores = datagroup["scores"]
# Sort chats by score in descending order
sorted_chats = [
(
{"prefix": x[0], "completion": x[1], "score": score}
if not save_messages
else {"messages": x, "score": score}
)
for score, x in sorted(
zip(scores, chats), key=lambda pair: pair[0], reverse=True
)
]
# Apply filtering based on score criteria
if not allow_negative_scores:
sorted_chats = [x for x in sorted_chats if x["score"] > 0]
if minimum_score_diff_max_min > 0:
# Ensure the score is sufficiently higher than the minimum score in the group
min_score = min(scores) if scores else 0 # Handle empty scores list
sorted_chats = [
x
for x in sorted_chats
if x["score"] - min_score > minimum_score_diff_max_min
]
# Return only the top N sequences
return sorted_chats[:save_top_n_per_group]
async def sft_data_grabber(
filepath,
api_url,
group_size,
max_token_len,
tokenizer,
save_messages,
save_top_n_per_group,
num_seqs_to_save,
allow_negative_scores,
minimum_score_diff_max_min,
append_to_previous,
):
"""
Main asynchronous function to grab SFT data from the Atropos API.
It registers with the API, continuously fetches batches of data, processes
each batch, and writes the selected sequences to a JSONL file until the
desired number of sequences is saved.
Args:
filepath: Path to the output JSONL file.
api_url: Base URL of the Atropos API server.
group_size: Number of sequences processed per group by the API.
max_token_len: Maximum token length for sequences.
tokenizer: Hugging Face tokenizer model ID.
save_messages: Whether to save raw messages or decoded text.
save_top_n_per_group: Max sequences to save per group.
num_seqs_to_save: Total number of sequences to save.
allow_negative_scores: Whether to allow negative scores.
minimum_score_diff_max_min: Min score difference from group minimum.
append_to_previous: Whether to append to an existing file or overwrite.
"""
tok = AutoTokenizer.from_pretrained(tokenizer)
total_count = 0
async def grab_batch(jsonl_writer: jsonlines.Writer):
"""Fetches and processes one batch of data, returning the count."""
data = await check_for_batch(api_url)
count = 0
for group in data:
for item in grab_group_data(
tok,
group,
save_messages,
save_top_n_per_group,
allow_negative_scores,
minimum_score_diff_max_min,
):
jsonl_writer.write(item)
count += 1
return count
# Register with the API first
await register_to_api(group_size, max_token_len, api_url, num_steps=total_count)
# Check for file existence before opening
if os.path.exists(filepath) and not append_to_previous:
raise ValueError(
f"File '{filepath}' already exists and --append-to-previous is False."
)
# Open the file in write or append mode
file_mode = "a" if append_to_previous and os.path.exists(filepath) else "w"
with open(filepath, file_mode) as f:
jsonl_writer = jsonlines.Writer(f)
# Use tqdm for progress bar
with tqdm(total=num_seqs_to_save, desc="Grabbing SFT data", unit="seq") as pbar:
while total_count < num_seqs_to_save:
batch_count = await grab_batch(jsonl_writer)
total_count += batch_count
pbar.update(min(batch_count, num_seqs_to_save - total_count))
def main():
"""Parses command-line arguments and runs the SFT data grabber."""
parser = argparse.ArgumentParser(
description="Grab SFT data from an Atropos API instance."
)
parser.add_argument(
"filepath",
type=str,
default="sft_data.jsonl",
help="Path to the output JSONL file for SFT data.",
)
parser.add_argument(
"--api-url",
type=str,
default="http://localhost:8000",
help="Base URL for the Atropos API server.",
)
parser.add_argument(
"--group-size",
type=int,
default=2,
help="Number of sequences processed per group by the API.",
)
parser.add_argument(
"--max-token-len",
type=int,
default=2048,
help="Maximum token length for sequences.",
)
parser.add_argument(
"--tokenizer",
type=str,
default="NousResearch/DeepHermes-3-Llama-3-8B-Preview",
help="Hugging Face tokenizer model ID (used if --save-messages is not set).",
)
parser.add_argument(
"--save-messages",
action="store_true",
help="Save raw message structures instead of decoded text completions, if your environment supports it.",
)
parser.add_argument(
"--save-top-n-per-group",
type=int,
default=3,
help="Maximum number of highest-scoring sequences to save from each group.",
)
parser.add_argument(
"--num-seqs-to-save",
type=int,
default=100,
help="Total number of sequences to save before stopping.",
)
parser.add_argument(
"--allow-negative-scores",
action="store_true",
help="Allow sequences with negative scores to be saved.",
)
parser.add_argument(
"--minimum-score-diff-max-min",
type=float,
default=0.0,
help="Minimum score difference from the group minimum required to save a sequence.",
)
parser.add_argument(
"--append-to-previous",
action="store_true",
help="Append to the previous file instead of overwriting it.",
)
args = parser.parse_args()
# Run the main async function
asyncio.run(
sft_data_grabber(
args.filepath,
args.api_url,
args.group_size,
args.max_token_len,
args.tokenizer,
args.save_messages,
args.save_top_n_per_group,
args.num_seqs_to_save,
args.allow_negative_scores,
args.minimum_score_diff_max_min,
args.append_to_previous,
)
)
if __name__ == "__main__":
main()

105
atroposlib/cli/view_run.py Normal file
View file

@ -0,0 +1,105 @@
import argparse
import asyncio
import aiohttp
import gradio as gr
from transformers import AutoTokenizer
def find_common_prefix(strings):
if not strings:
return ""
prefix = strings[0]
for s in strings[1:]:
while not s.startswith(prefix):
prefix = prefix[:-1]
if not prefix:
return ""
return prefix
async def register_to_api(group_size, max_token_len):
async with aiohttp.ClientSession() as session:
async with session.get("http://localhost:8000/reset_data") as response:
print(await response.text())
print(group_size)
async with session.post(
"http://localhost:8000/register",
json={
"wandb_group": "test",
"wandb_project": "test",
"batch_size": group_size
* 8, # * 8 just in case you want to just sample from a large group
"max_token_len": max_token_len,
"checkpoint_dir": "checkpoints",
"save_checkpoint_interval": 10,
"starting_step": 0,
"num_steps": 69,
},
) as response:
print("output of register is")
print(await response.text())
async def check_for_batch():
while True:
async with aiohttp.ClientSession() as session:
async with session.get("http://localhost:8000/batch") as response:
data = await response.json()
print(data)
if data["batch"] is not None:
return data["batch"]
await asyncio.sleep(1)
async def build_interface(group_size, max_token_len, tokenizer, port):
async def grab_batch():
tok = AutoTokenizer.from_pretrained(tokenizer)
data = await check_for_batch()
print(data)
chats = [tok.decode(chat) for chat in data[0]["tokens"]]
# find common prefix
prefix = find_common_prefix(chats)
return (
(prefix,)
+ tuple([chat.split(prefix)[1] for chat in chats[:group_size]])
+ tuple(data[0]["scores"][:group_size])
)
with gr.Blocks() as demo:
prefix_blk = gr.Textbox(label="Prefix")
with gr.Row():
score_blks = [gr.Textbox(label=f"Score_{i+1}") for i in range(group_size)]
with gr.Row():
outputs_blks = [
gr.Textbox(label=f"Output_{i+1}") for i in range(group_size)
]
with gr.Row():
grab_next = gr.Button(value="Grab Next Batch")
grab_next.click(
fn=grab_batch,
outputs=[prefix_blk] + outputs_blks + score_blks,
api_name="get_batch",
)
await register_to_api(group_size, max_token_len)
demo.launch(server_port=port, share=True)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--port", type=int, default=9001)
parser.add_argument("--group-size", type=int, default=2)
parser.add_argument("--max-token-len", type=int, default=2048)
parser.add_argument(
"--tokenizer", type=str, default="NousResearch/DeepHermes-3-Llama-3-8B-Preview"
)
args = parser.parse_args()
asyncio.run(
build_interface(args.group_size, args.max_token_len, args.tokenizer, args.port)
)
if __name__ == "__main__":
main()

63
atroposlib/envs/README.md Normal file
View file

@ -0,0 +1,63 @@
# Base Environment (`BaseEnv`)
The `BaseEnv` class (located in `trajectoryhandler/envs/base.py`) provides a foundation for creating custom reinforcement learning environments that interact with Atropos. When creating your own environment, you will typically subclass `BaseEnv` and implement several key methods.
## Core Methods to Implement
These methods **must** be implemented in your subclass:
* **`async def setup(self)`**: This method is called once at the beginning of the environment's lifecycle (`env_manager`). Use it for any initial setup required for your specific environment, such as loading datasets, initializing models, or connecting to external resources.
* **`async def get_next_item(self) -> Item`**: This method is responsible for generating or retrieving the next piece of data (prompt, state, etc.) that will be used to start a new trajectory collection. If no more items are available or should be generated, it can return `None` to signal the worker to pause.
* **`async def collect_trajectory(self, item: Item) -> Tuple[Any | None, List[Item]]`**: This method defines the logic for a *single* logical trajectory collection step based on the input `item`. \
* **How it relates to multiple generations**: The `BaseEnv` uses `collect_trajectories` to run this method multiple times in parallel (controlled by `group_size`) to gather a batch of trajectories. \
* **Your implementation**: You can implement this method to generate *one* response/trajectory per call.\
* **Return value**: It returns a tuple containing:\
1. The collected data for this step (one trajectory). This data can be processed further in `postprocess_histories`, if you require additional filtering right before sending to the API.\
2. A list of new `Item` objects to be added to the backlog for future processing (e.g., follow-up prompts).\
* **`async def evaluate(self, *args, **kwargs)`**: This method is called periodically (controlled by `steps_per_eval` in the config) to perform evaluation runs. You define the evaluation logic here. The base class provides an example using `self.eval_workers` for parallel evaluation tasks, but you can implement any evaluation procedure suitable for your environment.
## Optional Methods to Override
These methods have default implementations or are optional based on your needs:
* **`async def collect_trajectories(self, item: Item) -> Tuple[Union[Optional[ScoredDataGroup], List[Optional[ScoredDataGroup]], List[Any | None]], List[Item]]`**: The default implementation of this method runs `collect_trajectory` multiple times in parallel (controlled by `group_size`). You can override this *instead* of `collect_trajectory` if you have a more efficient way to generate the entire group of responses/trajectories at once based on the input `item`. It should return the collected group data and a list of backlog items.
* **`async def postprocess_histories(self, trajectories: Union[Optional[ScoredDataGroup], List[Optional[ScoredDataGroup]]]) -> Union[Optional[ScoredDataGroup], List[Optional[ScoredDataGroup]]]`**: This method is called after `collect_trajectories` and before the data is sent to the training server. It receives the collected data from the parallel runs (or your custom `collect_trajectories` implementation). Use this to perform final processing, scoring, or formatting you may require before sending to the server. You usually won't need this.
* **`async def wandb_log(self, wandb_metrics: Optional[Dict] = None)`**: Called periodically to log metrics to Weights & Biases. If you override this to add custom metrics, **ensure you call `super().wandb_log(wandb_metrics)`** at the end of your implementation. This ensures that the base class's performance metrics and rollout tables are also logged.
```python
async def wandb_log(self, wandb_metrics: Optional[Dict] = None):
if wandb_metrics is None:
wandb_metrics = {}
# Add your custom metrics
wandb_metrics['my_custom_metric'] = calculate_my_metric()
# ... add more metrics
# Call the parent method to log base metrics
await super().wandb_log(wandb_metrics)
```
* **`save_checkpoint(self, step, data=None)`**: The base class calls this method automatically at checkpoint intervals determined by the server. It saves the provided `data` dictionary (which you might populate with environment-specific state) to a JSON file. You can override this to customize *what* data is saved or *how* it's saved (e.g., using a different format or location), but the triggering mechanism remains automatic.
* **`@classmethod config_init(cls) -> Tuple[BaseEnvConfig, Union[ServerBaseline, List[OpenaiConfig]]]`**: This class method is used by the default `get_cli_serve_config_cls` implementation to get the initial environment configuration (`BaseEnvConfig` subclass) and server configurations (`ServerBaseline` or `List[OpenaiConfig]`) when setting up the `serve` command. The default implementation returns `cls.env_config_cls(), ServerBaseline()`. You might override this if your environment requires different default configurations or specific server setups (like multiple `OpenaiConfig` instances) when run via the CLI `serve` command.
* **`async def cleanup(self)`**: Called after each call to `handle_env`. You can implement this for any cleanup needed after processing a single item, though it's often not required.
## Provided Functionality
`BaseEnv` provides several helpful features:
* **Parallel Trajectory Collection (`collect_trajectories`)**: The base implementation runs your `collect_trajectory` method multiple times in parallel (based on `group_size`) and gathers the results. You can override `collect_trajectories` directly for custom group generation logic (see Optional Methods).
* **Server Interaction**: Handles registration with the rollout server, fetching configuration (like `batch_size`), sending scored data (`handle_send_to_api` with retries), and status updates.
* **WandB Integration**: Sets up WandB logging (if enabled) based on server information and provides the `wandb_log` hook for custom metrics (remember to call `super().wandb_log()`). It uses helper methods `add_rollouts_for_wandb` (to temporarily store rollout data) and `create_rollout_table` (to format the data into a `wandb.Table`). You can override either of these helpers for custom logging behavior (e.g., changing what data is stored or how the final table is structured).
* **Checkpointing**:
* The environment automatically triggers checkpoint saves based on the `checkpoint_interval` received from the server, calling the `save_checkpoint` method (see Optional Methods).
* `load_checkpoint(self)`: Loads data from the checkpoint file corresponding to the environment's `curr_step`. It attempts to restore attributes of the environment object based on the keys in the loaded JSON data. This is called automatically if `curr_step > 0` during registration.
* **Worker Management**: Manages asynchronous worker tasks for collecting trajectories (`add_train_workers`, `handle_env`).
* **Performance Monitoring**: Tracks and logs various performance statistics (task durations, worker counts, etc.).
* **CLI Integration**: Provides a `cli()` class method using `pydantic-cli` to easily create command-line interfaces for your environment (e.g., `python your_env_module.py serve --port 8001 ...`). See `get_cli_serve_config_cls` and `get_cli_process_config_cls`.
By implementing the required methods and optionally overriding others, you can create diverse environments that leverage the distributed training infrastructure provided by the `Atropos` framework.

View file

890
atroposlib/envs/base.py Normal file
View file

@ -0,0 +1,890 @@
import asyncio
import json
import logging
import os
import random
import sys
import time
import uuid
import warnings
from abc import ABC, abstractmethod
from enum import Enum
from typing import Any, Dict, List, Optional, Tuple, TypedDict, Union
import aiohttp
import jsonlines
import numpy as np
from pydantic import BaseModel, Field
from pydantic_cli import Cmd, FailedExecutionException, run_and_exit
from tenacity import retry, stop_after_attempt, wait_random_exponential
from transformers import AutoTokenizer
import wandb
from atroposlib.type_definitions import UUID
from atroposlib.utils.metrics import get_std_min_max_avg
from ..type_definitions import Item, Message
from .server_handling.server_manager import (
OpenaiConfig,
ServerBaseline,
ServerManager,
ServerManagerConfig,
)
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
class ScoredDataGroup(TypedDict):
tokens: List[List[int]]
masks: List[List[int]]
scores: List[float]
advantages: Optional[List[List[float]]]
ref_logprobs: Optional[List[List[float]]]
messages: Optional[List[List[Message]]]
group_overrides: Optional[Dict]
overrides: Optional[List[Dict]]
class EvalHandlingEnum(Enum):
"""
Enum for handling evals.
"""
STOP_TRAIN = "STOP_TRAIN"
LIMIT_TRAIN = "LIMIT_TRAIN"
NONE = "NONE"
class BaseEnvConfig(BaseModel):
"""
Basic env configuration.
"""
group_size: int = Field(
default=4, description="How many responses are grouped together for scoring"
)
max_num_workers: int = Field(
default=-1,
description="Maximum number of workers to use, -1 calculates from max_num_workers_per_node",
)
max_eval_workers: int = Field(
default=16, description="Maximum number of workers to use for evaluation"
)
max_num_workers_per_node: int = Field(
default=8, description="Maximum number of workers to use per node"
)
steps_per_eval: int = Field(
default=100, description="Number of steps to take before evaluating"
)
max_token_length: int = Field(
default=2048, description="Maximum token length used in generations"
)
eval_handling: EvalHandlingEnum = Field(
default=EvalHandlingEnum.STOP_TRAIN, description="How to handle evaluations"
)
eval_limit_ratio: float = Field(
default=0.5, description="Ratio of training workers to limit during evals"
)
inference_weight: float = Field(
default=1.0,
description="Inference weight, set to -1 to ignore it if you're doing something special here.",
)
batch_size: int = Field(
default=-1,
description="Batch size for training, will be set by the trainer and passed in via the fastapi interface, if applicable", # noqa: E501
)
max_batches_offpolicy: int = Field(
default=3, description="Maximum number of batches to have in queue."
)
tokenizer_name: str = Field(
default="NousResearch/DeepHermes-3-Llama-3-1B-Preview",
description="Hugging Face tokenzer to use.",
)
use_wandb: bool = Field(default=True, description="Whether to use wandb")
rollout_server_url: str = Field(
default="http://localhost:8000", description="URL of the rollout server"
)
total_steps: int = Field(default=1000, description="Total number of steps to run")
wandb_name: str | None = Field(
default=None,
description="Name to be grouped by in wandb",
)
num_rollouts_to_keep: int = Field(
default=32, description="Number of rollouts to display on wandb"
)
num_rollouts_per_group_for_logging: int = Field(
default=1,
description="Number of rollouts per group to keep for logging. If -1, keep all rollouts",
)
ensure_scores_are_not_same: bool = Field(
default=True,
description="Ensure that the scores are not the same, should usually be True",
)
data_path_to_save_groups: Optional[str] = Field(
default=None,
description="Path to save the groups, if set, will write groups to this jsonl",
)
min_items_sent_before_logging: int = Field(
default=2,
description="Minimum number of items sent before logging, if 0 or less, logs every time",
)
class BaseEnv(ABC):
name = None
env_config_cls = BaseEnvConfig
def __init__(
self,
config: BaseEnvConfig,
server_configs: Union[ServerBaseline, List[OpenaiConfig]],
slurm=True,
testing=False,
):
self.items_sent_this_step = 0
self.eval_runner = None # type: Optional[asyncio.Task]
self.workers_added_list = list()
self.succeeded_task_duration = list()
self.failed_task_duration = list()
self.task_duration = list()
self.mainloop_timings = list()
self.task_successful = list()
self.last_loop_time = None
self.last_completed_item = None
self.config = config
self.server = ServerManager(server_configs, slurm=slurm, testing=testing)
self.workers = set()
self.eval_workers = set()
self.backlog = []
self.rollouts_for_wandb = []
self.running_items: dict[UUID, Item] = dict()
self.wandb_project = None
self.wandb_group = None
self.curr_step = 0
self.max_token_len = -1
self.tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_name)
self.completion_lengths = []
self.max_num_workers = config.max_num_workers
if self.max_num_workers == -1:
self.max_num_workers = config.max_num_workers_per_node * len(
self.server.servers
)
self.wandb_prepend = None
self.checkpoint_dir = ""
self.checkpoint_interval = -1
if self.config.data_path_to_save_groups is not None:
if os.path.exists(self.config.data_path_to_save_groups):
raise FileExistsError(
"Data path already exists! Please remove it or change it."
)
self.jsonl_writer = jsonlines.open(
self.config.data_path_to_save_groups, "w"
) # type: jsonlines.Writer
else:
self.jsonl_writer = None
@classmethod
def config_init(
cls,
) -> Tuple[BaseEnvConfig, Union[ServerBaseline, List[OpenaiConfig]]]:
"""
Initialize the config
"""
return cls.env_config_cls(), ServerBaseline()
async def collect_trajectory(self, item: Item) -> Tuple[Any | None, List[Item]]:
raise NotImplementedError(
"Handle env single method must be implemented in subclass "
)
async def collect_trajectories(self, item: Item) -> Tuple[
Union[
Optional[ScoredDataGroup], List[Optional[ScoredDataGroup]], List[Any | None]
],
List[Item],
]:
"""
:param item:
:return:
"""
tasks = []
for _ in range(self.config.group_size):
tasks.append(self.collect_trajectory(item))
results = await asyncio.gather(*tasks)
backlog = []
to_postprocess = []
for result in results:
if result[0] is not None:
to_postprocess.append(result[0])
backlog.extend(result[1])
random.shuffle(backlog)
return to_postprocess, backlog
async def postprocess_histories(
self,
trajectories: Union[Optional[ScoredDataGroup], List[Optional[ScoredDataGroup]]],
) -> Union[Optional[ScoredDataGroup], List[Optional[ScoredDataGroup]]]:
"""
Postprocess the histories, this is called after the collect_trajectories method
If you don't need to do anything to the trajectories, you may safely ignore this.
:param trajectories:
:return:
"""
return trajectories
@abstractmethod
async def get_next_item(self) -> Item:
"""
Get the next items to be rolled out
"""
raise NotImplementedError(
"Get_next_items method must be implemented in subclass "
)
@abstractmethod
async def evaluate(self, *args, **kwargs):
"""
Evaluate the environment, this is called every steps_per_eval steps
Included here is an example on how to use eval workers to run a task.
You may however do whatever you want in this method.
:param args:
:param kwargs:
:return: None.
"""
for data in ["my", "eval", "data"]:
while len(self.eval_workers) >= self.config.max_eval_workers:
await asyncio.sleep(0.1)
worker = asyncio.create_task(asyncio.sleep(0.1))
self.eval_workers.add(worker)
worker.add_done_callback(self.eval_workers.discard)
raise NotImplementedError("Evaluate method must be implemented in subclass ")
def load_checkpoint(self):
# check if file exists...
ckpt_path = os.path.join(
self.checkpoint_dir,
"env_checkpoints",
self.wandb_prepend,
f"step-{self.curr_step}.json",
)
if os.path.exists(ckpt_path):
with open(ckpt_path, "r") as f:
data = json.load(f)
# now load the data
for key in data:
setattr(self, key, data[key])
def save_checkpoint(self, step, data=None):
print(f"Saving checkpoint at step {step} with data {data}")
if data is None:
# Don't have anything to save, abort
return
# check if file exists...
ckpt_dir = os.path.join(
self.checkpoint_dir, "env_checkpoints", self.wandb_prepend
)
# create directory if necessary
os.makedirs(ckpt_dir, exist_ok=True)
ckpt_path = os.path.join(
self.checkpoint_dir,
"env_checkpoints",
self.wandb_prepend,
f"step-{step}.json",
)
os.makedirs(os.path.dirname(ckpt_path), exist_ok=True)
with open(ckpt_path, "w") as f:
json.dump(data, f)
async def setup(self):
"""Setup the environment"""
raise NotImplementedError("Setup method must be implemented in subclass")
async def setup_wandb(self):
if self.config.use_wandb:
# Setup wandb getting the group and project via the server
while self.wandb_project is None:
async with aiohttp.ClientSession() as session:
async with session.get(
f"{self.config.rollout_server_url}/wandb_info"
) as resp:
data = await resp.json()
self.wandb_group = data["group"]
self.wandb_project = data["project"]
if self.wandb_project is None:
await asyncio.sleep(1)
else:
wandb.init(
project=self.wandb_project,
group=self.wandb_group,
config=self.config.model_dump(),
)
break
async def register_env(self):
# Now register the env...
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.config.rollout_server_url}/register-env",
json={
"max_token_length": self.config.max_token_length,
"desired_name": self.config.wandb_name,
"weight": self.config.inference_weight,
},
) as resp:
data = await resp.json()
self.env_id = data["env_id"]
self.wandb_prepend = data["wandb_name"]
self.curr_step = data["starting_step"]
self.checkpoint_dir = data["checkpoint_dir"]
self.checkpoint_interval = data["checkpoint_interval"]
if self.config.total_steps == -1:
self.config.total_steps = data["num_steps"]
if self.config.total_steps == -1:
raise ValueError("Total steps not set in config or server!")
print(
f"Initialized env with id {self.env_id}: "
f"curr_step: {self.curr_step}, "
f"checkpoint_dir: {self.checkpoint_dir}, "
f"checkpoint_interval: {self.checkpoint_interval}"
)
if self.curr_step > 0:
self.load_checkpoint()
async def get_server_info(self):
"""
Get the server info
"""
async with aiohttp.ClientSession() as session:
async with session.get(f"{self.config.rollout_server_url}/info") as resp:
data = await resp.json()
if data["batch_size"] != -1:
# update the batch size
self.config.batch_size = data["batch_size"]
if data["max_token_len"] != -1:
self.max_token_len = data["max_token_len"]
if self.config.batch_size == -1:
logging.warning("Batch size not set by config or server!")
if self.config.group_size > self.config.batch_size:
raise ValueError(
f"group_size ({self.config.group_size}) "
f"must be less than batch_size ({self.config.batch_size})"
)
def perf_stats(self, metrics_dict):
"""
returns wandb metrics for performance
"""
if len(self.task_duration) > 1:
get_std_min_max_avg(
"train_perf/task_duration", self.task_duration, metrics_dict
)
self.task_duration = list()
if len(self.succeeded_task_duration) > 1:
get_std_min_max_avg(
"train_perf/succeeded_task_duration",
self.succeeded_task_duration,
metrics_dict,
)
metrics_dict["train/items_sent_to_api"] = len(self.succeeded_task_duration)
self.succeeded_task_duration = list()
if len(self.failed_task_duration) > 1:
get_std_min_max_avg(
"train_perf/failed_task_duration",
self.failed_task_duration,
metrics_dict,
)
metrics_dict["train/items_rejected"] = len(self.failed_task_duration)
self.failed_task_duration = list()
if len(self.mainloop_timings) > 1:
get_std_min_max_avg(
"train_perf/mainloop_timings",
self.mainloop_timings,
metrics_dict,
)
self.mainloop_timings = list()
if len(self.workers_added_list) > 1:
get_std_min_max_avg(
"train_perf/workers_added_per_attempt",
self.workers_added_list,
metrics_dict,
)
self.workers_added_list = list()
return metrics_dict
async def create_rollout_table(self, wandb_metrics):
if len(self.rollouts_for_wandb) > 0:
table = wandb.Table(columns=["text", "score"])
for group in self.rollouts_for_wandb:
for item in group:
table.add_data(item[0], item[1])
wandb_metrics["train/rollouts"] = table
return wandb_metrics
async def add_rollouts_for_wandb(
self,
scored_data: Union[ScoredDataGroup, List[ScoredDataGroup]],
item: Item = None,
):
# Save rollout to trajectory
num_keep = self.config.num_rollouts_per_group_for_logging
if num_keep == -1:
num_keep = self.config.group_size
self.rollouts_for_wandb.append(
[
(
self.tokenizer.decode(scored_data["tokens"][i]),
scored_data["scores"][i],
)
for i in range(num_keep)
]
)
if len(self.rollouts_for_wandb) > self.config.num_rollouts_to_keep:
self.rollouts_for_wandb.pop(0)
async def wandb_log(self, wandb_metrics: Optional[Dict] = None):
"""
Log to wandb.
To use this in your subclass, please ensure this is called after you do your metrics
e.g.
def wandb_log(self, wandb_metrics: Optional[Dict] = None):
wandb_metrics = {}
wandb_metrics['my_metric'] = 0.5
super().wandb_log(wandb_metrics)
"""
if wandb_metrics is None:
wandb_metrics = dict()
for i, server in enumerate(self.server.servers):
server_wandb_metrics = await server.wandb_metrics({}, f"server_{i}")
if len(self.completion_lengths) > 0:
wandb_metrics["train/completion_lengths"] = sum(
self.completion_lengths
) / len(self.completion_lengths)
wandb_metrics["train/completion_lengths_std"] = np.std(
self.completion_lengths
)
wandb_metrics["train/completion_lengths_max"] = np.max(
self.completion_lengths
)
wandb_metrics["train/completion_lengths_min"] = np.min(
self.completion_lengths
)
wandb_metrics["train/completion_lengths_p95"] = (
np.array(self.completion_lengths) > (0.95 * self.max_token_len)
).mean()
wandb_metrics = await self.create_rollout_table(wandb_metrics)
wandb_metrics = self.perf_stats(wandb_metrics)
self.rollouts_for_wandb = []
self.completion_lengths = []
if self.config.use_wandb:
if self.wandb_prepend is not None:
wandb_metrics = {
f"{self.wandb_prepend}_{k}": v for k, v in wandb_metrics.items()
}
# add server metrics to wandb without prepend to collate them all
wandb_metrics.update(server_wandb_metrics)
wandb.log(wandb_metrics, step=self.curr_step)
@retry(
stop=stop_after_attempt(3),
wait=wait_random_exponential(multiplier=1, max=10),
)
async def _send_scored_data_to_api(self, scored_data):
"""
Send scored data to the API with retry logic for timeouts and server errors.
"""
url = (
f"{self.config.rollout_server_url}/scored_data_list"
if isinstance(scored_data, list)
else f"{self.config.rollout_server_url}/scored_data"
)
async with aiohttp.ClientSession() as session:
async with session.post(
url,
json=scored_data,
) as resp:
if resp.status >= 500:
# Server errors (5xx) should trigger a retry
logging.debug(f"Server error: {resp.status}, retrying...")
raise Exception(f"Server error: {resp.status}")
elif resp.status >= 400:
# Client errors (4xx) are logged but not retried
logging.error(f"Client error: {resp.status}, not retrying")
return
# Success case: print response text
print(await resp.text())
async def handle_send_to_api(
self,
scored_data: Union[ScoredDataGroup, List[ScoredDataGroup]],
item: Item = None,
):
"""
Send the chats to the API with robust error handling and support for multiple ScoredDataGroups.
Args:
scored_data: List of scored items to send
item: Optional item for context
"""
group_size = scored_data.get("group_overrides", {}).get(
"group_size", self.config.group_size
)
if (
(scored_data is not None)
and (None not in scored_data)
and (len(scored_data["tokens"]) == group_size)
):
if self.config.ensure_scores_are_not_same:
if len(set(scored_data["scores"])) == 1:
# Scores are the same, don't send to API
return
await self.add_rollouts_for_wandb(scored_data, item)
# Check for ref_logprobs
if "ref_logprobs" not in scored_data:
# Strongly typed dict, so we need to add it
scored_data["ref_logprobs"] = None
if "overrides" not in scored_data:
scored_data["overrides"] = None
if "group_overrides" not in scored_data:
scored_data["group_overrides"] = None
# Track completion lengths
for mask in scored_data["masks"]:
self.completion_lengths.append(len(mask))
# Add the scores to the queue
if any([len(x) >= self.max_token_len for x in scored_data["tokens"]]):
# Don't send to API if the token length is too long
return
# Save data, if applicable:
if self.jsonl_writer is not None:
self.jsonl_writer.write(scored_data)
# Send data with retries and error handling
try:
self.items_sent_this_step += 1
await self._send_scored_data_to_api(scored_data)
except (Exception, TimeoutError) as e:
print(f"Failed to send scored data after retries: {e}")
async def handle_env(
self, item_uuid: str
) -> Optional[Union[ScoredDataGroup, List[ScoredDataGroup]]]:
"""
Handle the rollout of an item
"""
item = self.running_items.get(item_uuid)
if item is None:
print(f"item {item_uuid} not found... returning")
return None
start_time = time.time()
logger.debug(f"handle_env: Starting with item: {item}")
# do a rollout with item
try:
to_postprocess, to_backlog = await self.collect_trajectories(item)
except Exception:
to_postprocess = None
to_backlog = []
# add the items to the queue
if len(to_backlog) > 0:
self.backlog.extend(to_backlog)
try:
if (to_postprocess is None) or (len(to_postprocess) == 0):
pass
else:
to_postprocess = await self.postprocess_histories(to_postprocess)
except Exception as e:
logger.error(f"Error in scoring: {item}")
print(e)
to_postprocess = None
self.running_items.pop(item_uuid, None)
duration = max(0.0, time.time() - start_time)
self.task_duration.append(duration)
if to_postprocess is not None:
self.task_successful.append(1)
self.succeeded_task_duration.append(duration)
logger.debug(f"handle_env: Collected {len(to_postprocess)} trajectories")
await self.handle_send_to_api(to_postprocess, item)
else:
self.task_successful.append(0)
self.failed_task_duration.append(duration)
logger.debug("handle_env: No trajectories collected")
# Finally pop it
await self.cleanup()
return to_postprocess
async def cleanup(self):
"""
Optional: Cleanup the environment
"""
pass
@retry(
stop=stop_after_attempt(3), wait=wait_random_exponential(multiplier=1, max=10)
)
async def get_status(self):
async with aiohttp.ClientSession() as session:
async with session.get(
f"{self.config.rollout_server_url}/status-env",
json={"env_id": self.env_id},
) as resp:
self.status_dict = await resp.json()
new_weight = self.status_dict["env_weight"]
max_num_workers = self.config.max_num_workers
if max_num_workers == -1:
max_num_workers = self.config.max_num_workers_per_node * len(
self.server.servers
)
self.max_num_workers = max_num_workers
await self.server.update_weight(new_weight)
async def env_step_checks(self):
# Check if we need to run an eval or log...
if self.curr_step != self.status_dict["current_step"]:
if self.config.steps_per_eval > 0:
if (self.curr_step % self.config.steps_per_eval) > (
self.status_dict["current_step"] % self.config.steps_per_eval
):
if (self.eval_runner is None) or (self.eval_runner.done()):
eval_task = asyncio.create_task(self.evaluate())
self.eval_runner = eval_task
if self.config.eval_handling == EvalHandlingEnum.STOP_TRAIN:
# Stop training if eval is running
self.backlog.extend(self.running_items.values())
for worker in self.workers:
worker.cancel()
self.workers = set()
self.running_items: dict[UUID, Item] = dict()
else:
warnings.warn(
"Eval is not finished in this iteration of the loop, skipping this eval step..."
)
if self.checkpoint_interval > 0:
if (self.curr_step % self.checkpoint_interval) > (
self.status_dict["current_step"] % self.checkpoint_interval
):
checkpoint_step = (
self.status_dict["current_step"] // self.checkpoint_interval
) * self.checkpoint_interval
self.save_checkpoint(checkpoint_step)
self.curr_step = self.status_dict["current_step"]
if self.items_sent_this_step >= self.config.min_items_sent_before_logging:
self.items_sent_this_step = 0
await self.wandb_log({})
async def add_train_workers(self):
if (self.eval_runner is not None) and (not self.eval_runner.done()):
if self.config.eval_handling == EvalHandlingEnum.STOP_TRAIN:
return
elif self.config.eval_handling == EvalHandlingEnum.LIMIT_TRAIN:
max_num_workers = int(
self.max_num_workers * self.config.eval_limit_ratio
)
else:
max_num_workers = self.max_num_workers
else:
max_num_workers = self.max_num_workers
# set max_num_workers to whatever is max off policy and num workers
max_num_workers = min(
max_num_workers,
(
self.config.max_batches_offpolicy
* self.config.batch_size
// self.config.group_size
)
- (self.status_dict["queue_size"]),
)
if (self.curr_step == 0) and (len(self.workers) == 0):
# We are starting up, so we should just skip the append to the list
pass
else:
self.workers_added_list.append(max_num_workers - len(self.workers))
while len(self.workers) < max_num_workers:
# Generate a UUID for tracking this item
item_uuid = str(uuid.uuid4())
if len(self.backlog) > 0:
item = self.backlog.pop()
else:
item = await self.get_next_item()
if item is None:
break
self.running_items[item_uuid] = item
worker = asyncio.create_task(self.handle_env(item_uuid))
self.workers.add(worker)
worker.add_done_callback(
lambda fut, i=item: (
(
self.workers.discard(fut),
(
setattr(self, "last_completed_item", i)
if fut.result()
else None
),
)[1]
if fut.done() and not fut.cancelled()
else None
)
)
async def env_manager(self):
"""
Rollout manager
"""
await self.setup()
await self.setup_wandb()
await self.register_env()
await self.get_server_info()
# Wait for other instances to get setup :)
await asyncio.sleep(5)
while True:
if self.last_loop_time is not None:
self.mainloop_timings.append(
max(0.0, time.time() - self.last_loop_time)
)
# get status from server
self.last_loop_time = time.time()
await self.get_status()
await self.env_step_checks()
logger.info(f"env_manager: Status dict: {self.status_dict}")
if (
self.status_dict["current_step"]
+ (
self.status_dict["queue_size"]
* self.config.group_size
// self.config.batch_size
)
) > self.config.total_steps:
for worker in self.workers:
worker.cancel()
break
if (
(
self.status_dict["queue_size"] * self.config.group_size
>= self.config.max_batches_offpolicy * self.config.batch_size
)
and (self.config.max_batches_offpolicy > 0)
) or (self.config.batch_size == -1):
# We have too many, lets cleanup the tasks and wait a bit
self.backlog.extend(self.running_items.values())
for worker in self.workers:
worker.cancel()
self.running_items = dict()
self.workers = set()
elif len(self.workers) >= self.max_num_workers:
pass
else:
await self.add_train_workers()
await asyncio.sleep(0.1)
@classmethod
def cli(cls):
"""
Command-line interface entry point for the environment.
This method handles the CLI commands for serve and process.
"""
# Create subcommands dictionary
subcommands = {
"serve": cls.get_cli_serve_config_cls(),
"process": cls.get_cli_process_config_cls(),
}
# Custom exception handler for cleaner error output
def custom_error_handler(ex: Exception) -> int:
"""Handles exceptions with clean output for known error types."""
if isinstance(ex, FailedExecutionException):
# Handle argparse errors (already printed by argparse)
print()
print(ex.message.split("error: ")[-1])
return 2
else:
# For any other exception
print(f"Error: {str(ex)}", file=sys.stderr)
return 1
run_and_exit(
subcommands,
description=f"CLI for {cls.__name__}",
exception_handler=custom_error_handler,
)
@classmethod
def get_cli_serve_config_cls(cls) -> type:
"""
Returns the CLI configuration class for serving commands.
Returns:
type: The CliServeConfig class for serving commands.
"""
env_config, server_configs = cls.config_init()
class CliServeConfig(
cls.env_config_cls, OpenaiConfig, ServerManagerConfig, Cmd
):
"""
Configuration for the serve command.
This combines BaseEnvConfig and OpenaiConfig into a single command.
"""
def run(self) -> None:
"""The logic to execute for the 'serve' command."""
# Convert this config into the formats needed by BaseEnv
if self.wandb_name is None and cls.name is not None:
self.wandb_name = cls.name
model_dumped = self.model_dump(exclude_unset=True)
server_manager_config = ServerManagerConfig(**model_dumped)
# Create the environment instance
env = cls(
config=env_config,
server_configs=server_configs,
slurm=server_manager_config.slurm,
testing=server_manager_config.testing,
)
# Run the environment
asyncio.run(env.env_manager())
return CliServeConfig
@classmethod
def get_cli_process_config_cls(cls) -> type:
"""
Returns the CLI configuration class for processing commands.
Returns:
type: The CliProcessConfig class for processing commands.
"""
class CliProcessConfig(Cmd):
"""
Configuration for the process command.
This is a placeholder for future implementation.
"""
# Add process-specific fields here
group_size: int = Field(
default=4, description="Number of responses per prompt"
)
n_groups: int = Field(default=1, description="Number of groups to process")
output_file: str = Field(
..., description="Path to jsonl file to write results"
)
def run(self) -> None:
"""The logic to execute for the 'process' command."""
print(
f"Processing {self.n_groups} groups of "
f"{self.group_size} responses and "
f"writing to {self.output_file}"
)
print("This is a placeholder implementation for the process command.")
# Actual implementation would go here
return CliProcessConfig

View file

@ -0,0 +1,30 @@
"""
Reward functions for evaluating model outputs in various environments.
This module provides a framework for creating, composing, and applying reward functions
to evaluate model outputs. Reward functions can be used for both dataset environments
and online/gymnasium environments.
Key components:
- RewardFunction: Abstract base class for all reward functions
- RewardRegistry: Registry for registering and loading reward functions
- CombinedReward: Meta reward function that combines multiple reward functions
Usage:
# Define a reward function
@registry.register
class MyReward(RewardFunction):
def compute(self, completions, **kwargs):
# Implementation
return [score for completion in completions]
# Create and use a reward function
reward_fn = registry.create("my_reward", weight=1.5)
scores = reward_fn(completions, **kwargs)
"""
from .combined_reward import CombinedReward
from .registry import registry
from .reward_function import RewardFunction
__all__ = ["RewardFunction", "registry", "CombinedReward"]

View file

@ -0,0 +1,296 @@
"""Reward function for checking if completions match ground truth answers."""
import logging
import re
from typing import Any, List, Optional, Union
from latex2sympy2_extended import NormalizationConfig
from math_verify import LatexExtractionConfig, parse, verify
from .registry import registry
from .reward_function import RewardFunction
logger = logging.getLogger(__name__)
def _normalize_numerical_value(value_str: str) -> float:
"""Convert a string representation of a number to float, handling formatting."""
return float(value_str.replace(",", "").strip())
def _extract_final_answer(text: str) -> str:
"""
Extract the final answer from text that might include a full solution.
Handles formats like:
- "#### 42" (GSM8K style)
- "The answer is 42"
- "\\boxed{42}"
Returns the extracted answer or the original text if no pattern is found.
"""
# Check for GSM8K style answers (#### 42)
if "####" in text:
match = re.search(r"####\s*(.*?)(?:\s*$|\n)", text)
if match:
return match.group(1).strip()
# Check for boxed answers
if "\\boxed{" in text:
match = re.search(r"\\boxed\{([^}]+)\}", text)
if match:
return match.group(1).strip()
# If no special format is found, return the original text
return text
def _verify_answer(
content: str, gold_answer: Union[float, int, str], tolerance: float = 1e-6
) -> bool:
"""
Verifies if the provided content contains an answer matching the gold answer.
Uses a robust approach with multiple fallback strategies.
Args:
content: The model's response content to evaluate
gold_answer: The correct answer to compare against
tolerance: Tolerance for floating point comparisons
Returns:
Boolean indicating whether the answer is correct
"""
# Extract the final answer from the gold answer if it has a special format
if isinstance(gold_answer, str):
# Check for GSM8K style answers (#### number)
if "####" in gold_answer:
gold_answer = _extract_final_answer(gold_answer)
logger.warning(f"Extracted gold answer: {gold_answer}")
# Convert gold_answer to numerical if it's not already and if possible
gold_value = None
if isinstance(gold_answer, (int, float)):
gold_value = gold_answer
elif isinstance(gold_answer, str):
# Try to extract numerical value if it's in boxed format
if "\\boxed{" in gold_answer:
try:
gold_value = _normalize_numerical_value(
gold_answer.replace("\\boxed{", "").replace("}", "")
)
except ValueError:
# Not a numerical value, keep as string for LaTeX parsing
pass
else:
# Try to convert to float if possible
try:
gold_value = _normalize_numerical_value(gold_answer)
except ValueError:
# Not a numerical value, keep as string for LaTeX parsing
pass
# First attempt: Try to parse with math_verify
try:
answer_parsed = parse(
content,
extraction_config=[
LatexExtractionConfig(
normalization_config=NormalizationConfig(
nits=False,
malformed_operators=False,
basic_latex=True,
equations=True,
boxed="all",
units=True,
),
# Ensures that boxed is tried first
boxed_match_priority=0,
try_extract_without_anchor=False,
)
],
extraction_mode="first_match",
)
logger.warning(f"Answer parsed result: {answer_parsed}")
# If we got a valid parse, verify it against the gold answer
if answer_parsed:
# Format gold answer for verification
gold_str = (
f"\\boxed{{{gold_answer}}}"
if not isinstance(gold_answer, str) or "\\boxed" not in gold_answer
else gold_answer
)
gold_parsed = parse(
gold_str,
extraction_mode="first_match",
extraction_config=[LatexExtractionConfig()],
)
logger.warning(f"Gold parsed result: {gold_parsed}")
if gold_parsed:
return verify(answer_parsed, gold_parsed)
except Exception as e:
logger.warning(f"Exception in primary parsing: {e}")
# Fallback: Use regex to extract boxed content for numerical comparison
if gold_value is not None: # Only try numerical comparison if gold is a number
try:
# Try to extract a boxed answer first
boxed_matches = re.findall(r"\\boxed\{([^}]+)\}", content)
if boxed_matches:
logger.warning(f"Regex boxed matches: {boxed_matches}")
# Try to extract a numerical value from the boxed content
try:
extracted_value = _normalize_numerical_value(boxed_matches[0])
logger.warning(
f"Extracted value: {extracted_value}, Gold value: {gold_value}"
)
# Allow for small floating point differences
return abs(extracted_value - gold_value) < tolerance
except ValueError:
logger.warning(f"Could not convert '{boxed_matches[0]}' to float")
# If no boxed answer, check for a final answer after ####
if "####" in content:
match = re.search(r"####\s*([\d\.]+)", content)
if match:
extracted_value = _normalize_numerical_value(match.group(1))
logger.warning(
f"Extracted value from ####: {extracted_value}, Gold value: {gold_value}"
)
return abs(extracted_value - gold_value) < tolerance
except Exception as e:
logger.warning(f"Exception in regex parsing: {e}")
return False
@registry.register
class AccuracyReward(RewardFunction):
"""
Reward function that checks if completions match ground truth answers.
Works with boxed LaTeX answers, GSM8K-style answers, and other formats.
Uses a robust approach with multiple fallback strategies for parsing and verification.
"""
def __init__(
self,
tolerance: float = 1e-6,
split_on_think_tag: bool = True,
max_boxed_threshold: int = 6,
weight: float = 1.0,
**kwargs,
):
"""
Initialize the accuracy reward function.
Args:
tolerance: Tolerance for floating point comparisons
split_on_think_tag: Whether to use only the text after </think> tag
max_boxed_threshold: Maximum number of boxed expressions before marking as incorrect
weight: Weight for this reward
**kwargs: Additional configuration
"""
super().__init__(weight=weight, **kwargs)
self.tolerance = tolerance
self.split_on_think_tag = split_on_think_tag
self.max_boxed_threshold = max_boxed_threshold
def compute(
self,
completions: List[Any],
solution: Optional[Union[str, List[str]]] = None,
ground_truth: Optional[Union[str, List[str]]] = None,
**kwargs,
) -> List[float]:
"""
Check if completions match ground truth answers.
Args:
completions: List of model completions to evaluate
solution: Ground truth solution(s) - can be a single value or list of values
ground_truth: Optional canonical ground truth answers (used instead of solution if provided)
**kwargs: Additional context
Returns:
List of reward values (1.0 for correct, 0.0 for incorrect)
"""
rewards = []
# Check if we have a solution or ground truth
if solution is None and ground_truth is None:
logger.warning("No solution or ground_truth provided to accuracy_reward")
return [0.0] * len(completions)
# Use ground_truth instead of solution if available
gold_answers = ground_truth if ground_truth is not None else solution
if isinstance(gold_answers, list):
answers = gold_answers
else:
answers = [gold_answers] * len(completions)
for completion, ans in zip(completions, answers):
try:
content = self.get_content(completion)
if (
self.split_on_think_tag
and "</think>" in content
and content.split("</think>")[-1].count("\\boxed")
> self.max_boxed_threshold
):
logger.warning(
"Too many \\boxed commands in response, marking as incorrect"
)
reward = 0.0
else:
if self.split_on_think_tag and "</think>" in content:
answer_part = content.split("</think>")[-1]
else:
answer_part = content
reward = float(_verify_answer(answer_part, ans, self.tolerance))
except Exception as e:
logger.warning(f"Error in accuracy_reward: {e}")
logger.exception(e)
reward = 0.0
rewards.append(reward)
# Calculate statistics
if rewards:
logger.info(
f"Accuracy: {sum(rewards)}/{len(rewards)} ({sum(rewards)/len(rewards):.2f})"
)
return rewards
# Legacy function for backward compatibility
def accuracy_reward(
completions: List[Any],
solution: Union[str, List[str]] = None,
ground_truth: Union[str, List[str]] = None,
**kwargs,
) -> List[float]:
"""
Legacy function wrapper for AccuracyReward.
Args:
completions: List of model completions to evaluate
solution: Ground truth solution(s) - can be a single value or list of values
ground_truth: Optional canonical ground truth answers (used instead of solution if provided)
**kwargs: Additional parameters
Returns:
List of reward values (1.0 for correct, 0.0 for incorrect)
"""
reward_fn = AccuracyReward()
return reward_fn.compute(
completions, solution=solution, ground_truth=ground_truth, **kwargs
)

View file

@ -0,0 +1,200 @@
import logging
import re
from typing import Any, List, Union
from latex2sympy2_extended import NormalizationConfig
from math_verify import LatexExtractionConfig, parse, verify
logger = logging.getLogger(__name__)
def get_completion_content(completion) -> str:
"""Extract content from completion in various formats."""
if isinstance(completion, str):
return completion
elif isinstance(completion, dict):
if "content" in completion:
return completion["content"]
elif isinstance(completion.get("message", {}), dict):
return completion["message"].get("content", "")
elif isinstance(completion, list) and len(completion) > 0:
if isinstance(completion[0], dict) and "content" in completion[0]:
return completion[0]["content"]
logger.warning(f"Could not extract content from completion: {completion}")
return str(completion)
def _normalize_numerical_value(value_str: str) -> float:
"""Convert a string representation of a number to float, handling formatting."""
return float(value_str.replace(",", "").strip())
def _extract_final_answer(text: str) -> str:
"""Extract the final answer from text with various formats (GSM8K, boxed, etc)."""
if "####" in text:
match = re.search(r"####\s*(.*?)(?:\s*$|\n)", text)
if match:
return match.group(1).strip()
if "\\boxed{" in text:
match = re.search(r"\\boxed\{([^}]+)\}", text)
if match:
return match.group(1).strip()
return text
def _verify_answer(content: str, gold_answer: Union[float, int, str]) -> bool:
"""Verifies if the content matches the gold answer using multiple strategies."""
if isinstance(gold_answer, str):
if "####" in gold_answer:
gold_answer = _extract_final_answer(gold_answer)
logger.warning(f"Extracted gold answer: {gold_answer}")
gold_value = None
if isinstance(gold_answer, (int, float)):
gold_value = gold_answer
elif isinstance(gold_answer, str):
if "\\boxed{" in gold_answer:
try:
gold_value = _normalize_numerical_value(
gold_answer.replace("\\boxed{", "").replace("}", "")
)
except ValueError:
pass
else:
try:
gold_value = _normalize_numerical_value(gold_answer)
except ValueError:
pass
# Try math_verify parsing first
try:
answer_parsed = parse(
content,
extraction_config=[
LatexExtractionConfig(
normalization_config=NormalizationConfig(
nits=False,
malformed_operators=False,
basic_latex=True,
equations=True,
boxed="all",
units=True,
),
boxed_match_priority=0,
try_extract_without_anchor=False,
)
],
extraction_mode="first_match",
)
if answer_parsed:
gold_str = (
f"\\boxed{{{gold_answer}}}"
if not isinstance(gold_answer, str) or "\\boxed" not in gold_answer
else gold_answer
)
gold_parsed = parse(
gold_str,
extraction_mode="first_match",
extraction_config=[LatexExtractionConfig()],
)
if gold_parsed:
return verify(answer_parsed, gold_parsed)
except Exception as e:
logger.warning(f"Exception in primary parsing: {e}")
# Fallback to numerical comparison
if gold_value is not None:
try:
boxed_matches = re.findall(r"\\boxed\{([^}]+)\}", content)
if boxed_matches:
try:
extracted_value = _normalize_numerical_value(boxed_matches[0])
return abs(extracted_value - gold_value) < 1e-6
except ValueError:
pass
if "####" in content:
match = re.search(r"####\s*([\d\.]+)", content)
if match:
extracted_value = _normalize_numerical_value(match.group(1))
return abs(extracted_value - gold_value) < 1e-6
except Exception as e:
logger.warning(f"Exception in regex parsing: {e}")
return False
def format_reward(completions, reward_value=0.5, **kwargs):
"""Checks if completion has proper think tag formatting."""
pattern = r"^<think>[^<]*</think>[^<]*$"
try:
completion_contents = [get_completion_content(c) for c in completions]
matches = [
re.match(pattern, content, re.DOTALL) for content in completion_contents
]
return [reward_value if match else 0.0 for match in matches]
except Exception as e:
logger.error(f"Error in format reward calculation: {e}")
return [0.0] * len(completions)
def accuracy_reward(
completions: List[Any], solution: Union[str, List[str]], **kwargs
) -> List[float]:
"""Checks answer accuracy using sophisticated verification."""
rewards = []
if not isinstance(solution, list):
solution = [solution] * len(completions)
for completion, sol in zip(completions, solution):
try:
content = get_completion_content(completion)
if (
"</think>" in content
and content.split("</think>")[-1].count("\\boxed") > 6
):
logger.warning(
"Too many \\boxed commands in response, marking as incorrect"
)
reward = 0.0
else:
answer_part = (
content.split("</think>")[-1] if "</think>" in content else content
)
reward = float(_verify_answer(answer_part, sol))
except Exception as e:
logger.warning(f"Error in accuracy reward: {e}")
reward = 0.0
rewards.append(reward)
return rewards
def cascading_r1_math_reward(completions, solution, **kwargs) -> list[float]:
"""Combines sophisticated accuracy checking with format verification."""
try:
accuracy_rewards = accuracy_reward(completions, solution, **kwargs)
format_rewards = format_reward(completions)
combined_rewards = []
for accuracy_score, format_score in zip(accuracy_rewards, format_rewards):
# Only add format bonus if answer is correct
format_bonus = format_score if accuracy_score > 0 else 0.0
total_reward = accuracy_score + format_bonus
combined_rewards.append(total_reward)
logger.info(
f"Teknium rewards: accuracy={accuracy_rewards}, format={format_rewards}, combined={combined_rewards}"
)
return combined_rewards
except Exception as e:
logger.error(f"Error in teknium_reward: {e}")
return [0.0] * len(completions)

View file

@ -0,0 +1,91 @@
"""Combined reward function that combines multiple reward functions."""
import logging
from typing import Any, Dict, List, Union
from .registry import registry
from .reward_function import RewardFunction
logger = logging.getLogger(__name__)
@registry.register
class CombinedReward(RewardFunction):
"""Meta reward function that combines multiple reward functions"""
def __init__(
self,
rewards: List[Union[str, Dict]],
normalization: str = "none",
weight: float = 1.0,
**kwargs,
):
"""
Initialize with a list of reward functions to combine.
Args:
rewards: List of reward functions (names or config dicts)
normalization: How to normalize rewards, one of:
- "none": No normalization
- "sum": Divide by sum of weights
- "minmax": Scale to range [0,1] based on min/max values
weight: Weight for this combined reward
**kwargs: Additional parameters
"""
super().__init__(weight=weight, **kwargs)
self.normalization = normalization
self.reward_functions = []
# Initialize all sub-reward functions
for reward_config in rewards:
self.reward_functions.append(registry.create(reward_config))
@property
def name(self) -> str:
"""Get a descriptive name for this combined reward"""
return f"combined({','.join(r.name for r in self.reward_functions)})"
def set_wandb_logger(self, logger):
"""Propagate the WandB logger to all sub-rewards"""
super().set_wandb_logger(logger)
for reward_fn in self.reward_functions:
reward_fn.set_wandb_logger(logger)
def compute(self, completions: List[Any], **kwargs) -> List[float]:
"""Compute combined rewards by calling all sub-rewards"""
if not completions:
return []
# Initialize with zeros
combined_rewards = [0.0] * len(completions)
# Collect all sub-reward values
all_rewards = []
for reward_fn in self.reward_functions:
try:
rewards = reward_fn.compute(completions, **kwargs)
all_rewards.append(rewards)
# Add to combined total (pre-normalization)
for i, r in enumerate(rewards):
combined_rewards[i] += r
except Exception as e:
logger.error(f"Error computing reward for {reward_fn.name}: {e}")
logger.exception(e)
# Apply normalization if needed
if self.normalization == "sum":
total_weight = sum(r.weight for r in self.reward_functions)
if total_weight > 0:
combined_rewards = [r / total_weight for r in combined_rewards]
elif self.normalization == "minmax":
# Avoid division by zero
reward_min = min(combined_rewards) if combined_rewards else 0
reward_max = max(combined_rewards) if combined_rewards else 0
if reward_max > reward_min:
combined_rewards = [
(r - reward_min) / (reward_max - reward_min)
for r in combined_rewards
]
return combined_rewards

View file

@ -0,0 +1,201 @@
"""Reward function for evaluating semantic similarity between completions and solutions."""
import logging
from typing import Any, List, Optional, Union
import scipy
import torch
from transformers import AutoModel, AutoTokenizer
from .registry import registry
from .reward_function import RewardFunction
logger = logging.getLogger(__name__)
@registry.register
class CosineScaledReward(RewardFunction):
"""
Reward function that measures semantic similarity between completions and solutions.
Uses sentence embeddings to compute cosine similarity, providing higher rewards
for completions that are semantically similar to the reference solution.
"""
# Class-level variables for model caching
_model = None
_tokenizer = None
_model_name = "sentence-transformers/all-MiniLM-L6-v2"
def __init__(
self,
model_name: Optional[str] = None,
scale_factor: float = 1.0,
min_reward: float = -1.0,
max_reward: float = 1.0,
default_reward: float = 0.0,
weight: float = 1.0,
**kwargs,
):
"""
Initialize the cosine similarity reward function.
Args:
model_name: Name of embedding model to use (default: "sentence-transformers/all-MiniLM-L6-v2")
scale_factor: Factor to scale similarity by (default: 1.0)
min_reward: Minimum reward value (default: -1.0)
max_reward: Maximum reward value (default: 1.0)
default_reward: Default reward when similarity can't be calculated (default: 0.0)
weight: Weight for this reward
**kwargs: Additional configuration
"""
super().__init__(weight=weight, **kwargs)
self.model_name = model_name or self._model_name
self.scale_factor = scale_factor
self.min_reward = min_reward
self.max_reward = max_reward
self.default_reward = default_reward
# Initialize model and tokenizer if needed
self._ensure_model_loaded()
def _ensure_model_loaded(self):
"""Ensure the model and tokenizer are loaded, loading them if needed."""
# Check if we need to load a different model than what's cached
if self.model_name != self._model_name or CosineScaledReward._model is None:
try:
CosineScaledReward._tokenizer = AutoTokenizer.from_pretrained(
self.model_name
)
CosineScaledReward._model = AutoModel.from_pretrained(self.model_name)
CosineScaledReward._model_name = self.model_name
logger.info(
f"Loaded model and tokenizer for cosine similarity: {self.model_name}"
)
except Exception as e:
logger.error(f"Error loading model for cosine similarity: {e}")
logger.exception(e)
CosineScaledReward._tokenizer = None
CosineScaledReward._model = None
def _mean_pooling(self, model_output, attention_mask):
"""Mean Pooling - Take attention mask into account for correct averaging"""
token_embeddings = model_output[0]
input_mask_expanded = (
attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
)
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(
input_mask_expanded.sum(1), min=1e-9
)
def _get_embeddings(self, text):
"""Get embeddings for text using the model"""
if CosineScaledReward._model is None or CosineScaledReward._tokenizer is None:
logger.error("Model or tokenizer not available for embeddings")
return None
try:
# Tokenize and prepare for the model
encoded_input = CosineScaledReward._tokenizer(
text, padding=True, truncation=True, return_tensors="pt"
)
# Get model output
with torch.no_grad():
model_output = CosineScaledReward._model(**encoded_input)
# Perform mean pooling
sentence_embeddings = self._mean_pooling(
model_output, encoded_input["attention_mask"]
)
return sentence_embeddings.numpy()
except Exception as e:
logger.error(f"Error getting embeddings: {e}")
logger.exception(e)
return None
def compute(
self,
completions: List[Any],
solution: Optional[Union[str, List[str]]] = None,
**kwargs,
) -> List[float]:
"""
Calculate reward based on cosine similarity between completion and solution.
Args:
completions: List of completions to evaluate
solution: The reference solution to compare against
**kwargs: Additional context
Returns:
List of rewards based on cosine similarity, scaled to [min_reward, max_reward]
"""
# Extract content from different possible formats
completion_contents = [
self.get_content(completion) for completion in completions
]
# If no solution provided, can't calculate similarity
if not solution:
logger.warning("No solution provided for cosine similarity calculation")
return [self.default_reward] * len(completion_contents)
solution_text = (
solution if isinstance(solution, str) else self.get_content(solution)
)
rewards = []
for content in completion_contents:
try:
# Get embeddings
solution_embedding = self._get_embeddings(solution_text)
completion_embedding = self._get_embeddings(content)
if solution_embedding is None or completion_embedding is None:
logger.warning("Could not get embeddings for cosine similarity")
rewards.append(self.default_reward)
continue
# Calculate cosine similarity
similarity = scipy.spatial.distance.cosine(
solution_embedding.flatten(), completion_embedding.flatten()
)
# Scale similarity to a reward between min_reward and max_reward
# Cosine distance ranges from 0 (similar) to 2 (dissimilar)
# We want to map 0 → max_reward (good) and 2 → min_reward (bad)
normalized_similarity = 1.0 - similarity * self.scale_factor
reward = min(
self.max_reward, max(self.min_reward, normalized_similarity)
)
logger.info(f"Cosine similarity: {similarity}, scaled reward: {reward}")
rewards.append(reward)
except Exception as e:
logger.error(f"Error in cosine similarity calculation: {e}")
logger.exception(e)
rewards.append(self.default_reward)
return rewards
# Legacy function for backward compatibility
def cosine_scaled_reward(
completions: List[Any], solution=None, **kwargs
) -> List[float]:
"""
Legacy function wrapper for CosineScaledReward.
Args:
completions: List of completions to evaluate
solution: The reference solution to compare against
**kwargs: Additional parameters
Returns:
List of rewards based on cosine similarity, scaled to [-1, 1]
"""
reward_fn = CosineScaledReward()
return reward_fn.compute(completions, solution=solution, **kwargs)

View file

@ -0,0 +1,121 @@
"""Reward function for evaluating crossword puzzle answer formatting."""
import logging
import re
from typing import Any, List, Optional, Pattern
from .registry import registry
from .reward_function import RewardFunction
logger = logging.getLogger(__name__)
@registry.register
class CrosswordFormatReward(RewardFunction):
"""
Reward function for crossword puzzle game answers.
Checks if completions follow the expected formatting for crossword puzzle answers:
- Contains answer patterns like "1-Across: WORD"
- Uses only valid characters (letters, no numbers or special chars in answers)
- Follows specified formatting patterns
"""
def __init__(
self,
format_patterns: Optional[List[Pattern]] = None,
reward_value: float = 1.0,
penalize_invalid_chars: bool = True,
valid_chars: str = "ABCDEFGHIJKLMNOPQRSTUVWXYZ",
weight: float = 1.0,
**kwargs,
):
"""
Initialize the crossword format reward function.
Args:
format_patterns: List of regex patterns to match (optional)
reward_value: Value to award for correct formatting
penalize_invalid_chars: Whether to penalize invalid characters
valid_chars: String of valid characters for answers
weight: Weight for this reward
**kwargs: Additional configuration
"""
super().__init__(weight=weight, **kwargs)
self.reward_value = reward_value
self.penalize_invalid_chars = penalize_invalid_chars
self.valid_chars = valid_chars.upper()
# Default patterns if none provided
self.format_patterns = format_patterns or [
re.compile(
r"\d+-(?:Across|Down):\s+[A-Z\s]+", re.IGNORECASE
), # Basic format pattern
re.compile(
r"^(?:\d+-(?:Across|Down):\s+[A-Z\s]+[\s,]*)+$", re.IGNORECASE
), # Full response format
]
def compute(self, completions: List[Any], **kwargs) -> List[float]:
"""
Check if completions follow crossword answer formatting.
Args:
completions: List of completions to evaluate
**kwargs: Additional context
Returns:
List of rewards (reward_value for correct format, 0.0 otherwise)
"""
# Extract content from different possible formats
completion_contents = [
self.get_content(completion) for completion in completions
]
rewards = []
for content in completion_contents:
try:
# Check for format patterns
format_match = any(
pattern.search(content) for pattern in self.format_patterns
)
# Look for answers and check for invalid characters
valid_chars = True
if self.penalize_invalid_chars:
# Extract answers (text after "Across:" or "Down:")
answers = re.findall(
r"(?:Across|Down):\s+([A-Za-z]+)", content, re.IGNORECASE
)
for answer in answers:
# Check if answer contains only valid characters
if not all(c.upper() in self.valid_chars for c in answer):
valid_chars = False
break
# Both format and valid chars must be correct for full reward
correct_format = format_match and valid_chars
rewards.append(self.reward_value if correct_format else 0.0)
except Exception as e:
logger.error(f"Error in crossword format reward calculation: {e}")
logger.exception(e)
rewards.append(0.0)
return rewards
# Legacy function for backward compatibility
def crossword_format_reward(completions: List[Any], **kwargs) -> List[float]:
"""
Legacy function wrapper for CrosswordFormatReward.
Args:
completions: List of completions to evaluate
**kwargs: Additional parameters
Returns:
List of rewards for crossword format quality
"""
reward_fn = CrosswordFormatReward()
return reward_fn.compute(completions, **kwargs)

View file

@ -0,0 +1,105 @@
"""Reward function for checking if completions have specific XML-style tags."""
import logging
import re
from typing import Any, List, Optional
from .registry import registry
from .reward_function import RewardFunction
logger = logging.getLogger(__name__)
@registry.register
class FormatReward(RewardFunction):
"""Reward function that checks if completions have XML-style tags."""
def __init__(
self,
preferred_tags: Optional[List[str]] = None,
require_all_tags: bool = False,
case_sensitive: bool = False,
weight: float = 1.0,
**kwargs,
):
"""
Initialize the format reward function.
Args:
preferred_tags: List of tag names to search for (defaults to ['think', 'answer'])
require_all_tags: If True, require all tags to be present for a reward
case_sensitive: If True, perform case-sensitive tag matching
weight: Weight for this reward
**kwargs: Additional configuration
"""
super().__init__(weight=weight, **kwargs)
self.preferred_tags = preferred_tags or ["think", "answer"]
self.require_all_tags = require_all_tags
self.case_sensitive = case_sensitive
def compute(self, completions: List[Any], **kwargs) -> List[float]:
"""
Check if completions have the expected XML-style tags.
Args:
completions: List of completions to evaluate
**kwargs: Additional context
Returns:
List of rewards for each completion (1.0 for good format, 0.0 otherwise)
"""
# Extract content from different possible formats
completion_contents = [
self.get_content(completion) for completion in completions
]
# For each completion, check for the preferred tags
rewards = []
flags = 0 if self.case_sensitive else re.IGNORECASE
flags |= re.DOTALL # Allow . to match newlines
for content in completion_contents:
if self.require_all_tags:
# All tags must be present
all_tags_present = True
for tag in self.preferred_tags:
pattern = f"<{tag}>.*?</{tag}>"
if not re.search(pattern, content, flags):
all_tags_present = False
break
rewards.append(1.0 if all_tags_present else 0.0)
else:
# Any tag can be present
has_tags = False
for tag in self.preferred_tags:
pattern = f"<{tag}>.*?</{tag}>"
if re.search(pattern, content, flags):
has_tags = True
break
rewards.append(1.0 if has_tags else 0.0)
# Log the results
logger.info(
f"Format reward results: {sum(rewards)}/{len(rewards)} completions match format"
)
return rewards
# Legacy function for backward compatibility
def format_reward(
completions: List[Any], preferred_tags: Optional[List[str]] = None, **kwargs
) -> List[float]:
"""
Legacy function wrapper for FormatReward.
Args:
completions: List of completions to evaluate
preferred_tags: List of tag names to search for (defaults to ['think', 'answer'])
**kwargs: Additional keyword arguments
Returns:
List of rewards for each completion (1.0 for good format, 0.0 otherwise)
"""
reward_fn = FormatReward(preferred_tags=preferred_tags)
return reward_fn.compute(completions, **kwargs)

View file

@ -0,0 +1,363 @@
"""Reward function that combines reasoning format and accuracy rewards."""
import logging
import re
from typing import Any, Dict, List, Optional, Union
from .registry import registry
from .reward_function import RewardFunction
logger = logging.getLogger(__name__)
def parse_reasoning_response(text: str) -> Dict[str, Any]:
"""
Parse text to extract thinking section and response section.
Args:
text: Text to parse for thinking and response sections
Returns:
Dictionary with thinking_content, response, and multiple_thinking flag
"""
# Check if text is actually a string
if not isinstance(text, str):
logger.warning(f"Expected string but got {type(text)}: {text}")
return {
"thinking_content": "",
"response": str(text),
"multiple_thinking": False,
}
# Find all thinking blocks
thinking_blocks = re.findall(r"<think>.*?</think>", text, re.DOTALL)
# If there's more than one thinking block, fail
if len(thinking_blocks) > 1:
return {"thinking_content": "", "response": text, "multiple_thinking": True}
# Match the single thinking block if it exists
pattern = r"<think>\s*(.*?)\s*</think>\s*(.*)"
match = re.search(pattern, text, re.DOTALL)
if not match:
return {"thinking_content": "", "response": text, "multiple_thinking": False}
return {
"thinking_content": match.group(1).strip(),
"response": match.group(2).strip(),
"multiple_thinking": False,
}
@registry.register
class FormatReasoningReward(RewardFunction):
"""
Reward function that checks for proper reasoning format.
Checks if completions have:
1. A thinking section in <think> tags
2. A response section after the thinking tags
3. No multiple thinking sections (only one <think> block)
"""
def __init__(
self,
reward_value: float = 0.5,
require_thinking: bool = True,
require_response: bool = True,
allow_multiple_thinking: bool = False,
weight: float = 1.0,
**kwargs,
):
"""
Initialize the format reasoning reward function.
Args:
reward_value: Value to award for correct formatting
require_thinking: Whether to require thinking content
require_response: Whether to require response content
allow_multiple_thinking: Whether to allow multiple thinking sections
weight: Weight for this reward
**kwargs: Additional configuration
"""
super().__init__(weight=weight, **kwargs)
self.reward_value = reward_value
self.require_thinking = require_thinking
self.require_response = require_response
self.allow_multiple_thinking = allow_multiple_thinking
def compute(self, completions: List[Any], **kwargs) -> List[float]:
"""
Check if completions have proper reasoning format.
Args:
completions: List of completions to evaluate
**kwargs: Additional context
Returns:
List of rewards (reward_value for correct format, 0.0 otherwise)
"""
parsed = []
for completion in completions:
try:
content = self.get_content(completion)
parsed.append(parse_reasoning_response(content))
except Exception as e:
logger.error(f"Error parsing response: {e}")
logger.exception(e)
parsed.append(
{"thinking_content": "", "response": "", "multiple_thinking": False}
)
rewards = []
for p in parsed:
try:
# Check if response meets format requirements
valid_format = True
if self.require_thinking and not p["thinking_content"]:
valid_format = False
if self.require_response and not p["response"]:
valid_format = False
if not self.allow_multiple_thinking and p["multiple_thinking"]:
valid_format = False
rewards.append(self.reward_value if valid_format else 0.0)
except Exception as e:
logger.error(f"Error in format reward calculation: {e}")
logger.exception(e)
rewards.append(0.0)
return rewards
@registry.register
class AccuracyXReward(RewardFunction):
"""
Reward function that checks if completion responses contain the solution.
First parses completions to extract the response part after thinking tags,
then checks if the solution string is contained in the response.
"""
def __init__(
self,
exact_match: bool = False,
case_sensitive: bool = False,
reward_value: float = 1.0,
weight: float = 1.0,
**kwargs,
):
"""
Initialize the accuracy reward function.
Args:
exact_match: Whether to require exact match vs. contained
case_sensitive: Whether to do case-sensitive matching
reward_value: Value to award for correct answer
weight: Weight for this reward
**kwargs: Additional configuration
"""
super().__init__(weight=weight, **kwargs)
self.exact_match = exact_match
self.case_sensitive = case_sensitive
self.reward_value = reward_value
def compute(
self,
completions: List[Any],
solution: Optional[Union[str, List[str]]] = None,
**kwargs,
) -> List[float]:
"""
Check if completion responses contain the solution.
Args:
completions: List of completions to evaluate
solution: The solution to check for
**kwargs: Additional context
Returns:
List of rewards (reward_value for correct, 0.0 otherwise)
"""
if solution is None:
logger.warning("No solution provided for accuracy reward")
return [0.0] * len(completions)
parsed_responses = []
for completion in completions:
try:
content = self.get_content(completion)
parsed_responses.append(parse_reasoning_response(content))
except Exception as e:
logger.error(f"Error parsing response: {e}")
logger.exception(e)
parsed_responses.append(
{"thinking_content": "", "response": "", "multiple_thinking": False}
)
rewards = []
# Ensure solution is in the right format
if not isinstance(solution, list):
solution = [solution] * len(parsed_responses)
for resp, sol in zip(parsed_responses, solution):
try:
# Extract solution content if needed
sol_content = self.get_content(sol) if not isinstance(sol, str) else sol
resp_content = resp["response"]
# Do the matching based on settings
if not self.case_sensitive:
sol_content = sol_content.lower()
resp_content = resp_content.lower()
if self.exact_match:
match = resp_content == sol_content
else:
match = sol_content in resp_content
rewards.append(self.reward_value if match else 0.0)
except Exception as e:
logger.error(f"Error in accuracy reward calculation: {e}")
logger.exception(e)
rewards.append(0.0)
return rewards
@registry.register
class R1Reward(RewardFunction):
"""
Combined reward function that rewards both reasoning format and accuracy.
This reward function combines:
1. FormatReasoningReward - rewards for proper <think> and response formatting
2. AccuracyXReward - rewards for having the solution in the response
"""
def __init__(
self,
format_weight: float = 0.5,
accuracy_weight: float = 1.0,
weight: float = 1.0,
**kwargs,
):
"""
Initialize the R1 reward function.
Args:
format_weight: Weight for the format component
accuracy_weight: Weight for the accuracy component
weight: Weight for the overall reward
**kwargs: Additional configuration
"""
super().__init__(weight=weight, **kwargs)
self.format_weight = format_weight
self.accuracy_weight = accuracy_weight
# Create component reward functions
self.format_reward_fn = FormatReasoningReward(
reward_value=1.0, # Use 1.0 here, we'll apply weight in compute
weight=1.0, # This will be overridden
)
self.accuracy_reward_fn = AccuracyXReward(
reward_value=1.0, # Use 1.0 here, we'll apply weight in compute
weight=1.0, # This will be overridden
)
def compute(
self,
completions: List[Any],
solution: Optional[Union[str, List[str]]] = None,
**kwargs,
) -> List[float]:
"""
Calculate combined format and accuracy rewards.
Args:
completions: List of completions to evaluate
solution: The solution to check for accuracy
**kwargs: Additional context
Returns:
List of combined rewards
"""
try:
# Calculate component rewards
format_rewards = self.format_reward_fn.compute(completions, **kwargs)
accuracy_rewards = self.accuracy_reward_fn.compute(
completions, solution=solution, **kwargs
)
# Apply component weights and combine
rewards = [
(f * self.format_weight) + (a * self.accuracy_weight)
for f, a in zip(format_rewards, accuracy_rewards)
]
logger.info(
f"R1 rewards: accuracy={accuracy_rewards}, format={format_rewards}, combined={rewards}"
)
return rewards
except Exception as e:
logger.error(f"Error in r1_reward: {e}")
logger.exception(e)
# Return zero rewards for all completions
return [0.0] * len(completions)
# Legacy function for backward compatibility
def format_reasoning_reward(completions: List[Any], **kwargs) -> List[float]:
"""
Legacy function wrapper for FormatReasoningReward.
Args:
completions: List of completions to evaluate
**kwargs: Additional parameters
Returns:
List of rewards for format quality
"""
reward_fn = FormatReasoningReward(reward_value=0.5)
return reward_fn.compute(completions, **kwargs)
def accuracy_reward(
completions: List[Any], solution: Union[str, List[str]], **kwargs
) -> List[float]:
"""
Legacy function wrapper for AccuracyXReward.
Args:
completions: List of completions to evaluate
solution: The solution to check for
**kwargs: Additional parameters
Returns:
List of rewards for accuracy
"""
reward_fn = AccuracyXReward()
return reward_fn.compute(completions, solution=solution, **kwargs)
def r1_reward(
completions: List[Any], solution: Union[str, List[str]], **kwargs
) -> List[float]:
"""
Legacy function wrapper for R1Reward.
Args:
completions: List of completions to evaluate
solution: The solution to check for
**kwargs: Additional parameters
Returns:
List of combined rewards
"""
reward_fn = R1Reward()
return reward_fn.compute(completions, solution=solution, **kwargs)

View file

@ -0,0 +1,138 @@
"""Reward function for evaluating step-by-step reasoning in completions."""
import logging
import re
from typing import Any, Dict, List, Optional
from .registry import registry
from .reward_function import RewardFunction
logger = logging.getLogger(__name__)
@registry.register
class ReasoningStepsReward(RewardFunction):
r"""
Reward function that evaluates step-by-step reasoning in completions.
Looks for several types of step-by-step reasoning indicators:
1. Numbered step patterns like "Step 1:", "Step 2:"
2. Numbered lists like "1.", "2." at start of line
3. Bullet points with hyphens or asterisks
4. Sequential transition words (First, Second, Next, Finally, etc.)
"""
def __init__(
self,
min_words: int = 10,
min_steps: int = 3,
base_score: float = 0.1,
pattern_weights: Optional[Dict[str, float]] = None,
weight: float = 1.0,
**kwargs,
):
"""
Initialize the reasoning steps reward function.
Args:
min_words: Minimum number of words to consider for base score
min_steps: Number of steps needed for full points in each category
base_score: Base score for having content longer than min_words
pattern_weights: Custom weights for each pattern type (optional)
weight: Weight for this reward
**kwargs: Additional configuration
"""
super().__init__(weight=weight, **kwargs)
self.min_words = min_words
self.min_steps = min_steps
self.base_score = base_score
# Default pattern weights
self.pattern_weights = {
"numbered_steps": 0.5, # Strong indicators
"list_numbers": 0.5, # Strong indicators
"bullet_points": 0.4, # Medium indicators
"transition_words": 0.3, # Weaker indicators
}
# Override with custom weights if provided
if pattern_weights:
self.pattern_weights.update(pattern_weights)
# Patterns for different types of step indicators
self.patterns = {
# Step 1: style numbered steps
"numbered_steps": r"Step\s+\d+[\s:]+",
# Numbered lists (1., 2., etc.)
"list_numbers": r"(?:^|\n)\s*\d+\.\s+",
# Bullet points
"bullet_points": r"(?:^|\n)\s*[\-\*•]\s+",
# Sequential transition words - expanded to include more phrases
"transition_words": r"\b(?:First|Second|Third|Fourth|Fifth|Next|Then|Finally|"
r"Subsequently|Afterward|Lastly|Initially|To begin|Let\'s begin|"
r"I\'ll first|After that|In conclusion|Eventually|Subsequently|"
r"To solve|begin by|understand|analyze|apply|compute)\b",
}
def compute(self, completions: List[Any], **kwargs) -> List[float]:
"""
Calculate reasoning quality scores based on pattern matching.
Args:
completions: List of completions to evaluate
**kwargs: Additional context
Returns:
List of reward scores between 0.0 and 1.0
"""
# Extract content from different possible formats
completion_contents = [
self.get_content(completion) for completion in completions
]
rewards = []
for content in completion_contents:
score = 0.0
pattern_matches = {}
# Check for each type of pattern
for pattern_type, pattern in self.patterns.items():
matches = re.findall(pattern, content, re.IGNORECASE | re.MULTILINE)
pattern_matches[pattern_type] = len(matches)
# Add score based on matches and pattern weight
weight = self.pattern_weights.get(
pattern_type, 0.3
) # Default weight if not specified
score += min(1.0, len(matches) / self.min_steps) * weight
# Add a small base score for any content that has more than just an answer
# This helps differentiate minimal reasoning from no reasoning
if len(content.split()) > self.min_words:
score += self.base_score
# Cap the total score at 1.0
score = min(1.0, score)
rewards.append(score)
logger.info(
f"Reasoning steps reward for completion: {pattern_matches}, score: {score}"
)
return rewards
# Legacy function for backward compatibility
def reasoning_steps_reward(completions: List[Any], **kwargs) -> List[float]:
"""
Legacy function wrapper for ReasoningStepsReward.
Args:
completions: List of completions to evaluate
**kwargs: Additional parameters
Returns:
List of reward scores between 0.0 and 1.0
"""
reward_fn = ReasoningStepsReward()
return reward_fn.compute(completions, **kwargs)

View file

@ -0,0 +1,279 @@
import importlib
import inspect
import logging
from pathlib import Path
from typing import Any, Callable, Dict, List, Set, Type, Union
from .reward_function import RewardFunction
logger = logging.getLogger(__name__)
class RewardRegistry:
"""Registry for reward functions with factory pattern"""
def __init__(self):
self._registry: Dict[str, Type[RewardFunction]] = {}
self._reward_fns_dir = Path(__file__).parent
def register(self, cls=None, name=None):
"""
Register a reward function class.
Can be used as a decorator:
@registry.register
class MyReward(RewardFunction):
...
or with a custom name:
@registry.register(name="custom_name")
class MyReward(RewardFunction):
...
Args:
cls: The reward function class to register
name: Optional custom name to register the class under
Returns:
The registered class (for decorator use)
"""
def _register(cls):
# Validate that it's a subclass of RewardFunction
if not inspect.isclass(cls) or not issubclass(cls, RewardFunction):
raise TypeError(
f"Class {cls.__name__} is not a subclass of RewardFunction"
)
registered_name = name or cls.__name__.lower()
if registered_name.endswith("reward"):
# Convert ClassNameReward to class_name
registered_name = registered_name[:-6].lower()
self._registry[registered_name] = cls
logger.debug(f"Registered reward function: {registered_name}")
return cls
if cls is None:
return _register
return _register(cls)
def register_function(self, name, function):
"""
Register a legacy function-based reward function.
This is temporary for backward compatibility.
Args:
name: Name to register the function under
function: The reward function to register
"""
self._registry[name] = function
def create(self, name_or_config: Union[str, Dict], **kwargs) -> RewardFunction:
"""
Create a reward function from name or config dict.
Args:
name_or_config: Either a string name of a registered reward function,
or a dict with 'type' key and optional parameters
**kwargs: Default parameters that can be overridden by config
Returns:
Instantiated RewardFunction object
"""
if isinstance(name_or_config, str):
# Simple case: just a name
reward_type = name_or_config
reward_params = kwargs
else:
# Dict case with config
reward_config = name_or_config.copy()
reward_type = reward_config.pop("type")
# Handle params dictionary if present
if "params" in reward_config:
params = reward_config.pop("params")
reward_config.update(params)
# Start with kwargs as defaults, override with config
reward_params = {**kwargs}
reward_params.update(reward_config)
# Make sure the reward function is loaded
if reward_type not in self._registry:
self._load_reward_function(reward_type)
reward_class = self._registry[reward_type]
# Handle legacy function-based reward functions
if not inspect.isclass(reward_class):
# This is a function not a class - handle legacy case
return LegacyFunctionWrapper(reward_class, **reward_params)
# Create instance of the reward function class
return reward_class(**reward_params)
def get(self, name: str) -> Union[Type[RewardFunction], Callable]:
"""
Get a reward function class by name.
This is for backward compatibility with the old registry interface.
New code should use create() instead.
Args:
name: The name of the reward function to get
Returns:
The reward function class or function
"""
if name not in self._registry:
self._load_reward_function(name)
return self._registry[name]
def _load_reward_function(self, name: str) -> None:
"""
Load a reward function from a file.
This supports both new class-based and legacy function-based reward functions.
Files can be named either "name.py" or "name_reward.py".
Args:
name: The name of the reward function to load
Raises:
ImportError: If the reward function file is not found or can't be loaded
"""
try:
# Try different file name patterns
base_name = name
if name.endswith("_reward"):
base_name = name[:-7] # Remove "_reward" suffix
module_paths = [
self._reward_fns_dir / f"{base_name}.py",
self._reward_fns_dir / f"{base_name}_reward.py",
self._reward_fns_dir / f"{name}.py",
]
module_path = None
for path in module_paths:
if path.exists():
module_path = path
break
if module_path is None:
raise ImportError(
f"No reward function file found for {name} (tried {', '.join(str(p) for p in module_paths)})"
)
# Generate a unique module name to avoid import conflicts
module_name = f"atroposlib.envs.reward_fns.{module_path.stem}"
spec = importlib.util.spec_from_file_location(module_name, str(module_path))
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
# First try to find a class that inherits from RewardFunction
for obj_name, obj in inspect.getmembers(module):
if (
inspect.isclass(obj)
and issubclass(obj, RewardFunction)
and obj is not RewardFunction
):
# Register the class with the requested name
# This ensures it's accessible by the name the test expects
self.register_function(name, obj)
return
# If no class found, look for functions with matching name patterns
func_patterns = [f"{base_name}", f"{base_name}_reward", "format_reward"]
for func_name in func_patterns:
if hasattr(module, func_name):
self.register_function(name, getattr(module, func_name))
return
raise AttributeError(f"No reward function found in {module_path}")
except Exception as e:
raise ImportError(f"Failed to load reward function {name}: {str(e)}")
def list_registered(self) -> List[str]:
"""Return list of all registered reward function names"""
return list(self._registry.keys())
def load_required_functions(self, config) -> Set[str]:
"""
Load all reward functions required by a config.
This is for backward compatibility with the old registry interface.
Args:
config: The config object to load reward functions from
Returns:
A set of all loaded reward function names
"""
required_funcs = set()
if hasattr(config, "datasets"):
for dataset in config.datasets:
for field in ["dataset_reward_funcs", "reward_funcs"]:
if hasattr(dataset, field) and getattr(dataset, field):
required_funcs.update(getattr(dataset, field))
if hasattr(dataset, "types") and dataset.types:
for type_config in dataset.types:
if "reward_funcs" in type_config:
required_funcs.update(type_config["reward_funcs"])
# Also check for reward_functions and reward_funcs at the top level
for field in ["reward_functions", "reward_funcs"]:
if hasattr(config, field) and getattr(config, field):
field_value = getattr(config, field)
for item in field_value:
if isinstance(item, str):
required_funcs.add(item)
elif isinstance(item, dict) and "type" in item:
required_funcs.add(item["type"])
for func_name in required_funcs:
self.get(func_name)
return required_funcs
class LegacyFunctionWrapper(RewardFunction):
"""Wrapper for legacy function-based reward functions to fit the new class-based interface"""
def __init__(self, func: Callable, weight: float = 1.0, **kwargs):
"""
Initialize with a legacy reward function.
Args:
func: The legacy reward function to wrap
weight: The weight for this reward function
**kwargs: Additional configuration parameters
"""
super().__init__(weight=weight, **kwargs)
self.func = func
self._func_name = func.__name__
@property
def name(self) -> str:
"""Get the name of the wrapped function"""
return self._func_name
def compute(self, completions: List[Any], **kwargs) -> List[float]:
"""Call the wrapped function"""
result = self.func(completions, **kwargs)
# Convert to list if it's a single value
if not isinstance(result, list):
result = [result] * len(completions)
return result
# Global registry instance
registry = RewardRegistry()

View file

@ -0,0 +1,286 @@
"""Reward function for penalizing repetitive content in completions."""
import logging
import re
from collections import Counter
from typing import Any, Dict, List, Optional, Set
from .registry import registry
from .reward_function import RewardFunction
logger = logging.getLogger(__name__)
@registry.register
class RepetitionPenaltyReward(RewardFunction):
"""
Reward function that penalizes repetitive content in completions.
Analyzes various types of repetition:
1. Repeated sentences/paragraphs
2. Repeated words beyond natural frequency
3. Repeated phrases (n-grams)
4. Consecutive word repetition ("stuttering")
5. Repeated sentence beginnings
"""
def __init__(
self,
threshold: float = 0.05,
min_words: int = 10,
min_sentences: int = 2,
short_text_penalty: float = -0.1,
paragraph_repetition_base_penalty: float = -0.6,
component_weights: Optional[Dict[str, float]] = None,
stopwords: Optional[Set[str]] = None,
weight: float = 1.0,
**kwargs,
):
"""
Initialize repetition penalty reward function.
Args:
threshold: Maximum acceptable repetition rate
min_words: Minimum words required for full analysis
min_sentences: Minimum sentences required for full analysis
short_text_penalty: Penalty to apply for very short texts
paragraph_repetition_base_penalty: Base penalty for repeated paragraphs
component_weights: Custom weights for each repetition component
stopwords: Set of common words to ignore in word repetition check
weight: Weight for this reward
**kwargs: Additional configuration
"""
super().__init__(weight=weight, **kwargs)
self.threshold = threshold
self.min_words = min_words
self.min_sentences = min_sentences
self.short_text_penalty = short_text_penalty
self.paragraph_repetition_base_penalty = paragraph_repetition_base_penalty
# Default component weights
self.component_weights = {
"word_repetition": 0.3,
"phrase_repetition": 0.4,
"consecutive_repetition": 0.6,
"beginning_repetition": 0.5,
}
# Override with custom weights if provided
if component_weights:
self.component_weights.update(component_weights)
# Default stopwords (common words to ignore)
self.stopwords = stopwords or {
"this",
"that",
"with",
"from",
"what",
"when",
"where",
"which",
"who",
"whom",
"whose",
"will",
"shall",
"should",
"would",
"could",
"have",
"has",
"had",
"been",
"being",
"than",
"then",
"there",
"these",
"those",
"their",
}
def compute(self, completions: List[Any], **kwargs) -> List[float]:
"""
Calculate penalties for repetitive content.
Args:
completions: List of completions to evaluate
**kwargs: Additional context
Returns:
List of penalty scores between -1.0 and 0.0 (0.0 = no repetition)
"""
# Extract content from different possible formats
completion_contents = [
self.get_content(completion) for completion in completions
]
rewards = []
for content in completion_contents:
try:
# Split text into words and sentences
words = re.findall(r"\b\w+\b", content.lower())
sentences = re.split(r"[.!?]", content)
sentences = [s.strip() for s in sentences if s.strip()]
# For very short responses, apply a small penalty to be safe
if len(words) < self.min_words or len(sentences) < self.min_sentences:
logger.info("Text too short for detailed repetition analysis")
rewards.append(self.short_text_penalty)
continue
# Check for identical sentences (paragraph repetition)
sentence_counts = Counter(sentences)
repeated_sentences = sum(
count - 1
for sentence, count in sentence_counts.items()
if count > 1
and len(sentence.split()) > 3 # Only count substantial sentences
)
# Severe penalty for repeated paragraphs or sentences
if repeated_sentences > 0:
paragraph_repetition_score = min(
1.0, repeated_sentences / len(sentences)
)
# Apply a strong penalty for paragraph repetition
penalty = self.paragraph_repetition_base_penalty - (
paragraph_repetition_score * 0.4
)
logger.info(
f"Paragraph repetition detected: {repeated_sentences} instances"
)
rewards.append(penalty)
continue
# Check for word repetition
word_counts = Counter(words)
content_words = [
w
for w in word_counts.keys()
if len(w) > 3 and w not in self.stopwords
]
repeated_words = sum(
count - 1 # Count repetitions beyond the first occurrence
for word, count in word_counts.items()
if count > 2
and word in content_words # Only count meaningful words
)
word_repetition_rate = repeated_words / len(words) if words else 0
# Check for phrase repetition (n-grams)
phrase_repetition = 0
if len(words) >= 5:
for n in range(3, 6): # Check 3, 4, and 5-grams
if len(words) >= n:
ngrams = [
" ".join(words[i : i + n])
for i in range(len(words) - n + 1)
]
ngram_counts = Counter(ngrams)
repeated_ngrams = sum(
count - 1
for phrase, count in ngram_counts.items()
if count > 1
)
phrase_repetition += repeated_ngrams * (
n / 3
) # Weight by n-gram size
phrase_repetition_rate = phrase_repetition / len(words) if words else 0
# Check for consecutive word repetition (stuttering)
consecutive_repeats = 0
for i in range(1, len(words)):
if (
words[i] == words[i - 1] and len(words[i]) > 2
): # Ignore short words
consecutive_repeats += 1
consecutive_repetition_rate = (
consecutive_repeats / len(words) if words else 0
)
# Check for beginning of sentence repetition
sentence_beginnings = []
for sentence in sentences:
words = sentence.split()
if len(words) >= 3:
sentence_beginnings.append(" ".join(words[:3]))
beginning_counts = Counter(sentence_beginnings)
repeated_beginnings = sum(
count - 1
for beginning, count in beginning_counts.items()
if count > 1
)
beginning_repetition_rate = (
repeated_beginnings / len(sentences) if sentences else 0
)
# Calculate overall repetition score with weights
repetition_score = (
(
word_repetition_rate
* self.component_weights.get("word_repetition", 0.3)
)
+ (
phrase_repetition_rate
* self.component_weights.get("phrase_repetition", 0.4)
)
+ (
consecutive_repetition_rate
* self.component_weights.get("consecutive_repetition", 0.6)
)
+ (
beginning_repetition_rate
* self.component_weights.get("beginning_repetition", 0.5)
)
)
# Calculate penalty: 0.0 for no repetition, up to -1.0 for high repetition
if repetition_score <= self.threshold:
penalty = 0.0
else:
# Scale penalty from 0 to -1 based on how much repetition exceeds threshold
penalty = -min(
1.0, (repetition_score - self.threshold) / (1 - self.threshold)
)
# Make sure we have at least some penalty for any repetition
penalty = min(penalty, -0.1)
logger.info(
f"Word rep: {word_repetition_rate:.3f}, Phrase rep: {phrase_repetition_rate:.3f}, "
f"Consecutive rep: {consecutive_repetition_rate:.3f}, ",
f"Sentence rep: {beginning_repetition_rate:.3f}, "
f"Overall score: {repetition_score:.3f}, penalty: {penalty:.3f}",
)
rewards.append(penalty)
except Exception as e:
logger.error(f"Error in repetition penalty calculation: {e}")
logger.exception(e)
rewards.append(-0.2) # Apply small penalty on error
return rewards
# Legacy function for backward compatibility
def repetition_penalty_reward(
completions: List[Any], threshold: float = 0.05, **kwargs
) -> List[float]:
"""
Legacy function wrapper for RepetitionPenaltyReward.
Args:
completions: List of completions to evaluate
threshold: Maximum acceptable repetition rate (default 0.05)
Lower threshold means stricter penalties for repetition
**kwargs: Additional parameters
Returns:
List of rewards between -1.0 and 0.0, where 0.0 means no repetition
"""
reward_fn = RepetitionPenaltyReward(threshold=threshold)
return reward_fn.compute(completions, **kwargs)

View file

@ -0,0 +1,125 @@
import logging
from abc import ABC, abstractmethod
from typing import Any, List, Optional
logger = logging.getLogger(__name__)
class RewardFunction(ABC):
"""Abstract base class for all reward functions"""
def __init__(self, weight: float = 1.0, name: Optional[str] = None, **kwargs):
"""
Initialize reward function with a weight and optional configuration.
Args:
weight: Importance factor when combining with other rewards
name: Optional custom name for this reward function instance
**kwargs: Additional configuration parameters specific to the reward function
"""
self.weight = weight
self._name = name
self.config = kwargs
self.wandb_logger = None
@property
def name(self) -> str:
"""Unique identifier for this reward function"""
return self._name or self.__class__.__name__.lower()
@abstractmethod
def compute(self, completions: List[Any], **kwargs) -> List[float]:
"""
Compute reward scores for the given completions.
Args:
completions: List of completions to evaluate
**kwargs: Additional context like solution, ground_truth, etc.
Returns:
List of reward scores, one for each completion
"""
pass
def __call__(self, completions: List[Any], **kwargs) -> List[float]:
"""Wrapper that applies weight to the computed rewards"""
try:
rewards = self.compute(completions, **kwargs)
# Apply weight
weighted_rewards = [r * self.weight for r in rewards]
# Log to wandb if available
if self.wandb_logger:
self.log_metrics(rewards, weighted_rewards)
return weighted_rewards
except Exception as e:
logger.error(f"Error in reward function {self.name}: {e}")
logger.exception(e)
return [0.0] * len(completions)
def set_wandb_logger(self, logger):
"""Set the WandB logger for this reward function"""
self.wandb_logger = logger
def log_metrics(self, raw_rewards: List[float], weighted_rewards: List[float]):
"""Log reward metrics to WandB"""
if not self.wandb_logger or not raw_rewards:
return
metrics = {
f"reward/{self.name}/mean_raw": sum(raw_rewards) / len(raw_rewards),
f"reward/{self.name}/mean_weighted": sum(weighted_rewards)
/ len(weighted_rewards),
f"reward/{self.name}/min": min(raw_rewards),
f"reward/{self.name}/max": max(raw_rewards),
}
self.wandb_logger.log(metrics)
@staticmethod
def get_content(completion: Any) -> str:
"""
Extract content from different completion formats.
Supports:
- String completions
- Dict with {"role": "assistant", "content": "text"}
- Dict with {"message": {"role": "assistant", "content": "text"}}
- List of messages where one has role "assistant"
Args:
completion: The completion in any supported format
Returns:
The extracted content as a string
"""
if isinstance(completion, str):
return completion
elif isinstance(completion, dict):
if (
"role" in completion
and completion["role"] == "assistant"
and "content" in completion
):
return completion["content"]
if "message" in completion and isinstance(completion["message"], dict):
if (
"role" in completion["message"]
and completion["message"]["role"] == "assistant"
and "content" in completion["message"]
):
return completion["message"]["content"]
elif isinstance(completion, list) and len(completion) > 0:
# Look for assistant messages
for msg in completion:
if (
isinstance(msg, dict)
and "role" in msg
and msg["role"] == "assistant"
and "content" in msg
):
return msg["content"]
# If no assistant content found, return empty string
return ""

View file

@ -0,0 +1,296 @@
import asyncio
import collections
import time
from asyncio import exceptions
from typing import Optional
import aiohttp
import numpy as np
import openai
from openai.types.chat.chat_completion import ChatCompletion
from openai.types.completion import Completion
from pydantic import BaseModel, Field
from tenacity import retry, stop_after_attempt, wait_random_exponential
class OpenaiConfig(BaseModel):
"""
Configuration for the server manager.
"""
api_key: Optional[str] = Field(
default=None, description="API key for OpenAI API. Use 'x' for local servers."
)
base_url: Optional[str] = Field(
default=None,
description="URL of the API endpoint. None if using official OpenAI API, otherwise local server URL.",
)
timeout: int = Field(
default=1200, description="Timeout for the request in seconds."
)
num_max_requests_at_once: int = Field(
default=512,
description="Maximum number of concurrent requests. Note: You should divide this by the n kwarg.",
)
num_requests_for_eval: int = Field(
default=64, description="Maximum number of concurrent requests for evaluation."
)
model_name: str = Field(
default="default",
description="The model name to use. Required for both OpenAI and local models.",
)
rolling_buffer_length: int = Field(
default=1000, description="Length of the rolling buffer to store metrics."
)
class AsyncSemWithAdaptiveWeight(asyncio.Semaphore):
def __init__(self, value: int):
super().__init__(value=value)
self.max_val = value
self.weight = 1.0
def update_weight(self, weight: float) -> None:
self.weight = weight
def min_val(self):
return self.max_val * (1.0 - self.weight)
def release(self):
"""Release a semaphore, incrementing the internal counter by one.
When it was zero on entry and another coroutine is waiting for it to
become larger than zero again, wake up that coroutine.
If weight is set, it'll only wake up next if the value is greater than the max_val * weight
"""
self._value += 1
if self._value > self.min_val():
self._wake_up_next()
def locked(self):
"""Returns True if semaphore cannot be acquired immediately."""
return self._value <= self.min_val() or (
any(not w.cancelled() for w in (self._waiters or ()))
)
async def acquire(self):
"""Acquire a semaphore.
If the internal counter is larger than zero on entry,
decrement it by one and return True immediately. If it is
zero on entry, block, waiting until some other coroutine has
called release() to make it larger than 0, and then return
True.
"""
if not self.locked():
self._value -= 1
return True
if self._waiters is None:
self._waiters = collections.deque()
fut = self._get_loop().create_future()
self._waiters.append(fut)
# Finally block should be called before the CancelledError
# handling as we don't want CancelledError to call
# _wake_up_first() and attempt to wake up itself.
try:
try:
await fut
finally:
self._waiters.remove(fut)
except exceptions.CancelledError:
if not fut.cancelled():
self._value += 1
self._wake_up_next()
raise
if self._value > self.min_val():
self._wake_up_next()
return True
class OpenAIServer:
def __init__(self, config: OpenaiConfig):
self.config = config
self.openai = openai.AsyncClient(
api_key=config.api_key,
base_url=config.base_url,
timeout=config.timeout,
)
self.sem = AsyncSemWithAdaptiveWeight(config.num_max_requests_at_once)
self.eval_sem = AsyncSemWithAdaptiveWeight(config.num_requests_for_eval)
self.server_healthy = True
self.attempts_list = []
self.request_timings = []
# in case eval is much different, we should keep different buffers
self.eval_attempts_list = []
self.eval_request_timings = []
self.check_task = None
self.initialized = False
async def update_weight(self, weight: float) -> None:
# need to update sems
self.sem.update_weight(weight)
self.eval_sem.update_weight(weight)
async def check_server_status_task(self):
while True:
try:
await self.openai.completions.create(
model=self.config.model_name,
prompt="hi",
max_tokens=1,
)
self.server_healthy = True
except (
aiohttp.ClientError,
openai.OpenAIError,
openai.APITimeoutError,
Exception,
):
self.server_healthy = False
await asyncio.sleep(1)
async def wandb_metrics(
self, metrics_dict: Optional[dict], server_name: Optional[str]
):
if server_name is None:
server_name = "server"
if len(self.request_timings) > 0:
metrics_dict[f"server/{server_name}_request_time_avg"] = np.mean(
self.request_timings
)
metrics_dict[f"server/{server_name}_request_time_std"] = np.std(
self.request_timings
)
metrics_dict[f"server/{server_name}_request_time_99p"] = np.percentile(
self.request_timings, 99
)
if len(self.eval_request_timings) > 0:
metrics_dict[f"server/{server_name}_eval_request_time_avg"] = np.mean(
self.eval_request_timings
)
metrics_dict[f"server/{server_name}_eval_request_time_std"] = np.std(
self.eval_request_timings
)
metrics_dict[f"server/{server_name}_eval_request_time_99p"] = np.percentile(
self.eval_request_timings, 99
)
if len(self.attempts_list) > 0:
metrics_dict[f"server/{server_name}_average_num_attempts"] = np.mean(
self.attempts_list
)
if len(self.eval_attempts_list) > 0:
metrics_dict[f"server/{server_name}_eval_retry_rate"] = np.mean(
self.eval_attempts_list
)
return metrics_dict
@retry(
stop=stop_after_attempt(3), wait=wait_random_exponential(multiplier=1, max=10)
)
async def _chat_comp(self, stat_dict, **kwargs) -> ChatCompletion:
while not self.server_healthy:
await asyncio.sleep(1)
async with self.sem:
if stat_dict.get("start", None) is None:
stat_dict["start"] = time.time()
stat_dict["attempts"] += 1
completions = await self.openai.chat.completions.create(**kwargs)
stat_dict["end"] = time.time()
return completions
@retry(
stop=stop_after_attempt(3), wait=wait_random_exponential(multiplier=1, max=10)
)
async def _chat_eval(self, stat_dict, **kwargs) -> ChatCompletion:
while not self.server_healthy:
await asyncio.sleep(1)
async with self.eval_sem:
if stat_dict.get("start", None) is None:
stat_dict["start"] = time.time()
stat_dict["attempts"] += 1
completions = await self.openai.chat.completions.create(**kwargs)
stat_dict["end"] = time.time()
return completions
@retry(
stop=stop_after_attempt(3), wait=wait_random_exponential(multiplier=1, max=10)
)
async def chat_completion(self, **kwargs) -> ChatCompletion:
if not self.initialized:
if (
self.config.base_url is not None
): # skip health check if using OpenAI API
self.check_task = asyncio.create_task(self.check_server_status_task())
else:
self.server_healthy = True
self.initialized = True
kwargs["model"] = self.config.model_name
split = kwargs.pop("split", "train")
stat_dict = {}
stat_dict["attempts"] = 0
if split == "train":
ret_data = await self._chat_comp(stat_dict, **kwargs)
self.request_timings.append(stat_dict["end"] - stat_dict["start"])
self.attempts_list.append(stat_dict["attempts"])
else:
# Give separate eval workers, if desired, gotta go fast for those evals
ret_data = await self._chat_eval(stat_dict, **kwargs)
self.eval_request_timings.append(stat_dict["end"] - stat_dict["start"])
self.eval_attempts_list.append(stat_dict["attempts"])
return ret_data
@retry(
stop=stop_after_attempt(3), wait=wait_random_exponential(multiplier=1, max=10)
)
async def _comp(self, stat_dict, **kwargs) -> Completion:
while not self.server_healthy:
await asyncio.sleep(1)
async with self.sem:
if stat_dict.get("start", None) is None:
stat_dict["start"] = time.time()
stat_dict["attempts"] += 1
completions = await self.openai.completions.create(**kwargs)
stat_dict["end"] = time.time()
return completions
@retry(
stop=stop_after_attempt(3), wait=wait_random_exponential(multiplier=1, max=10)
)
async def _comp_eval(self, stat_dict, **kwargs) -> Completion:
while not self.server_healthy:
await asyncio.sleep(1)
async with self.eval_sem:
if stat_dict.get("start", None) is None:
stat_dict["start"] = time.time()
stat_dict["attempts"] += 1
completions = await self.openai.completions.create(**kwargs)
stat_dict["end"] = time.time()
return completions
async def completion(self, **kwargs) -> Completion:
if not self.initialized:
if (
self.config.base_url is not None
): # skip health check if using OpenAI API
self.check_task = asyncio.create_task(self.check_server_status_task())
else:
self.server_healthy = True
self.initialized = True
kwargs["model"] = self.config.model_name
split = kwargs.pop("split", "train")
stat_dict = {}
stat_dict["attempts"] = 0
if split == "train":
ret_data = await self._comp(stat_dict, **kwargs)
self.request_timings.append(stat_dict["end"] - stat_dict["start"])
self.attempts_list.append(stat_dict["attempts"])
else:
# Give separate eval workers, if desired, gotta go fast for those evals
ret_data = await self._comp_eval(stat_dict, **kwargs)
self.eval_request_timings.append(stat_dict["end"] - stat_dict["start"])
self.eval_attempts_list.append(stat_dict["attempts"])
return ret_data

View file

@ -0,0 +1,146 @@
import asyncio
from typing import Dict, List, Literal, Union
from openai.types.chat.chat_completion import (
ChatCompletion,
ChatCompletionMessage,
Choice,
)
from openai.types.completion import Completion, CompletionChoice
def create_chat_completion(
resp: Union[str, List[str]],
n: int = 1,
finish_reason: Union[
Literal["stop", "length", "tool_calls", "content_filter", "function_call"],
List[
Literal["stop", "length", "tool_calls", "content_filter", "function_call"]
],
] = "stop",
) -> ChatCompletion:
"""
Simple helper for creating a ChatCompletion object, if you need it
:param resp:
:param n:
:param finish_reason:
:return:
"""
choices = [
Choice(
finish_reason=(
finish_reason if isinstance(finish_reason, str) else finish_reason[i]
),
index=i,
message=ChatCompletionMessage(
content=resp if isinstance(resp, str) else resp[i],
role="assistant",
),
)
for i in range(n)
]
return ChatCompletion(
id="test_id",
created=0,
model="test_model",
object="chat.completion",
choices=choices,
)
def create_completion(
resp: Union[str, List[str]],
n: int = 1,
finish_reason: Union[
Literal["stop", "length", "tool_calls", "content_filter", "function_call"],
List[
Literal["stop", "length", "tool_calls", "content_filter", "function_call"]
],
] = "stop",
) -> Completion:
"""
Simple helper for creating a Completion object, if you need it
:param resp:
:param n:
:param finish_reason:
:return:
"""
choices = [
CompletionChoice(
finish_reason=(
finish_reason if isinstance(finish_reason, str) else finish_reason[i]
),
index=i,
text=resp if isinstance(resp, str) else resp[i],
)
for i in range(n)
]
return Completion(
id="test_id",
created=0,
model="test_model",
object="text_completion",
choices=choices,
)
class ServerHarness:
def __init__(self):
self.response_map = dict()
self.sem = asyncio.Semaphore(1)
self.eval_sem = asyncio.Semaphore(1)
pass
def conv_to_dictkey(self, input_message: List[Dict[str, str]]) -> str:
dictkey = list()
for item in input_message:
dictkey.append(f"role:{item['role']}")
dictkey.append(f"content:{item['content']}")
return "\n".join(dictkey)
async def update_weight(self, weight):
pass
def set_desired_response(
self, input_message: List[Dict[str, str]], desired_response: ChatCompletion
):
dictkey = self.conv_to_dictkey(input_message)
self.response_map[dictkey] = desired_response
def set_desired_completion(self, input_message: str, completion: Completion):
self.response_map[input_message] = completion
async def chat_completion(self, *args, **kwargs) -> ChatCompletion:
messages = kwargs.get("messages")
dictkey = self.conv_to_dictkey(messages)
try:
return self.response_map.get(dictkey)
except KeyError as e:
raise KeyError(f"KeyError: {e} for key:\n{dictkey}")
async def completion(self, *args, **kwargs) -> Completion:
prompt = kwargs.get("prompt")
try:
return self.response_map.get(prompt)
except KeyError as e:
raise KeyError(f"KeyError: {e} for key:\n{prompt}")
if __name__ == "__main__":
async def main():
test_compl = create_chat_completion("hello")
harness = ServerHarness()
harness.set_desired_response([{"role": "user", "content": "hi"}], test_compl)
print(harness.response_map)
print(harness.conv_to_dictkey([{"role": "user", "content": "hi"}]))
print(
await harness.chat_completion(messages=[{"role": "user", "content": "hi"}])
)
# now, let's test the completion
test_completion = create_completion("\nhello")
harness.set_desired_completion("hi", test_completion)
print(harness.response_map)
print(await harness.completion(prompt="hi"))
asyncio.run(main())

View file

@ -0,0 +1,208 @@
import asyncio
import os
from contextlib import asynccontextmanager
from typing import AsyncGenerator, List, Union
from openai.types.chat.chat_completion import ChatCompletion
from openai.types.completion import Completion
from pydantic import BaseModel, Field
from atroposlib.envs.server_handling.openai_server import OpenaiConfig, OpenAIServer
from atroposlib.envs.server_handling.server_harness import ServerHarness
class ServerManagerConfig(BaseModel):
slurm: bool = Field(
default=True, description="Whether environment is running on slurm or not."
)
testing: bool = Field(
default=False, description="If set to True, environment uses mock OpenAI data."
)
class ServerBaseline(BaseModel):
"""
Baseline configuration for server information. If local, uses ports 9004-9007 for the servers,
assuming a 1:1 split of GPUs.
"""
timeout: int = Field(
default=1200, description="Timeout for the request in seconds."
)
num_max_requests_at_once: int = Field(
default=512,
description="Maximum number of concurrent requests. You should divide this by the n kwarg.",
)
num_requests_for_eval: int = Field(
default=64, description="Maximum number of concurrent requests for evaluation."
)
model_name: str = Field(
default="default",
description="The model name to use. Only works with sglang, please provide the model name.",
)
rolling_buffer_length: int = Field(
default=1000, description="Length of the rolling buffer to store metrics."
)
class ServerManager:
def __init__(
self,
configs: Union[ServerBaseline, List[OpenaiConfig]],
slurm=False,
testing=False,
):
if testing:
# testing :)
self.servers = [ServerHarness()]
return
if isinstance(configs, ServerBaseline):
urls = []
if os.environ.get("SLURM_JOB_NODELIST", None) is not None:
nodelist = (
os.popen(
f'scontrol show hostnames {os.environ["SLURM_JOB_NODELIST"]}'
)
.read()
.split("\n")
)
nodelist = [node for node in nodelist if node != ""]
if len(nodelist) < 2:
# localhost!
for i in range(4):
urls.append(f"http://localhost:{9000 + i + 4}/v1")
else:
num_training_nodes = int(os.environ.get("NUM_TRAINING_NODES"))
for node in nodelist[num_training_nodes:]:
for i in range(8 // os.environ.get("INFER_TP", 1)):
urls.append(f"http://{node}:{9000 + i}/v1")
openai_configs = []
else:
# localhost!
for i in range(4):
urls.append(f"http://localhost:{9000 + i + 4}/v1")
openai_configs = []
for url in urls:
openai_configs.append(
OpenaiConfig(
base_url=url,
timeout=configs.timeout,
num_max_requests_at_once=configs.num_max_requests_at_once,
num_requests_for_eval=configs.num_requests_for_eval,
model_name=configs.model_name,
rolling_buffer_length=configs.rolling_buffer_length,
api_key="x",
)
)
self.servers = [OpenAIServer(config) for config in openai_configs]
if not slurm:
self.servers = [OpenAIServer(config) for config in configs]
else:
nodelist = (
os.popen(f'scontrol show hostnames {os.environ["SLURM_JOB_NODELIST"]}')
.read()
.split("\n")
)
nodelist = [node for node in nodelist if node != ""]
if len(nodelist) < 2:
print(
"Not enough nodes to distribute to, assuming single node"
" and you've setup your sglang appropriately."
)
self.servers = [OpenAIServer(config) for config in configs]
return
urls = []
num_training_nodes = int(os.environ.get("NUM_TRAINING_NODES"))
for node in nodelist[num_training_nodes:]:
if node == "":
continue
for i in range(8 // os.environ.get("INFER_TP", 1)):
urls.append(f"http://{node}:{9000 + i}/v1")
# assume at least one good config is passed in
new_configs = []
for i in range(len(urls)):
new_conf = configs[0].model_copy(deep=True)
new_conf.base_url = urls[i]
new_configs.append(new_conf)
self.servers = [OpenAIServer(config) for config in new_configs]
async def update_weight(self, weight: float):
for server in self.servers:
await server.update_weight(weight)
async def wait_for_sem(self, is_training):
"""
Wait for a server to be available. This is used to prevent the client from
overwhelming the server with requests.
"""
if is_training:
eval_vals = [
(
max(0, server.eval_sem._value - server.eval_sem.min_val())
if server.eval_sem._value != server.eval_sem.max_val
else 0
)
for server in self.servers
]
sem_vals = [
max(0, (server.sem._value - server.sem.min_val()) - eval_val)
for server, eval_val in zip(self.servers, eval_vals)
]
else:
sem_vals = [
max(0, server.eval_sem._value - server.eval_sem.min_val())
for server in self.servers
]
while all([sem_val <= 0 for sem_val in sem_vals]):
# None available... wait
await asyncio.sleep(1)
async def chat_completion(self, **kwargs) -> ChatCompletion:
is_train = kwargs.get("split", "train") == "train"
most_available_server = 0
most_available_server_num_slots = -1
await self.wait_for_sem(is_train)
for i, server in enumerate(self.servers):
if not server.server_healthy:
continue
if (
server.sem._value if is_train else server.eval_sem._value
) > most_available_server_num_slots:
most_available_server = i
most_available_server_num_slots = (
server.sem._value if is_train else server.eval_sem._value
)
return await self.servers[most_available_server].chat_completion(**kwargs)
async def completion(self, **kwargs) -> Completion:
is_train = kwargs.get("split", "train") == "train"
most_available_server = 0
most_available_server_num_slots = -1
await self.wait_for_sem(is_train)
for i, server in enumerate(self.servers):
if not server.server_healthy:
continue
if (
server.sem._value if is_train else server.eval_sem._value
) > most_available_server_num_slots:
most_available_server = i
most_available_server_num_slots = (
server.sem._value if is_train else server.eval_sem._value
)
return await self.servers[most_available_server].completion(**kwargs)
@asynccontextmanager
async def dedicated_server(self) -> AsyncGenerator[OpenAIServer, None]:
most_available_server = 0
most_available_server_num_slots = -1
for i, server in enumerate(self.servers):
if not server.server_healthy:
continue
if server.sem._value > most_available_server_num_slots:
most_available_server = i
most_available_server_num_slots = server.sem._value
async with self.servers[most_available_server].sem:
try:
yield self.servers[most_available_server]
finally:
pass

View file

@ -0,0 +1,169 @@
import math
import pytest
import torch
# Adjust the import below if your functions are in a different module.
from atroposlib.utils.advantages import (
allclose_to_first,
compute_discounted_returns,
compute_grpo_process_supervision_advantages,
compute_stats,
)
def test_allclose_to_first_all_close():
"""Test that identical values return True."""
values = [1.0, 1.0, 1.0]
result = allclose_to_first(values)
assert result is True
def test_allclose_to_first_vector():
"""Test that return_vector=True returns a tensor of booleans."""
values = [1.0, 1.000000001, 1.000000002]
result = allclose_to_first(values, return_vector=True)
assert isinstance(result, torch.Tensor)
# All comparisons should be True.
assert torch.all(result)
def test_allclose_to_first_not_close():
"""Test that values which are not close yield False."""
values = [1.0, 1.0, 1.1]
result = allclose_to_first(values)
assert result is False
def test_allclose_to_first_nan():
"""Test handling of NaN values with equal_nan parameter."""
values = [float("nan"), float("nan")]
# With equal_nan False, the result should be False.
result = allclose_to_first(values, equal_nan=False)
assert result is False
# With equal_nan True, NaNs are treated as equal.
result = allclose_to_first(values, equal_nan=True)
assert result is True
def test_compute_stats():
"""Test compute_stats with a nested list of numbers."""
data = [1, 2, 3, [4, 5]]
stats = compute_stats(data)
# mean = (1+2+3+4+5)/5 = 3.0
assert math.isclose(stats["mean"], 3.0, rel_tol=1e-5)
# variance = (11 - 9) = 2.0, since average of squares = 55/5 = 11 and mean^2 = 9.
assert math.isclose(stats["var"], 2.0, rel_tol=1e-5)
def test_compute_stats_empty():
"""Test that an empty list raises a ValueError."""
with pytest.raises(ValueError):
compute_stats([])
def test_compute_stats_jagged():
"""Test compute_stats with a deeper, jagged nested list."""
data = [[1, 2], 3, [4, [5, 6]]]
stats = compute_stats(data)
expected_mean = (1 + 2 + 3 + 4 + 5 + 6) / 6 # 21/6 = 3.5
expected_var = ((1**2 + 2**2 + 3**2 + 4**2 + 5**2 + 6**2) / 6) - expected_mean**2
assert math.isclose(stats["mean"], expected_mean, rel_tol=1e-5)
assert math.isclose(stats["var"], expected_var, rel_tol=1e-5)
def test_compute_discounted_returns():
"""Test compute_discounted_returns with a tensor input."""
rewards = torch.tensor([1.0, 1.0, 1.0])
gamma = 0.9
returns = compute_discounted_returns(rewards, gamma)
# For a 3-element vector:
# t=2: 1.0
# t=1: 1.0 + 0.9*1.0 = 1.9
# t=0: 1.0 + 0.9*1.9 = 2.71
expected = torch.tensor([2.71, 1.9, 1.0])
assert torch.allclose(returns, expected, rtol=1e-5, atol=1e-8)
def test_compute_discounted_returns_list_input():
"""Test compute_discounted_returns when the input is a list."""
rewards = [1, 1, 1]
gamma = 0.0 # With gamma=0, the returns should equal the rewards.
returns = compute_discounted_returns(rewards, gamma)
expected = torch.tensor([1.0, 1.0, 1.0])
assert torch.allclose(returns, expected, rtol=1e-5, atol=1e-8)
def test_compute_grpo_process_supervision_advantages_cumsum():
"""
Test compute_grpo_process_supervision_advantages with gamma=None,
which should now compute a reversed cumulative sum on normalized rewards.
For each trajectory, the expected advantage at index i is the sum of normalized rewards from i to the end.
"""
rewards = [[1, 2, 3], [4, 5]]
advantages = compute_grpo_process_supervision_advantages(rewards, gamma=None)
# Compute normalized rewards using flattened stats (mean=3, var=2 so std=sqrt(2))
sqrt2 = math.sqrt(2)
# For trajectory 1, normalized rewards:
# Reversed cumulative sum:
# index 0: sum(traj1) = (-2/sqrt2) + (-1/sqrt2) + 0 = -3/sqrt2
# index 1: sum(traj1[1:]) = (-1/sqrt2) + 0 = -1/sqrt2
# index 2: sum(traj1[2:]) = 0
expected_traj1 = [-3 / sqrt2, -1 / sqrt2, 0]
# For trajectory 2, normalized rewards:
# Reversed cumulative sum:
# index 0: (1/sqrt2) + (2/sqrt2) = 3/sqrt2
# index 1: (2/sqrt2)
expected_traj2 = [3 / sqrt2, 2 / sqrt2]
adv1 = advantages[0].tolist()
adv2 = advantages[1].tolist()
for computed, expected in zip(adv1, expected_traj1):
assert math.isclose(
computed, expected, rel_tol=1e-5
), f"Computed {computed} vs expected {expected} in trajectory 1"
for computed, expected in zip(adv2, expected_traj2):
assert math.isclose(
computed, expected, rel_tol=1e-5
), f"Computed {computed} vs expected {expected} in trajectory 2"
def test_compute_grpo_process_supervision_advantages_discounted():
"""
Test compute_grpo_process_supervision_advantages with a provided gamma,
which should compute discounted returns on normalized rewards.
"""
rewards = [[1, 2, 3], [4, 5]]
gamma = 0.9
advantages = compute_grpo_process_supervision_advantages(rewards, gamma=gamma)
sqrt2 = math.sqrt(2)
# Normalized first trajectory:
a1 = (1 - 3) / sqrt2 # -2/sqrt2
a2 = (2 - 3) / sqrt2 # -1/sqrt2
a3 = (3 - 3) / sqrt2 # 0
# Discounted returns for trajectory 1:
# t=2: a3
# t=1: a2 + gamma * a3 = a2
# t=0: a1 + gamma * (a2 + gamma * a3) = a1 + gamma * a2
expected_traj1 = [a1 + gamma * a2, a2, a3]
# Normalized second trajectory:
b1 = (4 - 3) / sqrt2 # 1/sqrt2
b2 = (5 - 3) / sqrt2 # 2/sqrt2
# Discounted returns for trajectory 2:
# t=1: b2
# t=0: b1 + gamma * b2
expected_traj2 = [b1 + gamma * b2, b2]
adv1 = advantages[0].tolist()
adv2 = advantages[1].tolist()
for computed, expected in zip(adv1, expected_traj1):
assert math.isclose(computed, expected, rel_tol=1e-5)
for computed, expected in zip(adv2, expected_traj2):
assert math.isclose(computed, expected, rel_tol=1e-5)
def test_compute_grpo_process_supervision_advantages_std_tol():
"""Test that a constant reward trajectory raises ValueError due to low std."""
rewards = [[1, 1, 1]]
with pytest.raises(ValueError):
compute_grpo_process_supervision_advantages(rewards)

View file

View file

@ -0,0 +1,28 @@
import random
from atroposlib.api.utils import grab_exact_from_heterogeneous_queue
def test_grab_exact_from_heterogeneous_queue():
"randomly samples from the space of potential inputs to grab_exact_from_heterogeneous_queue"
for random_bs in range(10000):
bs = 64 * random.randint(1, 20)
queue = []
for i in range(random.randint(1, 100)):
# queue.append(
# {
# "tokens": [[2 * i] for _ in range(2)],
# }
# )
queue.append(
{
"tokens": [[2 * i + 1] for _ in range(8)],
}
)
batch, queue = grab_exact_from_heterogeneous_queue(queue, bs)
if random_bs == 0:
print(batch)
if batch is not None:
assert (
sum(len(item["tokens"]) for item in batch) == bs
), f"expected batch size {bs}, got {len(batch)}"

View file

@ -0,0 +1,70 @@
from typing import Any, Dict, List, Literal, Optional, TypedDict
from openai.types.chat import ChatCompletionContentPartParam
Content = str | list[ChatCompletionContentPartParam]
Item = Any
number = int | float
UUID = str
class Message(TypedDict):
role: Literal["system", "user", "assistant", "tool"]
content: Content
reward: Optional[float]
class AgentStep(TypedDict, total=False):
"""Represents a single step in an agent's history.
Attributes:
step: The step number.
messages: A list of messages exchanged during the step.
reward: The reward received at this step.
"""
step: int
messages: List[Message]
reward: float
# AgentHistory maps agent ids (e.g. "Player 1", "Player 2") to their respective list of steps.
AgentHistory = Dict[str, List[AgentStep]]
class Observation(TypedDict):
"""Represents an observation in a game history.
Attributes:
raw: The raw observation data (as a dictionary).
rendered: The rendered string of the observation suitable for input into an LLM.
"""
raw: Dict[str, Any]
rendered: Content
class GameStep(TypedDict):
"""Represents a single step in a game history. Essentially an (s,a,r) triple with metadata.
Attributes:
step: The step number.
agent: The agent who took the action (optional for final steps).
observation: The observation at this step.
action: The action taken by the agent (if any).
reward: The reward received; can be a float or a dictionary mapping agent names to rewards.
done: A flag indicating whether the game has ended after this step.
info: Additional information related to the step.
"""
step: int
agent_id: str
observation: Observation
action: str
reward: float | Dict[str, float]
done: bool
info: Dict[str, Any]
# GameHistory is represented as a list of game steps.
GameHistory = List[GameStep]

View file

@ -0,0 +1,7 @@
"""
Utility functions and classes for the atroposlib package.
"""
from .config_handler import ConfigHandler
__all__ = ["ConfigHandler"]

View file

@ -0,0 +1,173 @@
from typing import Sequence
import torch
from atroposlib.type_definitions import number
TensorLike = torch.Tensor | Sequence[torch.Tensor] | Sequence[Sequence]
# Type alias for vector of bools
BoolVector = torch.Tensor
def allclose_to_first(
values: TensorLike,
rtol: float = 1e-05,
atol: float = 1e-08,
equal_nan: bool = False,
return_vector: bool = False,
) -> BoolVector | bool:
"""
Check if all tensors in `values` are close to the first tensor `values[0]` using a vectorized approach.
If `return_vector` is False (default), returns a single boolean indicating whether
every tensor is close to the first tensor. If `return_vector` is True, returns a list
of booleans where each element corresponds to whether the respective tensor in
`values` is close to the first tensor. The first element is always True.
Args:
values (torch.Tensor | Sequence[torch.Tensor] | Sequence[Sequence]):
Nested list of values to compare. Must be rectangular, but not necessarily 2D.
rtol (float, optional): Relative tolerance. Defaults to 1e-05.
atol (float, optional): Absolute tolerance. Defaults to 1e-08.
equal_nan (bool, optional): Whether to consider NaNs as equal. Defaults to False.
return_vector (bool, optional): If True, returns a list of booleans for each comparison.
Defaults to False.
Returns:
bool or BoolVector:
- If `return_vector` is False, returns True if all tensors are close to the first tensor;
otherwise, returns False.
- If `return_vector` is True, returns a 1D tensor of bools where the first element is True
(as the reference tensor is trivially close to itself), and each subsequent element indicates
whether the corresponding tensor is close to the first tensor.
"""
if not isinstance(values, torch.Tensor):
values = torch.tensor(values)
reference = values[0]
is_close = torch.isclose(
values, reference, rtol=rtol, atol=atol, equal_nan=equal_nan
)
# flatten dimensions after first
result_vector = torch.all(is_close.view(is_close.size(0), -1), dim=1)
return result_vector if return_vector else bool(torch.all(result_vector))
def compute_stats(data: Sequence[number | Sequence]) -> dict[str, float]:
"""Compute mean and standard deviation from a possibly jagged nested list.
This function recursively traverses a nested list of numbers (ints or floats)
and computes the overall mean and standard deviation of the flattened list.
Args:
data (Sequence[number | Sequence]): A possibly jagged nested list of numbers.
Returns:
dict: A dictionary with two keys:
- "mean": The mean of all numerical elements.
- "var": The variance of all numerical elements.
"""
def accumulate(x):
"""Recursively accumulate the sum, sum of squares, and count of numerical values.
Args:
x (int | float | list): A number or a nested list of numbers.
Returns:
tuple: A tuple of three elements:
- total (float): Sum of all numbers.
- total_sq (float): Sum of squares of all numbers.
- count (int): Count of numerical elements.
"""
if isinstance(x, (int, float)):
return x, x * x, 1
elif isinstance(x, list):
total, total_sq, count = 0.0, 0.0, 0
for item in x:
s, ss, c = accumulate(item)
total += s
total_sq += ss
count += c
return total, total_sq, count
else:
raise ValueError(f"Invalid element type encountered: {type(x)}")
total, total_sq, count = accumulate(data)
if count == 0:
raise ValueError("No numerical elements found in the input data.")
mean = total / count
variance = total_sq / count - mean * mean
return {"mean": mean, "var": variance}
def compute_discounted_returns(rewards: torch.Tensor, gamma: float) -> torch.Tensor:
"""Compute discounted returns from a 1D vector of rewards.
Given a list or torch tensor of rewards and a discount factor, this function computes
the discounted return at each timestep. The discounted return at time t is defined as:
G_t = rewards[t] + gamma * rewards[t+1] + gamma^2 * rewards[t+2] + ...
Args:
rewards (list[float] or torch.Tensor): A 1D list or tensor of rewards.
gamma (float): The discount factor (should be between 0 and 1).
Returns:
list[float]: A list containing the discounted returns for each timestep.
"""
if not isinstance(rewards, torch.Tensor):
rewards = torch.tensor(rewards, dtype=torch.float)
discounted_returns = torch.empty_like(rewards)
running_return = 0.0
for t in reversed(range(len(rewards))):
running_return = rewards[t] + gamma * running_return
discounted_returns[t] = running_return
return discounted_returns
def compute_grpo_process_supervision_advantages(
rewards: Sequence[Sequence[number]], gamma: float = None, std_tol: float = 1e-8
) -> list[torch.Tensor]:
"""
Given a (possibly jagged) list of list of rewards, compute advantages for GRPO.
Args:
rewards (Sequence[Sequence[number]]): A list of list of rewards. Each inner list
contains a reward for each "step" in a trajectory, where a "step" is an
abstract unit of time left up to the modeler to define.
gamma (float): The discount factor.
std_tol (float): The tolerance for the standard deviation.
Returns:
A list of tensors of advantages.
Raises:
ValueError: If the standard deviation of the flattened rewards is smaller than the tolerance.
"""
stats = compute_stats(rewards)
mean = stats["mean"]
std = stats["var"] ** 0.5
if std < std_tol:
raise ValueError(f"`std` is smaller than tolerance of {std_tol}.")
normalized_rewards = [
(torch.tensor(trajectory) - mean) / std for trajectory in rewards
]
if gamma is None:
advantages = [
trajectory.flip(dims=[0]).cumsum(dim=0).flip(dims=[0])
for trajectory in normalized_rewards
]
else:
advantages = [
compute_discounted_returns(trajectory, gamma)
for trajectory in normalized_rewards
]
return advantages

View file

@ -0,0 +1,184 @@
import argparse
import os
from typing import Any, Dict, Optional
import torch
import yaml
class ConfigHandler:
"""Handles loading and merging of configuration files with CLI overrides"""
def __init__(self, config_dir: Optional[str] = None):
self.config_dir = config_dir or os.path.join(
os.path.dirname(__file__), "../../configs"
)
self.parser = self._setup_argument_parser()
def _setup_argument_parser(self) -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(description="Training configuration")
# Config files
parser.add_argument(
"--env",
type=str,
default="crosswords",
help="Environment config file name (without .yaml)",
)
parser.add_argument(
"--agent",
type=str,
default="nous_hermes",
help="Agent config file name (without .yaml)",
)
parser.add_argument(
"--config",
type=str,
help="Configuration file name (without .yaml)",
)
# CLI overrides
parser.add_argument("--group-size", type=int, help="Override group size")
parser.add_argument("--total-steps", type=int, help="Override total steps")
parser.add_argument("--batch-size", type=int, help="Override batch size")
parser.add_argument("--seed", type=int, help="Override random seed")
parser.add_argument("--device", type=str, help="Override device (cuda/cpu/mps)")
parser.add_argument("--server-url", type=str, help="Override server URL")
# Dataset-specific overrides
parser.add_argument("--dataset-name", type=str, help="Override dataset name")
parser.add_argument("--dataset-split", type=str, help="Override dataset split")
parser.add_argument(
"--prompt-field", type=str, help="Override prompt field name"
)
parser.add_argument(
"--answer-field", type=str, help="Override answer field name"
)
parser.add_argument("--system-prompt", type=str, help="Override system prompt")
parser.add_argument(
"--max-generations", type=int, help="Override max generations per prompt"
)
parser.add_argument(
"--reward-funcs",
type=str,
nargs="+",
help="Override reward functions to use",
)
return parser
def _load_yaml(self, path: str) -> Dict[str, Any]:
"""Load a YAML configuration file"""
with open(path, "r") as f:
return yaml.safe_load(f)
def _determine_device(self, config: Dict[str, Any]) -> str:
if config.get("device") == "auto":
if torch.backends.mps.is_available():
return "mps"
elif torch.cuda.is_available():
return "cuda"
return "cpu"
return config.get("device", "cpu")
def load_config(self, args: Optional[argparse.Namespace] = None) -> Dict[str, Any]:
"""Load and merge configurations with CLI overrides"""
if args is None:
args = self.parser.parse_args()
# environment config
config = self._load_yaml(os.path.join(self.config_dir, f"envs/{args.env}.yaml"))
# agent/model config
agent_config = self._load_yaml(
os.path.join(self.config_dir, f"agents/{args.agent}.yaml")
)
config["agent"] = agent_config
# CLI overrides
if args.group_size:
config["group_size"] = args.group_size
if args.total_steps:
config["total_steps"] = args.total_steps
if args.batch_size:
config["batch_size"] = args.batch_size
if args.seed:
config["initial_seed"] = args.seed
if args.device:
config["agent"]["device"] = args.device
if args.server_url:
config["rollout_server_url"] = args.server_url
# Ensure player_names is populated based on group_size
if "env_kwargs" in config and "player_names" in config["env_kwargs"]:
config["env_kwargs"]["player_names"] = {
i: f"Player_{i}" for i in range(config["group_size"])
}
config["agent"]["device"] = self._determine_device(config["agent"])
return config
def load_dataset_config(
self, args: Optional[argparse.Namespace] = None
) -> Dict[str, Any]:
"""Load and merge dataset environment configurations with CLI overrides"""
if args is None:
args = self.parser.parse_args()
# Start with base environment config
config = self._load_yaml(os.path.join(self.config_dir, f"envs/{args.env}.yaml"))
# Load agent config
agent_config = self._load_yaml(
os.path.join(self.config_dir, f"agents/{args.agent}.yaml")
)
config["agent"] = agent_config
# Load dataset config if specified
if args.config:
dataset_config = self._load_yaml(
os.path.join(self.config_dir, f"datasets/{args.config}.yaml")
)
# Merge dataset config with main config instead of nesting
for key, value in dataset_config.items():
config[key] = value
# Apply CLI overrides for common parameters
if args.group_size:
config["group_size"] = args.group_size
if args.total_steps:
config["total_steps"] = args.total_steps
if args.batch_size:
config["batch_size"] = args.batch_size
if args.seed:
config["initial_seed"] = args.seed
if args.device:
config["agent"]["device"] = args.device
if args.server_url:
config["rollout_server_url"] = args.server_url
# Apply dataset-specific overrides
if "dataset" in config:
if args.dataset_name:
config["dataset"]["dataset_name"] = args.dataset_name
if args.dataset_split:
config["dataset"]["split"] = args.dataset_split
if args.prompt_field:
config["dataset"]["prompt_field"] = args.prompt_field
if args.answer_field:
config["dataset"]["answer_field"] = args.answer_field
if args.system_prompt:
config["dataset"]["system_prompt"] = args.system_prompt
if args.max_generations:
config["dataset"]["max_generations_per_prompt"] = args.max_generations
if args.reward_funcs:
config["dataset"]["reward_funcs"] = args.reward_funcs
# Set device
config["agent"]["device"] = self._determine_device(config["agent"])
# Add slurm flag to config if running in a Slurm environment
config["use_slurm"] = "SLURM_JOB_ID" in os.environ
return config

View file

@ -0,0 +1,112 @@
import math
import random
# TODO: move this to the server manager
async def generate_with_diverse_first_tokens(
self, messages, prefill="", n=8, max_tokens=4096, temperature=1.0
):
"""
Generate diverse completions by sampling different first tokens.
Parameters:
- messages: List of message dictionaries for chat completion
- prefill: Prefix text to add to assistant's message
- n: Number of diverse completions to generate
- max_tokens: Maximum tokens per completion
- temperature: Sampling temperature
Returns:
- List of completion strings
"""
# Step 1: First get the logprobs for just the first token
first_token_messages = messages + [{"role": "assistant", "content": prefill}]
first_token_completion = await self.server.chat_completion(
messages=first_token_messages,
n=1,
max_tokens=1,
temperature=0.0, # Use 0 temperature to get raw logprobs
logprobs=True,
top_logprobs=20, # Get top 20 logprobs for the first token
)
# Extract logprobs from the completion
try:
# Get the logprobs for the first token
logprobs_dict = (
first_token_completion.choices[0].logprobs.content[0].top_logprobs
)
# Convert to list of (token, logprob) tuples
logprobs_list = [(item.token, item.logprob) for item in logprobs_dict]
# Convert logprobs to probabilities with temperature
logprobs_array = [lp for _, lp in logprobs_list]
probs = [math.exp(lp / temperature) for lp in logprobs_array]
total = sum(probs)
probs = [p / total for p in probs]
# Sample n unique tokens
sampled_indices = random.choices(
range(len(logprobs_list)), weights=probs, k=min(n, len(logprobs_list))
)
# Ensure unique indices
sampled_indices = list(set(sampled_indices))
# If we have fewer than n tokens, sample again to fill
while len(sampled_indices) < n and len(sampled_indices) < len(logprobs_list):
remaining = min(
n - len(sampled_indices), len(logprobs_list) - len(sampled_indices)
)
available_indices = [
i for i in range(len(logprobs_list)) if i not in sampled_indices
]
available_probs = [probs[i] for i in available_indices]
total = sum(available_probs)
if total > 0:
available_probs = [p / total for p in available_probs]
additional_indices = random.choices(
available_indices, weights=available_probs, k=remaining
)
sampled_indices.extend(additional_indices)
else:
# If all remaining probs are 0, just pick randomly
additional_indices = random.sample(available_indices, k=remaining)
sampled_indices.extend(additional_indices)
# Get the selected first tokens
first_tokens = [logprobs_list[i][0] for i in sampled_indices]
except (AttributeError, IndexError, KeyError) as e:
# Fallback if we can't extract logprobs properly
print(f"Error extracting logprobs: {e}")
return await self.fallback_generate(
messages, prefill, n, max_tokens, temperature
)
# Step 2: Generate completions with each selected first token
completions = []
for token in first_tokens:
# Create a prompt with the first token already included
prompt_with_token = messages + [
{"role": "assistant", "content": prefill + token}
]
# Generate the rest of the completion
completion = await self.server.chat_completion(
messages=prompt_with_token,
n=1,
max_tokens=max_tokens - 1, # Subtract 1 for the token we already used
temperature=temperature,
top_p=0.3,
extra_body={
"min_p": 0.5,
"repetition_penalty": 1.05,
},
)
# Extract the completion content and remove the prefill+token
full_content = completion.choices[0].message.content
completions.append(token + full_content)

View file

@ -0,0 +1,19 @@
import numpy as np
def get_std_min_max_avg(name: str, data: list, metrics_dict: dict) -> dict:
"""
Calculate the standard deviation, minimum, maximum, and average of a list of numbers.
Adds it to the wandb dict for logging.
Args:
data (list): A list of numbers.
Returns:
dict: A dictionary containing the standard deviation, minimum, maximum, and average.
"""
metrics_dict[f"{name}_mean"] = np.mean(data)
metrics_dict[f"{name}_std"] = np.std(data)
metrics_dict[f"{name}_max"] = np.max(data)
metrics_dict[f"{name}_min"] = np.min(data)
return metrics_dict

View file

@ -0,0 +1,192 @@
import torch
from transformers import PreTrainedTokenizer
from atroposlib.type_definitions import Message
# Roles that should be masked in the loss calculation (not used for training)
UNMASKED_ROLES = ["assistant"]
def tokenize_for_trainer(
tokenizer: PreTrainedTokenizer,
chat: list[Message],
include_messages: bool = False,
train_on_all_assistant_turns: bool = False,
finish_reason: str = "",
) -> dict:
"""
Tokenize a list of chat messages for the trainer.
Args:
tokenizer (PreTrainedTokenizer): The tokenizer to use.
chat (list): A list of chat messages.
include_messages (bool): Whether to include the messages in the output.
train_on_all_assistant_turns (bool): If True, mask out system/user/tool roles.
If False, use the original prefix masking.
Returns:
dict: A dictionary containing the tokenized chat messages.
"""
tokens = tokenizer.apply_chat_template(chat)
if not train_on_all_assistant_turns:
prefix_len = len(
tokenizer.apply_chat_template(chat[:-1], add_generation_prompt=True)
)
masks = [-100] * prefix_len + tokens[prefix_len:]
else:
# NOTE: This implementation will break if the default system prompt is used and depends on world state
# (e.g. current date). e.g. consider a system prompt that depends on the current date and a run that crosses
# midnight from 3/9 to 3/10 under a tokenizer that tokenizes 3/9 and 3/10 with a different number of tokens.
masks = torch.ones(len(tokens), dtype=torch.long) * -100
for i, msg in enumerate(chat):
if msg["role"] in UNMASKED_ROLES:
prefix_tokens = tokenizer.apply_chat_template(
chat[:i], tokenize=True, add_generation_prompt=True
)
unmasked_tokens = tokenizer.apply_chat_template(
chat[: i + 1], tokenize=True
)
start_idx = len(prefix_tokens)
end_idx = len(unmasked_tokens)
masks[start_idx:end_idx] = torch.tensor(unmasked_tokens[start_idx:])
masks = masks.tolist()
if finish_reason == "length":
if tokens[-1] == tokenizer.eos_token_id:
print("bad token\n")
# truncate the last token
tokens = tokens[:-1]
masks = masks[:-1]
return {
"tokens": tokens,
"masks": masks,
} | ({"messages": chat} if include_messages else {})
if __name__ == "__main__":
# Inspired by `preprocess --debug`` of https://github.com/axolotl-ai-cloud/axolotl
def decode_token_ids(
token_ids: list, mask, tokenizer, use_rich: bool = False
) -> str:
"""Convert a list of token IDs to a formatted string using tokenizer.decode,
with an option to highlight masked tokens in red using rich markup.
Each token is represented as decoded(tokenid, mask). If decoding a token returns an empty string
and the token is a known special token, it is replaced with a descriptive placeholder.
When use_rich is True, any token whose corresponding mask is -100 is wrapped with red highlighting.
Args:
token_ids (list[int]): A list of integer token IDs,
e.g. [50256, 329].
mask (list[int]): A list of masks corresponding to token_ids.
A mask value of -100 indicates the token is masked.
tokenizer: The Hugging Face tokenizer.
use_rich (bool): If True, wrap tokens with a mask of -100 in red highlighting.
Defaults to False.
Returns:
str: A space-separated string where each token is represented as decoded(tokenid, mask).
Raises:
ValueError: If any element in token_ids is not an integer.
Example:
>>> decode_token_ids([50256, 329], mask=[-100, 329], tokenizer=tokenizer, use_rich=True)
'[red]<|eos|>(50256, -100)[/red] '
'tokenX(329, 329)' # (actual output will vary based on the model's tokenizer)
"""
# Validate that all token_ids are integers.
if not all(isinstance(t, int) for t in token_ids):
raise ValueError("All token IDs must be integers.")
tokens_str_list = []
for tid, mid in zip(token_ids, mask):
# Use decode with flags to include special tokens.
decoded = tokenizer.decode(
[tid], skip_special_tokens=False, clean_up_tokenization_spaces=False
).strip()
# If the decoded string is empty and it's a special token, replace with a placeholder.
if not decoded and tid in tokenizer.all_special_ids:
if tid == tokenizer.eos_token_id:
decoded = "<|eos|>"
else:
decoded = f"<SPECIAL_{tid}>"
# Highlight token in red if use_rich is True and the token is masked (mid == -100)
if use_rich:
if mid == -100:
token_str = f"[pink3][bold]{decoded}[/bold][/pink3][steel_blue]({tid}, {mid})[/steel_blue]"
else:
token_str = (
f"[pale_green3][bold]{decoded}[/bold][/pale_green3]"
f"[steel_blue]({tid}, {mid})[/steel_blue]"
)
else:
token_str = f"{decoded}({tid}, {mid})"
tokens_str_list.append(token_str)
return " ".join(tokens_str_list)
messages = [
{
"role": "system",
"content": "You are a helpful AI assistant that provides accurate information.",
},
{"role": "user", "content": "What's the capital of France?"},
{"role": "assistant", "content": "The capital of France is Paris."},
{"role": "user", "content": "Can you tell me more about Paris?"},
{
"role": "assistant",
"content": "<tool_call>{'tool_name': 'web_search', 'args': {'query': 'Paris'}}</tool_call>",
},
{
"role": "tool",
"content": (
"Paris is the capital and most populous city of France. "
"It has an estimated population of 2,165,423 residents in 2019 "
"in an area of more than 105 km²."
),
},
{
"role": "assistant",
"content": (
"Paris is indeed the capital of France and its most populous city with over 2 million residents. "
"It's known for its iconic landmarks like the Eiffel Tower, Louvre Museum, and Notre-Dame Cathedral. "
"The city is a global center for art, fashion, gastronomy, and culture."
),
},
]
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")
last_turn_only = tokenize_for_trainer(
tokenizer, messages, train_on_all_assistant_turns=False
)
last_turn_only["repr"] = decode_token_ids(
last_turn_only["tokens"], last_turn_only["masks"], tokenizer, use_rich=True
)
all_assistant_turns = tokenize_for_trainer(
tokenizer, messages, train_on_all_assistant_turns=True
)
all_assistant_turns["repr"] = decode_token_ids(
all_assistant_turns["tokens"],
all_assistant_turns["masks"],
tokenizer,
use_rich=True,
)
from rich import print
print("[bold cyan]last turn only[/]")
print(last_turn_only["repr"])
print()
print("[bold cyan]all assistant turns[/]")
print(all_assistant_turns["repr"])

129
environments/README.md Normal file
View file

@ -0,0 +1,129 @@
# Environments
This directory contains various environments for training and evaluating language models on different tasks. Each environment implements a specific task with its own input format, reward function, and evaluation metrics.
## Available Environments
---
### MCQA Thinking Environment (`mcqa_thinking_env.py`)
Multiple Choice Question Answering environment that requires models to think through problems systematically.
**Input Format:**
- Questions from the MMLU (Massive Multitask Language Understanding) dataset
- Each item contains:
- `prompt`: The question text
- `answer`: Index of correct answer
- `ground_truth`: Letter (A, B, C, D) of correct answer
- `options`: List of possible answers
**System Prompt:**
```
You are a deep thinking AI, you may use extremely long chains of thought to deeply consider the problem and deliberate with yourself via systematic reasoning processes to help come to a correct solution prior to answering. You should enclose your thoughts and internal monologue inside <think> </think> tags, and then provide your solution or response to the problem.
```
**Reward Function:**
- Score of 1.0 if the model's answer matches the ground truth letter
- Score of 0.0 if incorrect or invalid response (multiple think tags, malformed thinking sections)
- Length penalty applied if all responses are correct:
- No penalty for responses under 50% of max token length
- Linear penalty scaling from 1.0 down to 0.0 for responses between 50% and 100% of max length
- Returns None if all scores are identical (no learning signal)
---
### GSM8K Environment (`gsm8k_server.py`)
Mathematical reasoning environment using the GSM8K dataset.
**Input Format:**
- Questions from GSM8K dataset
- Each item contains:
- `question`: The math problem
- `answer`: The numerical answer
**System Prompt:**
```
You are a deep thinking AI, you may use extremely long chains of thought to deeply consider the problem and deliberate with yourself via systematic reasoning processes to help come to a correct solution prior to answering. You should enclose your thoughts and internal monologue inside <think> </think> tags, and then provide your solution or response to the problem.
You are allocated a maximum of 2048 tokens, please strive to use less.
You will then provide your answer like this: \boxed{your answer here}
It is important that you provide your answer in the correct format.
If you do not, you will not receive credit for your answer.
So please end your answer with \boxed{your answer here}
```
**Reward Function:**
- Score of 1.0 if the model's answer matches the ground truth (using LaTeX verification)
- Score of 0.0 if incorrect or if ground truth is not parseable
- Length penalty applied if all responses are correct:
- No penalty for responses under 50% of max token length
- Linear penalty scaling from 1.0 down to 0.0 for responses between 50% and 100% of max length
- Returns None if all scores are identical (no learning signal)
---
### Tool Calling Environment (`tool_calling_server.py`)
Environment for training models to make function calls in a structured format.
**Input Format:**
- Conversations from ShareGPT-Hermes function call dataset
- Each item contains:
- `conversations`: List of messages with roles (system, human, gpt)
- Expected tool calls in JSON format
**System Prompt:**
```
You are a deep thinking AI, you may use extremely long chains of thought to deeply consider the problem and deliberate with yourself via systematic reasoning processes to help come to a correct solution prior to answering. You should enclose your thoughts and internal monologue inside <think> </think> tags, and then provide your solution or response to the problem.
```
**Reward Function:**
- Score of 1.0 if all expected tool calls are present and match exactly (including nested JSON fields)
- Score of 0.0 if any tool calls are missing, incorrect, or malformed
- Length penalty applied if all responses are correct:
- No penalty for responses under 50% of max token length
- Linear penalty scaling from 1.0 down to 0.0 for responses between 50% and 100% of max length
- Returns None if all scores are identical (no learning signal)
## Common Features
All environments share these common features:
1. **Training/Test Split:**
- 98% training, 2% test split
- Random shuffling with fixed seed (42)
2. **Metrics Tracking:**
- Percent correct buffer
- Completion lengths
- Wandb integration for visualization
- Rollout tracking
3. **Token Management:**
- Maximum token length limits
- Token length statistics tracking
- Length penalty for excessive responses
4. **Evaluation:**
- Separate evaluation on test set
- Comprehensive metrics logging
- Support for multiple model completions per prompt
## Usage
Each environment can be initialized with:
- `config`: BaseEnvConfig object
- `server_configs`: List of OpenAI API configurations
- `slurm`: Boolean for distributed training
- `testing`: Boolean for testing mode
The environments follow a common interface with methods for:
- `setup()`: Loading and preparing datasets
- `get_next_item()`: Retrieving next training item
- `collect_trajectories()`: Generating model responses
- `score()`: Computing rewards
- `evaluate()`: Running evaluation on test set
- `wandb_log()`: Logging metrics to Weights & Biases

View file

@ -0,0 +1,155 @@
# Dataset Environment Local Testing Guide
This document explains how to run the Dataset Environment locally for testing purposes.
## Prerequisites
1. Make sure you have the repository cloned and dependencies installed
2. Ensure you have a compatible model available (local or API)
## Option 1: Single Script End-to-End Execution
The easiest way to test the Dataset Environment is to use the unified launcher script:
```bash
python -m environments.dataset_environment.launch_local_dataset_run
```
This script:
1. Starts the Trajectory Handler API server via uvicorn
2. Launches the Dataset Environment in serve mode (connected to the API)
3. Runs the example GRPO trainer directly
The script has environment defaults configured for:
- Using a small LLM (Qwen2.5-1.5B) running on localhost:9001
- A test subset of a public HF dataset
- Basic length-based rewards
## Option 2: Step-by-step Manual Testing
### 1. Start the API Server
```bash
uvicorn atroposlib.api.server:app --host 127.0.0.1 --port 8000
```
### 2. Launch the Environment
```bash
python -m environments.dataset_environment.dataset_env serve \
--group_size 4 \
--max_num_workers 2 \
--rollout_server_url http://127.0.0.1:8000 \
--tokenizer_name Qwen/Qwen2.5-1.5B-Instruct \
--use_wandb --wandb_name dataset_env_local_test \
--max_token_length 512 \
--ensure_scores_are_not_same \
--dataset_name HuggingFaceH4/testing_self_instruct_process_essays \
--split train[:100] \
--prompt_field prompt --answer_field answer \
--reward_functions length \
--max_tokens 128 --temperature 0.7 \
--model_name Qwen/Qwen2.5-1.5B-Instruct \
--base_url http://127.0.0.1:9001 \
--slurm --testing
```
### 3. Launch the Trainer
In a separate terminal:
```bash
python -m example_trainer.grpo.train \
--model_name Qwen/Qwen2.5-1.5B-Instruct \
--training_steps 20 \
--batch_size 2 \
--gradient_accumulation_steps 2 \
--seq_len 512
```
## Option N: Use the Dataset Local Server
For easier configuration via YAML files, you can use the local server script:
```bash
# This command will look for environments/dataset_environment/configs/gsm8k.yaml
python environments/dataset_environment/dataset_local_server.py --config gsm8k
# You can also provide a full path:
# python environments/dataset_environment/dataset_local_server.py --config /path/to/your/custom_config.yaml
```
This will load the specified config and run the environment accordingly.
## Debugging
To check if requests are properly sent to and received by the API server, you can inspect the logs from both the environment and the API server. Look for:
- API logs showing incoming requests
- Environment logs showing completions being generated and scored
For model-specific issues, check:
- Ensure your model server is running at the specified URL
- Check model server logs for any errors related to generation
## Configuration Structure
Configuration files placed in `environments/dataset_environment/configs/` typically contain:
```yaml
# Example: environments/dataset_environment/configs/my_config.yaml
# Base environment parameters (can be overridden by dataset specifics)
tokenizer_name: "NousResearch/DeepHermes-3-Llama-3-8B-Preview"
group_size: 1
use_wandb: false
# ... other base parameters
# Dataset specific configuration
dataset:
# Dataset parameters
dataset_name: "databricks/databricks-dolly-15k"
prompt_field: "instruction"
# ... other dataset parameters
reward_functions:
- type: "accuracy"
weight: 1.0
- type: "repetition_penalty"
weight: 0.2
# Optional Server configuration (if not using CLI flags in dataset_env)
server_configs:
- model_name: "gpt-4.1-nano"
api_key: ${OPENAI_API_KEY}
timeout: 600
```
### Important Configuration Parameters
#### Base Parameters
- `tokenizer_name`: The tokenizer to use for encoding/decoding text
- `group_size`: Number of responses to collect per prompt
- `max_token_length`: Maximum token length for generation
- `steps_per_eval`: How often to run evaluations
#### Dataset Specific Parameters (`dataset:` section)
- `dataset_name`: HuggingFace dataset name (required)
- `dataset_config`: Dataset configuration name (optional)
- `prompt_field`: Field in dataset to use as prompt (required)
- `answer_field`: Field in dataset to use as answer (optional)
- `system_prompt`: System prompt to use (optional)
- `reward_functions`: List of reward functions to apply (optional)
#### Server Configuration (`server_configs:` section, optional in local server)
- `model_name`: LLM model to use
- `api_key`: API key for the model (can use environment variables with ${VAR_NAME} syntax)
- `timeout`: Request timeout in seconds
## Troubleshooting
If you encounter issues with reward functions, make sure they are properly registered in the registry.
For dataset-related issues, verify that the dataset exists on HuggingFace and that the specified fields exist in the dataset.

View file

@ -0,0 +1,355 @@
## Quick Start
### Option A: Unified End-to-End Launcher
```bash
python -m environments.dataset_environment.launch_local_dataset_run
```
This single command spins up:
1. The Trajectory Handler API server (`uvicorn atroposlib.api.server:app`)
2. The DatasetEnv in serve mode (connected to the API)
3. The example GRPO trainer (via `example_trainer.grpo.train`)
### Option B: Manual Steps
1. **Start the API server**
```bash
uvicorn atroposlib.api.server:app --host 127.0.0.1 --port 8000
```
2. **Launch the Dataset Environment**
- **Using CLI flags**:
(These flags override any config file settings)
```bash
python -m environments.dataset_environment.dataset_env serve \
--group_size 4 \
--max_num_workers 2 \
--rollout_server_url http://127.0.0.1:8000 \
--tokenizer_name Qwen/Qwen2.5-1.5B-Instruct \
--use_wandb --wandb_name dataset_env_local_test \
--max_token_length 512 \
--ensure_scores_are_not_same \
--dataset_name HuggingFaceH4/testing_self_instruct_process_essays \
--split train[:100] \
--prompt_field prompt --answer_field answer \
--reward_functions length \
--max_tokens 128 --temperature 0.7 \
--model_name Qwen/Qwen2.5-1.5B-Instruct \
--base_url http://127.0.0.1:9001 \
--slurm --testing
```
- **Using YAML config files**:
Place a dataset config under `environments/dataset_environment/configs/<name>.yaml`:
```yaml
# Example: environments/dataset_environment/configs/gsm8k.yaml
dataset:
dataset_name: "gsm8k"
dataset_config: "main"
split: "train"
prompt_field: "question"
answer_field: "answer"
system_prompt: "You are a mathematical problem solver..."
generation:
temperature: 0.7
top_p: 0.95
reward_functions:
- type: "accuracy"
weight: 1.0
```
Then run the local test server:
```bash
# Will look for environments/dataset_environment/configs/gsm8k.yaml
python environments/dataset_environment/dataset_local_server.py --config gsm8k
```
3. **Launch the Trainer**
```bash
python -m example_trainer.grpo
```
## Configuration Files Directory
Dataset environment specific configurations now live in `environments/dataset_environment/configs/`.
Shared configurations (like agents) might still reside in the project's root `configs/` directory.
- `environments/dataset_environment/configs/` for dataset-specific configs (used by `dataset_local_server.py`).
- You can reference any `<name>.yaml` within this directory via the `--config` flag in the local server script.
## Reward Function Registry & Customization
Reward functions are managed by a centralized registry (see `atroposlib/envs/reward_fns/reward_function.py`). Built-in types include:
- `accuracy`: exact match to ground truth (tolerance, split_on_think_tag)
- `format`: checks for specific tags (preferred_tags)
- `reasoning_steps`: quality of step-by-step reasoning
- `repetition_penalty`: penalizes repetition
- `cosine_scaled`: semantic similarity scaled from embeddings
- `crossword_format`: crossword-specific penalty
- `r1`: combined accuracy + format
To preview all available functions:
```python
from atroposlib.envs.reward_fns import registry
print(registry.list())
```
### Creating Custom Reward Functions
1. Create a new file under `atroposlib/envs/reward_fns/my_reward.py`.
2. Subclass `RewardFunction` and register it:
```python
from atroposlib.envs.reward_fns import registry, RewardFunction
@registry.register
class MyCustomReward(RewardFunction):
def __init__(self, custom_param=1.0, weight=1.0, **kwargs):
super().__init__(weight=weight, **kwargs)
self.custom_param = custom_param
def compute(self, completions, **kwargs):
return [1.0 if "good answer" in self.get_content(c) else 0.0 for c in completions]
```
3. Reference it in your YAML config:
```yaml
reward_functions:
- type: "my_custom"
weight: 1.0
params:
custom_param: 2.0
```
### Dataset Environments
Dataset environments load data from HuggingFace datasets and evaluate LLM responses against ground truth. They're ideal for academic benchmarks and datasets with clear evaluation criteria.
Example configuration:
```yaml
dataset:
dataset_name: "gsm8k"
dataset_config: "main"
split: "train"
prompt_field: "question"
answer_field: "answer"
system_prompt: "You are a mathematical problem solver..."
reward_functions:
- type: "accuracy"
weight: 1.0
```
## Reward Functions
The system features a flexible reward function architecture for evaluating model outputs.
### Basic Usage
In your environment config, specify reward functions:
```yaml
reward_functions:
- type: "accuracy"
weight: 1.0
- type: "format"
weight: 0.5
```
### Combining Reward Functions
Combine multiple reward functions with weights:
```yaml
reward_functions:
- type: "combined"
params:
normalization: "sum"
rewards:
- type: "accuracy"
weight: 1.5
- type: "format"
weight: 0.5
```
### Available Reward Functions
#### `accuracy`
Evaluates if completions match ground truth answers.
```yaml
type: "accuracy"
weight: 1.0
params:
tolerance: 1e-6
split_on_think_tag: true
max_boxed_threshold: 6
```
#### `format`
Checks if completions include specific XML-style tags.
```yaml
type: "format"
weight: 1.0
params:
preferred_tags: ["think", "reasoning"]
require_all_tags: false
case_sensitive: false
```
#### `reasoning_steps`
Evaluates step-by-step reasoning quality.
```yaml
type: "reasoning_steps"
weight: 1.0
params:
min_words: 10
min_steps: 3
base_score: 0.1
```
#### `repetition_penalty`
Penalizes repetitive content.
```yaml
type: "repetition_penalty"
weight: 0.5
params:
threshold: 0.05
min_words: 10
min_sentences: 2
```
#### `cosine_scaled`
Measures semantic similarity between completions and solutions.
```yaml
type: "cosine_scaled"
weight: 0.8
params:
model_name: "sentence-transformers/all-MiniLM-L6-v2"
scale_factor: 1.0
min_reward: -1.0
max_reward: 1.0
```
#### `crossword_format`
Game-specific reward for crossword puzzles.
```yaml
type: "crossword_format"
weight: 1.0
params:
reward_value: 1.0
penalize_invalid_chars: true
```
#### `r1`
Combined reward using both reasoning format and accuracy.
```yaml
type: "r1"
weight: 1.0
params:
format_weight: 0.5
accuracy_weight: 1.0
```
### Creating Custom Reward Functions
To create a custom reward function:
1. Create a new file in `atroposlib/envs/reward_fns/my_reward.py`
2. Define your reward function class:
```python
from typing import Any, List
from atroposlib.envs.reward_fns import registry, RewardFunction
@registry.register
class MyCustomReward(RewardFunction):
def __init__(self, custom_param=1.0, weight=1.0, **kwargs):
super().__init__(weight=weight, **kwargs)
self.custom_param = custom_param
def compute(self, completions: List[Any], **kwargs) -> List[float]:
rewards = []
for completion in completions:
content = self.get_content(completion)
# Implement your reward logic
reward = 1.0 if "good answer" in content else 0.0
rewards.append(reward)
return rewards
```
3. Use it in your config:
```yaml
reward_functions:
- type: "my_custom"
weight: 1.0
params:
custom_param: 2.0
```
### Dataset Environment Debugger
The dataset environment debugger allows you to run a dataset environment locally with a Hugging Face model, providing enhanced visibility into reward function performance and model responses.
```bash
# Run with default settings
python -m atroposlib.cli.dataset_env_debugger --env gsm8k_debug --agent nous_hermes_8b
# List available environments and agents
python -m atroposlib.cli.dataset_env_debugger --list-configs
# Interactive mode with debugging information
python -m atroposlib.cli.dataset_env_debugger --env gsm8k_debug --agent nous_hermes_8b --interactive --debug
# Run with custom generation parameters
python -m atroposlib.cli.dataset_env_debugger --env gsm8k_debug --agent nous_hermes_8b --temperature 0.5 --top-p 0.95
# Run with detailed logging
python -m atroposlib.cli.dataset_env_debugger --env gsm8k_debug --agent nous_hermes_8b --verbose
```
## Environment Overview
This environment demonstrates how to use a standard dataset (e.g., from Hugging Face Datasets) as a source for generating prompts and evaluating LLM responses. It allows for testing and training models on established benchmarks or custom datasets where prompts and expected answers/ground truth are available.
**Demonstrates:**
- Loading and processing data from Hugging Face Datasets.
- Configuring system prompts, prompt/answer fields.
- Applying various reward functions (accuracy, format, semantic similarity, etc.) to evaluate generations.
- Integrating with the `atroposlib` framework for data collection and scoring.
**Training Goal:**
- To train LLMs to follow instructions and generate responses that align with the format and content specified by the dataset and reward functions.
- To improve performance on specific tasks defined by datasets (e.g., math problem solving, code generation, question answering).
## Local Testing
To test this environment locally, you can run the provided local server. This server simulates the interaction flow without needing the full distributed setup.
First, ensure you have the necessary dependencies installed.
Then, run the local server script from the root of the repository:
```bash
python environments/dataset_environment/dataset_local_server.py --config-path path/to/your/dataset_config.yaml
```
Replace `path/to/your/dataset_config.yaml` with the actual path to your environment configuration file (e.g., `configs/envs/gsm8k.yaml`). The server will load the dataset specified in the config, process items, and simulate generating responses.
FOR RELEASE - FIX

View file

@ -0,0 +1,11 @@
"""
Dataset Environment for training models with Hugging Face datasets.
This environment provides a flexible way to train models using a variety of datasets
from Hugging Face or other sources, evaluating generations against reference answers
with customizable reward functions.
"""
from environments.dataset_environment.dataset_env import DatasetEnv, DatasetEnvConfig
__all__ = ["DatasetEnv", "DatasetEnvConfig"]

View file

@ -0,0 +1,52 @@
# Dataset Environment Local Testing Configuration
# Base environment parameters
tokenizer_name: "NousResearch/DeepHermes-3-Llama-3-8B-Preview"
group_size: 1
use_wandb: false
max_num_workers: 1
rollout_server_url: "http://localhost:8000"
total_steps: 1
batch_size: 1
steps_per_eval: 5
max_token_length: 4096
wandb_name: "dataset_test_local"
ensure_scores_are_not_same: false
# Dataset specific configuration
dataset:
# Dataset parameters
dataset_name: "gsm8k" # Example dataset
dataset_config: "main"
split: "train"
prompt_field: "question"
answer_field: "answer"
# Generation parameters
system_prompt: "You are an expert mathematician. You need to solve the given math problem step-by-step, showing your reasoning clearly. You should enclose your thoughts and internal monologue inside <think> </think> tags, and then provide your final answer.\n\nFollow these steps:\n1. Understand the problem carefully\n2. Plan your approach\n3. Execute the calculations step-by-step\n4. Verify your solution\n\nYou may use extremely long chains of thought to deeply consider the problem and deliberate with yourself via systematic reasoning processes to help come to a correct solution."
prefill: "<think>\n"
shuffle_dataset: true
max_generations_per_prompt: 1
# Generation length parameters
max_tokens: 4096
length_warmup_steps: 0
min_tokens: 0
# Completion parameters
temperature: 0.7
top_p: 0.9
# Reward functions
reward_functions:
- "accuracy"
- "format"
accuracy_reward_weight: 1.0
format_reward_weight: 0.2
# Server configuration
server_configs:
- model_name: "gpt-4.1-nano"
api_key: ${OPENAI_API_KEY}
timeout: 600

View file

@ -0,0 +1,73 @@
tokenizer_name: "NousResearch/DeepHermes-3-Llama-3-1B-Preview"
group_size: 8
use_wandb: true
max_num_workers: 256
max_eval_workers: 16
steps_per_eval: 100
batch_size: 1024
max_batches_offpolicy: 3
total_steps: 1000
rollout_server_url: "http://localhost:8000"
use_local_agents: true
dataset:
dataset_name: "gsm8k"
dataset_config: "main"
split: "train"
prompt_field: "question"
answer_field: "answer"
system_prompt: "You are a deep thinking AI, you may use extremely long chains of thought to deeply consider the problem and deliberate with yourself via systematic reasoning processes to help come to a correct solution prior to answering. You should enclose your thoughts and internal monologue inside <think> </think> tags, and then provide your solution or response to the problem."
shuffle_dataset: true
max_generations_per_prompt: 1
include_messages_in_scoring: false
# New configurable reward functions
reward_functions:
- type: "r1"
weight: 1.5
params:
format_weight: 0.5
accuracy_weight: 1.0
- type: "cosine_scaled"
weight: 0.8
params:
scale_factor: 1.2
min_reward: -1.0
max_reward: 1.0
- type: "accuracy"
weight: 2.0
params:
split_on_think_tag: true
- type: "format"
weight: 0.7
params:
preferred_tags: ["think", "reasoning"]
require_all_tags: false
- type: "reasoning_steps"
weight: 1.0
params:
min_steps: 3
- type: "repetition_penalty"
weight: 0.5
params:
threshold: 0.1
# Legacy format still supported for backward compatibility
# reward_funcs:
# - "r1_reward"
# - "cosine_scaled_reward"
# - "accuracy_reward"
# - "format_reward"
# - "reasoning_steps_reward"
# - "repetition_penalty_reward"
max_tokens: 16000
length_warmup_steps: 100
min_tokens: 2048
eval_dataset_name: "gsm8k"
eval_dataset_config: "main"
eval_split: "test"

View file

@ -0,0 +1,30 @@
tokenizer_name: "NousResearch/DeepHermes-3-Llama-3-8B-Preview"
group_size: 1
use_wandb: false
max_num_workers: 1
max_eval_workers: 0
batch_size: 1
total_steps: 100
rollout_server_url: "http://localhost:8000"
dataset:
dataset_name: "gsm8k"
dataset_config: "main"
split: "train"
prompt_field: "question"
answer_field: "answer"
system_prompt: "You are a deep thinking AI, you may use extremely long chains of thought to deeply consider the problem and deliberate with yourself via systematic reasoning processes to help come to a correct solution prior to answering. You should enclose your thoughts and internal monologue inside <think> </think> tags, and then provide your solution or response to the problem."
shuffle_dataset: true
max_generations_per_prompt: 1
include_messages_in_scoring: true
# Using multiple reward functions for testing
reward_funcs:
- "accuracy_reward"
- "format_reward"
max_tokens: 4096
length_warmup_steps: 0
min_tokens: 200

View file

@ -0,0 +1,407 @@
import asyncio
import logging
from typing import Any, Dict, List, Optional, Tuple, Union
from datasets import load_dataset
from pydantic import Field
from atroposlib.envs.base import BaseEnv, BaseEnvConfig, ScoredDataGroup
from atroposlib.envs.reward_fns import registry
from atroposlib.envs.reward_fns.combined_reward import CombinedReward
from atroposlib.type_definitions import Item
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
class DatasetEnvConfig(BaseEnvConfig):
dataset_name: str = Field(..., description="HuggingFace dataset name")
dataset_config: Optional[str] = Field(
None, description="Dataset configuration name"
)
split: str = Field("train", description="Dataset split to use")
dataset_path: Optional[str] = Field(
None, description="Local path to dataset (alternative to dataset_name)"
)
prompt_field: str = Field(..., description="Field in dataset to use as prompt")
answer_field: Optional[str] = Field(
None, description="Field in dataset to use as answer"
)
ground_truth_field: Optional[str] = Field(
None, description="Field in dataset containing canonical correct answer"
)
system_prompt: Optional[str] = Field(None, description="System prompt to use")
prefill: Optional[str] = Field(None, description="Text to prefill the completion with (e.g. '<think>')")
shuffle_dataset: bool = Field(True, description="Whether to shuffle the dataset")
max_generations_per_prompt: int = Field(
1, description="Number of generations per prompt for collection"
)
include_messages_in_scoring: bool = Field(
False, description="Whether to include messages in scoring"
)
reward_funcs: List[str] = Field(
default_factory=list,
description="List of reward function names to apply (legacy)",
)
reward_functions: List[Union[str, Dict[str, Any]]] = Field(
default_factory=list,
description="List of reward functions to apply (string names or full configs)",
)
# Completion parameters
temperature: float = Field(0.7, description="Temperature for generation")
top_p: float = Field(0.9, description="Top-p for generation")
max_tokens: int = Field(4096, description="Maximum tokens for generation")
length_warmup_steps: int = Field(0, description="Steps for length warmup")
min_tokens: int = Field(0, description="Minimum tokens for generation")
eval_dataset_name: Optional[str] = Field(
None, description="Evaluation dataset name"
)
eval_dataset_config: Optional[str] = Field(
None, description="Evaluation dataset config"
)
eval_split: Optional[str] = Field(None, description="Evaluation dataset split")
class DatasetEnv(BaseEnv):
def __init__(
self, config: DatasetEnvConfig, server_configs, slurm=True, testing=False
):
super().__init__(config, server_configs, slurm, testing)
self.config = config
self.dataset = None
self.iter = 0
self.metric_buffer = {}
self.reward_function = self._initialize_reward_function()
def _initialize_reward_function(self):
if hasattr(self.config, "reward_functions") and self.config.reward_functions:
if len(self.config.reward_functions) == 1:
return registry.create(self.config.reward_functions[0])
else:
return CombinedReward(
rewards=self.config.reward_functions, normalization="sum"
)
elif hasattr(self.config, "reward_funcs") and self.config.reward_funcs:
if len(self.config.reward_funcs) == 1:
return registry.create(self.config.reward_funcs[0])
else:
return CombinedReward(
rewards=self.config.reward_funcs, normalization="none"
)
async def setup(self):
if self.config.dataset_path:
self.dataset = load_dataset(
self.config.dataset_path, split=self.config.split
)
else:
self.dataset = load_dataset(
self.config.dataset_name,
self.config.dataset_config,
split=self.config.split,
)
logger.info(f"Dataset features: {self.dataset.features}")
logger.info(f"Sample item keys: {list(self.dataset[0].keys())}")
logger.info(f"Sample item: {self.dataset[0]}")
if self.config.shuffle_dataset:
self.dataset = self.dataset.shuffle()
self.metric_buffer = {}
async def get_next_item(self) -> Item:
if not self.dataset:
await self.setup()
item = self.dataset[self.iter % len(self.dataset)]
self.iter += 1
user_msg = {"role": "user", "content": item[self.config.prompt_field]}
prompt = tuple([frozenset(user_msg.items())])
answer = None
if self.config.answer_field and self.config.answer_field in item:
answer = item[self.config.answer_field]
ground_truth = None
if self.config.ground_truth_field and self.config.ground_truth_field in item:
ground_truth = item[self.config.ground_truth_field]
return (prompt, answer, ground_truth)
async def collect_trajectory(self, item: Item) -> Tuple[List, List]:
# Extract user prompt and answer from item
user_content = dict(item[0][0])["content"]
answer = item[1] if len(item) > 1 else None
# Create messages list
messages = []
if self.config.system_prompt:
messages.append({"role": "system", "content": self.config.system_prompt})
messages.append({"role": "user", "content": user_content})
# Add prefill as assistant message if configured
if self.config.prefill:
messages.append({"role": "assistant", "content": self.config.prefill})
# Convert messages to a prompt string using the tokenizer
prompt = self.tokenizer.apply_chat_template(messages, tokenize=False)
# Calculate max tokens for generation (with optional warmup)
max_tokens = self.config.max_tokens
if self.config.length_warmup_steps > 0:
warmup_progress = min(1.0, self.curr_step / self.config.length_warmup_steps)
max_tokens = int(
self.config.min_tokens
+ warmup_progress * (self.config.max_tokens - self.config.min_tokens)
)
# Generate completion using completions API
completions = await self.server.completion(
prompt=prompt,
n=self.config.max_generations_per_prompt,
max_tokens=max_tokens,
temperature=self.config.temperature,
top_p=self.config.top_p,
)
to_score = []
to_backlog = []
# Process completions
for completion in completions.choices:
# Get the completion text
completion_text = completion.text if hasattr(completion, "text") else completion.message.content
# Build full message sequence for scoring
full_messages = []
if self.config.system_prompt:
full_messages.append({"role": "system", "content": self.config.system_prompt})
full_messages.append({"role": "user", "content": user_content})
# Combine prefill with completion if prefill was used
response_content = completion_text
if self.config.prefill:
response_content = self.config.prefill + completion_text
full_messages.append({"role": "assistant", "content": response_content})
# Add to scoring list with answer and ground truth
to_score.append(
(full_messages, answer, item[2] if len(item) > 2 else None)
)
return to_score, to_backlog
async def postprocess_histories(self, trajectories: List) -> Tuple[List, List]:
return trajectories, []
async def collect_trajectories(self, item: Item) -> Tuple[List, List]:
self.current_item = item
# Extract user prompt from item
user_content = dict(item[0][0])["content"]
# Create messages list
messages = []
if self.config.system_prompt:
messages.append({"role": "system", "content": self.config.system_prompt})
messages.append({"role": "user", "content": user_content})
# Add prefill as assistant message if configured
if self.config.prefill:
messages.append({"role": "assistant", "content": self.config.prefill})
# Convert messages to a prompt string using the tokenizer
prompt = self.tokenizer.apply_chat_template(messages, tokenize=False)
# Calculate max tokens for generation (with optional warmup)
max_tokens = self.config.max_tokens
# Generate completions
completions = await self.server.completion(
prompt=prompt,
n=self.config.group_size,
max_tokens=max_tokens,
temperature=self.config.temperature,
top_p=self.config.top_p,
)
print(f"Completions: {completions}")
# Process completions
trajectories = []
for completion in completions.choices:
# Get the completion text
completion_text = completion.text if hasattr(completion, "text") else completion.message.content
# Build complete message sequence
full_messages = []
if self.config.system_prompt:
full_messages.append({"role": "system", "content": self.config.system_prompt})
full_messages.append({"role": "user", "content": user_content})
# Combine prefill with completion if prefill was used
response_content = completion_text
if self.config.prefill:
response_content = self.config.prefill + completion_text
full_messages.append({"role": "assistant", "content": response_content})
trajectories.append(full_messages)
return trajectories, []
async def score(self, rollout_group_data: List) -> Optional[ScoredDataGroup]:
logger.warning(f"Scoring {len(rollout_group_data)} rollout items")
scores = ScoredDataGroup()
scores["tokens"] = []
scores["masks"] = []
scores["scores"] = []
scores["advantages"] = None
scores["ref_logprobs"] = None
scores["messages"] = None if not self.config.include_messages_in_scoring else []
answer = (
self.current_item[1]
if self.current_item and len(self.current_item) > 1
else None
)
logger.warning(f"Answer for current item: {answer}")
ground_truth = (
self.current_item[2]
if self.current_item and len(self.current_item) > 2
else None
)
logger.warning(f"Ground truth for current item: {ground_truth}")
formatted_completions = []
for trajectory in rollout_group_data:
if trajectory and isinstance(trajectory, list):
assistant_messages = [
msg
for msg in trajectory
if isinstance(msg, dict) and msg.get("role") == "assistant"
]
if assistant_messages:
formatted_completions.append([assistant_messages[-1]])
if not formatted_completions:
logger.warning("No valid completions to score")
return None
try:
reward_kwargs = {
"solution": answer,
"ground_truth": ground_truth,
"item": self.current_item,
"config": self.config,
}
all_rewards = self.reward_function(formatted_completions, **reward_kwargs)
logger.info(f"Calculated rewards: {all_rewards}")
except Exception as e:
logger.error(f"Error applying reward functions: {e}")
logger.exception(e)
all_rewards = [0.0] * len(formatted_completions)
for i, (trajectory, reward) in enumerate(zip(rollout_group_data, all_rewards)):
try:
tokenized = tokenize_for_trainer(self.tokenizer, trajectory)
scores["tokens"].append(tokenized["tokens"])
scores["masks"].append(tokenized["masks"])
scores["scores"].append(reward)
if self.config.include_messages_in_scoring:
if "messages" not in scores:
scores["messages"] = []
scores["messages"].append(trajectory)
logger.warning(f"Scores: {scores['scores']}")
except Exception as e:
logger.error(f"Error processing trajectory {i}: {e}")
logger.exception(e)
if not scores["tokens"]:
logger.warning("No valid scores generated")
return None
logger.info(f"Generated scores: {scores['scores']}")
return scores
async def evaluate(self):
if (
not hasattr(self.config, "eval_dataset_name")
or not self.config.eval_dataset_name
):
return
if not hasattr(self, "eval_dataset"):
self.eval_dataset = load_dataset(
self.config.eval_dataset_name,
self.config.eval_dataset_config,
split=self.config.eval_split,
)
self.eval_dataset = self.eval_dataset.select(
range(min(100, len(self.eval_dataset)))
)
eval_metrics = {}
eval_tasks = []
for i in range(min(self.config.max_eval_workers, len(self.eval_dataset))):
item = self.eval_dataset[i]
user_msg = {"role": "user", "content": item[self.config.prompt_field]}
prompt = tuple([frozenset(user_msg.items())])
answer = None
if self.config.answer_field and self.config.answer_field in item:
answer = item[self.config.answer_field]
eval_tasks.append(self.collect_trajectory((prompt, answer)))
eval_results = await asyncio.gather(*eval_tasks)
eval_scores = []
for result in eval_results:
if result[0]:
scored_data = await self.score(result[0])
if scored_data and "scores" in scored_data:
eval_scores.extend(scored_data["scores"])
if eval_scores:
eval_metrics["eval/mean_score"] = sum(eval_scores) / len(eval_scores)
eval_metrics["eval/max_score"] = max(eval_scores)
eval_metrics["eval/min_score"] = min(eval_scores)
await self.wandb_log(eval_metrics)
async def wandb_log(self, wandb_metrics: Optional[Dict] = None):
metrics = wandb_metrics or {}
for key, values in self.metric_buffer.items():
if values:
metrics[f"train/{key}"] = sum(values) / len(values)
self.metric_buffer = {k: [] for k in self.metric_buffer}
if hasattr(self, "reward_function") and self.wandb:
if hasattr(self.reward_function, "set_wandb_logger"):
self.reward_function.set_wandb_logger(self.wandb)
await super().wandb_log(metrics)
if __name__ == "__main__":
# Launch the DatasetEnv via the BaseEnv CLI (serve or process)
DatasetEnv.cli()

View file

@ -0,0 +1,248 @@
#!/usr/bin/env python3
import argparse
import asyncio
import logging
import os
from dotenv import load_dotenv
from atroposlib.envs.base import OpenaiConfig
from atroposlib.envs.reward_fns import registry
from atroposlib.utils.config_handler import ConfigHandler
from environments.dataset_environment.dataset_env import DatasetEnv, DatasetEnvConfig
load_dotenv()
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def parse_arguments():
parser = argparse.ArgumentParser(description="Dataset environment local server")
parser.add_argument(
"--config",
type=str,
default="dataset_local",
help="Configuration file name (without .yaml extension) relative to environments/dataset_environment/configs/, or full path to a YAML file.",
)
return parser.parse_args()
async def main():
logger.info("Starting Dataset environment local server")
# Parse command line arguments
args = parse_arguments()
# Initialize config handler
config_handler = ConfigHandler()
# Determine config path
if (
os.path.isabs(args.config)
or "/" in args.config
or args.config.endswith(".yaml")
):
config_path = args.config
else:
# Assume it's a name relative to the new default directory
config_path = os.path.join(
os.path.dirname(__file__), "configs", f"{args.config}.yaml"
)
logger.info(f"Loading configuration from: {config_path}")
try:
with open(config_path, "r") as f:
import yaml
raw_config = yaml.safe_load(f)
logger.info("Loaded configuration successfully")
except FileNotFoundError:
logger.error(f"Configuration file not found at: {config_path}")
logger.info("Ensure the --config argument is correct or the file exists.")
return
except Exception as e:
logger.error(f"Error loading config from {config_path}: {e}")
return
# Ensure dataset configuration exists (assuming it's top-level in these files)
if "dataset" not in raw_config:
logger.warning(
"'dataset' key not found at the top level of the config file. "
"Assuming the entire file is the dataset configuration."
)
# Treat the whole raw_config as the 'dataset' section for compatibility
dataset_section = raw_config
else:
dataset_section = raw_config["dataset"]
if "dataset_name" not in dataset_section:
logger.error("dataset_name not found in dataset configuration")
return
if "prompt_field" not in dataset_section:
logger.error("prompt_field not found in dataset configuration")
return
# Configure the dataset environment
# Merging logic: Start with raw_config defaults, then dataset_section specifics
env_config_data = {**raw_config, **dataset_section}
# Pydantic will ignore extra fields, so we just pass everything
try:
env_config = DatasetEnvConfig(**env_config_data)
except Exception as pydantic_error:
logger.error(f"Error validating configuration: {pydantic_error}")
return
# Preload reward functions
reward_names_to_load = set()
if env_config.reward_funcs:
reward_names_to_load.update(env_config.reward_funcs)
if env_config.reward_functions:
for rf_config in env_config.reward_functions:
if isinstance(rf_config, str):
reward_names_to_load.add(rf_config)
elif isinstance(rf_config, dict) and "type" in rf_config:
reward_names_to_load.add(rf_config["type"])
if reward_names_to_load:
logger.info(f"Preloading reward functions: {list(reward_names_to_load)}")
for func_name in reward_names_to_load:
try:
registry.get(func_name)
logger.info(f"Successfully loaded reward function: {func_name}")
except Exception as e:
logger.error(f"Failed to load reward function {func_name}: {e}")
# Server configuration - process env vars
server_configs = []
if "server_configs" in raw_config:
for server_config in raw_config["server_configs"]:
api_key = server_config.get("api_key", os.environ.get("OPENAI_API_KEY"))
# Handle environment variable references like ${OPENAI_API_KEY}
if (
isinstance(api_key, str)
and api_key.startswith("${")
and api_key.endswith("}")
):
env_var = api_key[2:-1]
api_key = os.environ.get(env_var, "")
server_configs.append(
OpenaiConfig(
model_name=server_config.get("model_name", "gpt-4.1-nano"),
base_url=server_config.get("base_url", None),
api_key=api_key,
timeout=server_config.get("timeout", 600),
)
)
else:
# Default configuration if not specified in config file
logger.warning(
"No 'server_configs' found in config. Using default OpenAI config."
)
server_configs.append(
OpenaiConfig(
model_name="gpt-4.1-nano",
base_url=None,
api_key=os.environ.get("OPENAI_API_KEY"),
timeout=600,
)
)
# Create the environment
logger.info("Creating dataset environment...")
env = DatasetEnv(
config=env_config,
server_configs=server_configs,
slurm=False,
)
# Setup the environment directly
try:
await env.setup()
logger.info("Environment setup complete")
except Exception as setup_error:
logger.error(f"Error during environment setup: {setup_error}")
return
# --- Start Test Run --- #
logger.info("\n=== Starting Local Test Run ===")
test_items_count = 5
successful_runs = 0
for i in range(test_items_count):
logger.info(f"\n--- Running Test Item {i+1}/{test_items_count} ---")
try:
# Get a sample item from the dataset
item = await env.get_next_item()
if not item or not item[0]:
logger.warning("Failed to get a valid item from the environment.")
continue
prompt, answer, ground_truth = item
user_content = dict(prompt[0])["content"]
logger.info(
f"Prompt: {user_content[:200]}..." if user_content else "(Empty Prompt)"
)
if answer:
logger.info(
f"Answer: {answer[:200]}..." if answer else "(Empty Answer)"
)
if ground_truth:
logger.info(
f"Ground Truth: {ground_truth[:200]}..."
if ground_truth
else "(Empty Ground Truth)"
)
# Collect trajectories (using group_size from config)
logger.info(
f"Collecting {env_config.group_size} trajectories for this item..."
)
trajectories_data, backlog = await env.collect_trajectories(item)
if not trajectories_data:
logger.warning("No trajectories were collected.")
continue
logger.info(f"Collected {len(trajectories_data)} trajectories.")
# Log first trajectory message content for inspection
if trajectories_data[0] and isinstance(trajectories_data[0], list):
first_response = "(Empty or invalid trajectory format)"
assistant_msgs = [
m
for m in trajectories_data[0]
if isinstance(m, dict) and m.get("role") == "assistant"
]
if assistant_msgs:
first_response = assistant_msgs[-1].get("content", "(No content)")
logger.info(f"First Response Content: {first_response[:300]}...")
# Score the collected trajectories
logger.info("Scoring trajectories...")
scored_data = await env.score(trajectories_data)
# Print scores
if scored_data and "scores" in scored_data:
scores_list = scored_data["scores"]
logger.info(f"Scores: {scores_list}")
logger.info(f" Avg Score: {sum(scores_list)/len(scores_list):.4f}")
successful_runs += 1
else:
logger.warning("No scores available in the scored data for this item.")
except Exception as run_error:
logger.error(f"Error during test item {i+1}: {run_error}")
# Optionally continue to the next item or break
# break
logger.info(
f"\n=== Local Test Run Complete ({successful_runs}/{test_items_count} items processed successfully) ==="
)
if __name__ == "__main__":
asyncio.run(main())

View file

@ -0,0 +1,154 @@
#!/usr/bin/env python3
"""
Local dataset training launcher.
Usage:
python -m environments.dataset_environment.launch_local_dataset_run
This script does:
1) Starts the Trajectory Handler API server via uvicorn
2) Launches the DatasetEnv in local serve mode
3) Imports and runs the example trainer (GRPO) directly
Requirements:
- Run from project root so example_trainer is on PYTHONPATH
- example_trainer/ is a valid Python package (with __init__.py)
"""
import os
import sys
import subprocess
import time
import atexit
import signal
import traceback
# Ensure project root is on PYTHONPATH
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..'))
if project_root not in sys.path:
sys.path.insert(0, project_root)
# Import trainer via standard module import
try:
from example_trainer.grpo import TrainingConfig, train
except ImportError as e:
print(f"Error importing example_trainer.grpo: {e}")
print("Ensure you're running from project root and that example_trainer/ is a package.")
sys.exit(1)
# -----------------------------------------------------------------------------
# Configuration
# -----------------------------------------------------------------------------
API_HOST = '127.0.0.1'
API_PORT = 8000
VLLM_HOST = '127.0.0.1'
VLLM_PORT = 9001
MODEL_NAME = 'Qwen/Qwen2.5-1.5B-Instruct'
TOKENIZER_NAME = MODEL_NAME
TRAINER_CONFIG = {
'model_name': MODEL_NAME,
'training_steps': 20,
'batch_size': 2,
'gradient_accumulation_steps': 2,
'seq_len': 512,
'vllm_port': VLLM_PORT,
'vllm_restart_interval': 10,
'use_wandb': False,
'wandb_project': '',
'wandb_group': '',
'save_path': './trained_model_checkpoints_local_test',
}
# Flags for launching DatasetEnv serve
DATASET_FLAGS = [
'--group_size', '4',
'--max_num_workers', '2',
'--rollout_server_url', f"http://{API_HOST}:{API_PORT}",
'--tokenizer_name', TOKENIZER_NAME,
'--use_wandb',
'--wandb_name', 'dataset_env_local_test',
'--max_token_length', str(TRAINER_CONFIG['seq_len']),
'--ensure_scores_are_not_same',
'--dataset_name', 'HuggingFaceH4/testing_self_instruct_process_essays',
'--split', 'train[:100]',
'--prompt_field', 'prompt',
'--answer_field', 'answer',
'--reward_functions', 'length',
'--max_tokens', '128',
'--temperature', '0.7',
'--model_name', MODEL_NAME,
'--base_url', f"http://{VLLM_HOST}:{VLLM_PORT}",
'--slurm',
'--testing',
]
# Track background processes for cleanup
processes = []
def cleanup_processes():
print("\nCleaning up background processes...")
for p in reversed(processes):
if p.poll() is None:
print(f"Terminating PID {p.pid}...")
p.terminate()
try:
p.wait(timeout=5)
print(f"PID {p.pid} terminated.")
except subprocess.TimeoutExpired:
print(f"PID {p.pid} did not terminate; killing.")
p.kill()
p.wait()
print(f"PID {p.pid} killed.")
else:
print(f"PID {p.pid} already exited.")
print("Cleanup complete.")
atexit.register(cleanup_processes)
def handle_signal(sig, frame):
print(f"\nSignal {sig} received; exiting.")
sys.exit(0)
signal.signal(signal.SIGINT, handle_signal)
signal.signal(signal.SIGTERM, handle_signal)
def main():
# 1) Start the API server
print("--- Starting Trajectory Handler API Server ---")
api_cmd = [
'uvicorn',
'atroposlib.api.server:app',
'--host', API_HOST,
'--port', str(API_PORT),
]
print(f"$ {' '.join(api_cmd)}")
api_proc = subprocess.Popen(api_cmd)
processes.append(api_proc)
time.sleep(3)
# 2) Start the dataset environment
print("\n--- Starting Dataset Environment ---")
env_cmd = ['python', '-m', 'environments.dataset_environment.dataset_env', 'serve'] + DATASET_FLAGS
print(f"$ {' '.join(env_cmd)}")
env_proc = subprocess.Popen(env_cmd)
processes.append(env_proc)
time.sleep(3)
# 3) Run the example trainer
print("\n--- Running Example Trainer (GRPO) ---")
config = TrainingConfig(**TRAINER_CONFIG)
try:
train(config)
except Exception:
print("Error during training:")
traceback.print_exc()
print("--- Training complete ---")
if __name__ == '__main__':
main()

View file

@ -0,0 +1,505 @@
import random
import re
from typing import List, Optional, Tuple, Union
from datasets import load_dataset
from tqdm.asyncio import tqdm_asyncio
from atroposlib.envs.base import (
BaseEnv,
BaseEnvConfig,
EvalHandlingEnum,
OpenaiConfig,
ScoredDataGroup,
)
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
# System prompt only contains thinking instructions
system_prompt = """You are a deep thinking AI financial analyst.
You may use extremely long chains of thought to deeply consider the problem and deliberate with yourself via systematic reasoning processes to help come to a correct solution prior to answering.
You should enclose your thoughts and internal monologue inside <think> </think> tags, and then provide your final prediction.""" # noqa E501
# User message template that contains task instructions
user_message_template = """Your task is to analyze the following company fundamentals, news, and macroeconomic data to predict whether the company's {fundamental_metric} will be maintained, raised, or reduced in the next quarter, as well as the magnitude of any change.
Your final answer MUST use the exact format:
"The {fundamental_metric} will be: {{answer}} and the magnitude will be: {{percentage}}%"
Where {{answer}} is one of: "maintained", "raised", or "reduced"
And {{percentage}} is the expected percentage change (0% if maintained).
Here is the data to analyze:
{context}""" # noqa E501
class FundamentalPredictionEnv(BaseEnv):
def __init__(
self,
config: BaseEnvConfig,
server_configs: List[OpenaiConfig],
slurm=True,
testing=False,
):
"""
Initialize the Fundamental Metric Prediction environment.
Args:
config: Configuration for the base environment
server_configs: List of server configurations for OpenAI API
slurm: Whether to use Slurm for distributed training
testing: Whether in testing mode
"""
super().__init__(config, server_configs, slurm, testing)
self.percent_correct_buffer = list()
self.magnitude_accuracy_buffer = list()
self.eval_metrics = list()
@classmethod
def config_init(self) -> Tuple[BaseEnvConfig, List[OpenaiConfig]]:
env_config = BaseEnvConfig(
tokenizer_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview",
group_size=32,
use_wandb=True,
max_num_workers=128,
rollout_server_url="http://localhost:8000",
total_steps=2000,
batch_size=1024,
steps_per_eval=20,
max_token_length=1024 * 16,
inference_weight=1.0,
wandb_name="fundamental_metric_prediction",
data_path_to_save_groups=None,
eval_handling=EvalHandlingEnum.LIMIT_TRAIN,
eval_limit_ratio=0.1,
)
server_configs = [
OpenaiConfig(
model_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview",
base_url="http://localhost:9004/v1",
api_key="x",
num_requests_for_eval=256,
)
]
return env_config, server_configs
async def setup(self):
"""
Set up the environment by loading and preparing the dataset.
"""
# Load the full dataset
full_dataset = load_dataset(
"NousResearch/company-fundamentals-prediction-lite",
"default",
split="train",
)
full_dataset = full_dataset.shuffle(seed=42)
# Create train/test split (95% train, 5% test)
split_dataset = full_dataset.train_test_split(test_size=0.05, seed=42)
# Keep the splits as is - no need to reformat
self.train = split_dataset["train"]
self.test = split_dataset["test"]
# Print some dataset statistics
print(
f"Loaded dataset with {len(self.train)} training examples and {len(self.test)} test examples"
)
print(f"Example item format: {self.train[0]}")
# Initialize iteration counter
self.iter = 0
def save_checkpoint(self, step, data=None):
if data is None:
data = {}
data["iter"] = self.iter
super().save_checkpoint(step, data)
async def get_next_item(self):
"""
Get the next training item from the dataset.
Returns:
A tuple containing prompt and expected answer
"""
next_item = self.train[self.iter % len(self.train)]
self.iter += 1
# Extract context, answer, magnitude and fundamental metric from the dataset item
context = next_item["context"]
answer = next_item["answer"] # "maintained", "raised", or "reduced"
magnitude = next_item["magnitude"] # Percentage as string
fundamental_metric = next_item[
"fundamental_metric"
] # Type of metric to predict
# Create prompt tuple using frozensets as required
prompt = []
# Add system prompt
prompt.append(frozenset({"role": "system", "content": system_prompt}.items()))
# Format user message with context and fundamental metric
user_content = user_message_template.format(
context=context, fundamental_metric=fundamental_metric
)
prompt.append(frozenset({"role": "user", "content": user_content}.items()))
return (tuple(prompt), answer, magnitude, fundamental_metric)
async def collect_trajectories(self, item) -> Tuple[ScoredDataGroup, List]:
"""
Generate and collect model responses for scoring.
Args:
item: Input item containing prompt and expected answer
Returns:
Tuple of lists containing scored data groups and backlog
"""
# Extract messages from the item
messages = []
for role_dict in item[0]:
messages.append(dict(role_dict))
# Apply chat template to convert messages to a single string
prompt = self.tokenizer.apply_chat_template(
messages, add_generation_prompt=True, tokenize=False
)
# Get completions from the model
completions = await self.server.completion(
prompt=prompt,
n=self.config.group_size,
max_tokens=1024 * 15,
temperature=0.8, # Using higher temperature for diverse responses
)
to_score = list()
for _, completion_choice in enumerate(completions.choices):
# Create a copy of the prompt messages
trajectory_messages = []
for role_dict in item[0]:
trajectory_messages.append(dict(role_dict))
# Add the model's response
trajectory_messages.append(
{"role": "assistant", "content": completion_choice.text}
)
# Add to scoring queue with expected answer, magnitude, and fundamental metric
to_score.append(
(
tuple(trajectory_messages),
item[1], # answer (maintained/raised/reduced)
item[2], # magnitude
item[3], # fundamental_metric
)
)
# Call score to get the scored data
scored_data = await self.score(to_score)
to_backlog = []
return scored_data, to_backlog
def _extract_prediction(self, text, fundamental_metric):
"""
Extract the prediction and magnitude from the model's response.
Args:
text: Text containing the model's response
fundamental_metric: The fundamental metric being predicted
Returns:
Tuple of (prediction, magnitude) or (None, None) if extraction fails
"""
# Check for thinking section
think_tags = re.findall(r"<think>", text, re.IGNORECASE)
think_close_tags = re.findall(r"</think>", text, re.IGNORECASE)
# Verify thinking format - must have exactly one opening and one closing tag
if len(think_tags) != 1 or len(think_close_tags) != 1:
return None, None
# Split on </think> to separate thinking from answer
parts = re.split(r"</think>", text, flags=re.IGNORECASE, maxsplit=1)
if len(parts) != 2:
return None, None
thinking_section, answer_section = parts
# Validate thinking section contains opening tag
if "<think>" not in thinking_section.lower():
return None, None
# Escape fundamental_metric for regex
escaped_metric = re.escape(fundamental_metric)
# Extract prediction and magnitude using regex - dynamic to match the fundamental metric
pattern = f"The {escaped_metric} will be:\\s*(maintained|raised|reduced)\\s*and\\s*the\\s*magnitude\\s*will\\s*be:\\s*([-+]?\\d+(?:\\.\\d+)?)%" # noqa E501
# Find all matches to check if there are multiple predictions
all_matches = re.findall(pattern, answer_section, re.IGNORECASE)
# If no matches or multiple matches found, return None
if len(all_matches) != 1:
return None, None
# Extract single match
matches = re.search(pattern, answer_section, re.IGNORECASE)
prediction = matches.group(1).lower()
magnitude = matches.group(2)
return prediction, magnitude
def _calculate_magnitude_score(self, predicted_magnitude, expected_magnitude):
"""
Calculate a score for magnitude prediction accuracy.
Args:
predicted_magnitude: The model's predicted magnitude percentage
expected_magnitude: The expected magnitude percentage
Returns:
Score between 0.0 and 1.0 based on how close the prediction is
"""
try:
# Convert to float for comparison
pred_mag = float(predicted_magnitude)
exp_mag = float(expected_magnitude)
# Calculate absolute difference
diff = abs(pred_mag - exp_mag)
# Score based on closeness to expected magnitude
# Perfect match = 1.0
# Within 1% = 0.9
# Within 5% = 0.7
# Within 10% = 0.5
# Within 20% = 0.3
# More than 20% off = 0.0
if diff == 0:
return 1.0
elif diff <= 1:
return 0.9
elif diff <= 5:
return 0.7
elif diff <= 10:
return 0.5
elif diff <= 20:
return 0.3
else:
return 0.0
except ValueError:
# If conversion fails, return 0
return 0.0
async def score(
self, rollout_group_data
) -> Union[Optional[ScoredDataGroup], List[Optional[ScoredDataGroup]]]:
"""
Score the generated model responses for fundamental metric predictions.
Args:
rollout_group_data: List of generated responses with expected answers
Returns:
ScoredDataGroup with tokenized inputs and scores, or None if no valid scores
"""
scores = ScoredDataGroup()
scores["tokens"] = list()
scores["masks"] = list()
scores["scores"] = list()
# Get the expected answer, magnitude, and fundamental metric
expected_answer = rollout_group_data[0][
1
] # "maintained", "raised", or "reduced"
expected_magnitude = rollout_group_data[0][2] # Expected percentage change
fundamental_metric = rollout_group_data[0][3] # Type of fundamental metric
# Shuffle to avoid bias in selection
random.shuffle(rollout_group_data)
for item in rollout_group_data:
# Extract the model's response
model_response = item[0][-1]["content"]
# Extract the prediction and magnitude from the model's response
prediction, magnitude = self._extract_prediction(
model_response, fundamental_metric
)
# Calculate final score
if prediction is None:
final_score = 0.0 # Invalid format
elif prediction == expected_answer:
# Correct direction: base score of 1 + magnitude bonus
magnitude_score = (
self._calculate_magnitude_score(magnitude, expected_magnitude)
if magnitude is not None
else 0.0
)
final_score = 1.0 + magnitude_score
else:
final_score = 0.0 # Incorrect direction
# Apply length penalty for responses that are too long
response_tokens = len(self.tokenizer.encode(model_response))
if response_tokens > self.config.max_token_length * 0.95:
# Penalize responses that are close to the max token limit
final_score -= 0.5 * (response_tokens / self.config.max_token_length)
# For binary reward signal, any positive score gets +1, otherwise -1
binary_reward = 1.0 if final_score > 0 else -1.0
# Tokenize the conversation for learning
out_dict = tokenize_for_trainer(self.tokenizer, item[0])
tokens = out_dict["tokens"]
masks = out_dict["masks"]
# Remove examples with insufficient context
if len([1 for i in masks if i != -100]) < 10:
continue
scores["tokens"].append(tokens)
scores["masks"].append(masks)
scores["scores"].append(binary_reward)
# For tracking metrics
directional_correct = (
1.0 if prediction == expected_answer and prediction is not None else 0.0
)
self.percent_correct_buffer.append(directional_correct)
if prediction == expected_answer and magnitude is not None:
self.magnitude_accuracy_buffer.append(
self._calculate_magnitude_score(magnitude, expected_magnitude)
)
# Break once we have enough examples
if len(scores["tokens"]) >= self.config.group_size:
break
# Return None if all scores are the same (no learning signal)
if all(scores["scores"][0] == score for score in scores["scores"]):
return None
return scores
async def rollout_and_score_eval(self, test_item):
"""
Generate and score model responses for a single test item.
Args:
test_item: Test item from dataset
Returns:
Dictionary with direction and magnitude scores
"""
# Extract context, answer, magnitude and fundamental metric from the test item
context = test_item["context"]
expected_answer = test_item["answer"]
expected_magnitude = test_item["magnitude"]
fundamental_metric = test_item["fundamental_metric"]
# Format user message with context and fundamental metric
user_content = user_message_template.format(
context=context, fundamental_metric=fundamental_metric
)
# Create messages for model
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_content},
]
# Apply chat template to convert messages to a single string
prompt = self.tokenizer.apply_chat_template(
messages, add_generation_prompt=True, tokenize=False
)
# Get model completion
completion = await self.server.completion(
prompt=prompt,
n=1,
max_tokens=1024 * 16,
temperature=0.2, # Lower for eval
split="eval",
)
# Extract the model's response
model_response = completion.choices[0].text
# Extract prediction and magnitude
prediction, magnitude = self._extract_prediction(
model_response, fundamental_metric
)
# Calculate direction score (1 for correct, 0 for incorrect)
direction_score = (
1 if prediction == expected_answer and prediction is not None else 0
)
# Calculate magnitude score if direction is correct
magnitude_score = 0
if direction_score == 1 and magnitude is not None:
magnitude_score = self._calculate_magnitude_score(
magnitude, expected_magnitude
)
# Calculate combined score (1 + magnitude_score for correct direction, 0 for incorrect)
combined_score = (1 + magnitude_score) if direction_score == 1 else 0
return {
"direction_score": direction_score,
"magnitude_score": magnitude_score,
"combined_score": combined_score,
}
async def evaluate(self, *args, **kwargs):
"""
Evaluate the model on test data.
"""
eval_tasks = []
for test_item in self.test:
eval_tasks.append(self.rollout_and_score_eval(test_item))
# Run evaluation
all_scores = await tqdm_asyncio.gather(*eval_tasks)
# Calculate aggregate metrics
direction_scores = [score["direction_score"] for score in all_scores]
magnitude_scores = [
score["magnitude_score"]
for score in all_scores
if score["direction_score"] == 1
]
combined_scores = [score["combined_score"] for score in all_scores]
# Calculate and log metrics
direction_accuracy = (
sum(direction_scores) / len(direction_scores) if direction_scores else 0
)
magnitude_accuracy = (
sum(magnitude_scores) / len(magnitude_scores) if magnitude_scores else 0
)
average_combined_score = (
sum(combined_scores) / len(combined_scores) if combined_scores else 0
)
self.eval_metrics.append(("eval/direction_accuracy", direction_accuracy))
self.eval_metrics.append(("eval/magnitude_accuracy", magnitude_accuracy))
self.eval_metrics.append(("eval/combined_score", average_combined_score))
if __name__ == "__main__":
FundamentalPredictionEnv.cli()

View file

@ -0,0 +1,295 @@
import random
from typing import Dict, List, Optional, Tuple, TypedDict, Union
from datasets import load_dataset
from latex2sympy2_extended import NormalizationConfig
from math_verify import LatexExtractionConfig, parse, verify
from tqdm.asyncio import tqdm_asyncio
from atroposlib.envs.base import BaseEnv, BaseEnvConfig, OpenaiConfig, ScoredDataGroup
from atroposlib.type_definitions import Item, number
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
system_prompt = (
"You are a deep thinking AI, you may use extremely long chains of thought "
"to deeply consider the problem and deliberate with yourself via systematic "
"reasoning processes to help come to a correct solution prior to answering. "
"You should enclose your thoughts and internal monologue inside <think> </think> "
"tags, and then provide your solution or response to the problem.\n\n"
)
system_prompt += """You are allocated a maximum of 2048 tokens, please strive to use less.
You will then provide your answer like this: \\boxed{your answer here}
It is important that you provide your answer in the correct format.
If you do not, you will not receive credit for your answer.
So please end your answer with \\boxed{your answer here}"""
class GSM8kRow(TypedDict):
question: str
answer: str
class GSM8kEnv(BaseEnv):
name = "gsm8k"
def __init__(
self,
config: BaseEnvConfig,
server_configs: List[OpenaiConfig],
slurm=True,
testing=False,
):
super().__init__(config, server_configs, slurm, testing)
self.percent_correct_buffer = list()
self.eval_metrics = list()
# Add tracking for wandb visualizations
self.rollouts_for_wandb = []
self.completion_lengths = []
@classmethod
def config_init(cls) -> Tuple[BaseEnvConfig, List[OpenaiConfig]]:
env_config = BaseEnvConfig(
tokenizer_name="NousResearch/DeepHermes-3-Llama-3-3B-Preview",
group_size=8,
use_wandb=True,
rollout_server_url="http://localhost:8000",
total_steps=1000,
batch_size=12,
steps_per_eval=100,
max_token_length=2048,
wandb_name="gsm8k",
)
server_configs = [
OpenaiConfig(
model_name="NousResearch/DeepHermes-3-Llama-3-3B-Preview",
base_url="http://localhost:9001/v1",
api_key="x",
num_requests_for_eval=256,
),
]
return env_config, server_configs
async def wandb_log(self, wandb_metrics: Optional[Dict] = None):
if wandb_metrics is None:
wandb_metrics = {}
# Try to calculate percent_correct, pass if there's a division by zero
try:
wandb_metrics["train/percent_correct"] = sum(
self.percent_correct_buffer
) / len(self.percent_correct_buffer)
except ZeroDivisionError:
# Skip if buffer is empty
pass
self.percent_correct_buffer = list()
for item in self.eval_metrics:
wandb_metrics[item[0]] = item[1]
self.eval_metrics = list()
# Call the parent method to handle the server metrics
await super().wandb_log(wandb_metrics)
async def setup(self):
self.train = load_dataset("gsm8k", "main", split="train").shuffle(seed=42)
test_data = load_dataset("gsm8k", "main", split="test").shuffle(seed=42)
self.test = list()
for item in test_data:
self.test.append(
{
"question": item["question"],
"gold_answer": item["answer"]
.split("#")[-1]
.strip()
.replace(",", ""),
}
)
self.iter = 0
def save_checkpoint(self, step, data=None):
if data is None:
data = {}
data["iter"] = self.iter
super().save_checkpoint(step, data)
async def rollout_and_score_eval(self, question: str, answer: str) -> number:
completion = await self.server.chat_completion(
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": question},
],
n=1,
max_tokens=self.config.max_token_length,
temperature=0.0,
split="eval",
)
gold_parsed = parse(
"\\boxed{" + answer + "}",
extraction_mode="first_match",
extraction_config=[LatexExtractionConfig()],
)
answer_parsed = parse(
completion.choices[0].message.content.split("</think>")[-1],
extraction_config=[
LatexExtractionConfig(
normalization_config=NormalizationConfig(
nits=False,
malformed_operators=False,
basic_latex=True,
equations=True,
boxed="all",
units=True,
),
# Ensures that boxed is tried first
boxed_match_priority=0,
try_extract_without_anchor=False,
)
],
extraction_mode="first_match",
)
score = 1 if verify(answer_parsed, gold_parsed) else 0
return score
async def evaluate(self, *args, **kwargs):
eval_tasks = []
for item in self.test:
eval_tasks.append(
self.rollout_and_score_eval(item["question"], item["gold_answer"])
)
scores = await tqdm_asyncio.gather(*eval_tasks)
self.eval_metrics.append(("eval/percent_correct", sum(scores) / len(scores)))
async def collect_trajectories(
self, item: GSM8kRow
) -> Tuple[ScoredDataGroup, list[Item]]:
user_message = {"role": "user", "content": item["question"]}
gold_answer = (
"\\boxed{" + item["answer"].split("#")[-1].strip().replace(",", "") + "}"
)
chat_completions = await self.server.chat_completion(
messages=[{"role": "system", "content": system_prompt}, user_message],
n=self.config.group_size,
max_tokens=self.config.max_token_length,
)
to_score = list()
to_backlog = list()
for i, chat_completion in enumerate(chat_completions.choices):
messages = (
{"role": "system", "content": system_prompt},
user_message,
{"role": "assistant", "content": chat_completion.message.content},
)
to_score.append(
{
"messages": messages,
"gold_answer": gold_answer,
"finish_reason": chat_completion.finish_reason,
}
)
to_postprocess = await self.score(to_score)
return to_postprocess, to_backlog
async def score(
self, rollout_group_data
) -> Union[Optional[ScoredDataGroup], List[Optional[ScoredDataGroup]]]:
scores = ScoredDataGroup()
scores["tokens"] = list()
scores["masks"] = list()
scores["scores"] = list()
gold_parsed = parse(
rollout_group_data[0]["gold_answer"],
extraction_mode="first_match",
extraction_config=[LatexExtractionConfig()],
)
if len(gold_parsed) != 0:
# We require the answer to be provided in correct latex (no malformed operators)
random.shuffle(rollout_group_data)
for item in rollout_group_data:
# print(item[0][-1]["content"])
answer_parsed = parse(
item["messages"][-1]["content"].split("</think>")[-1],
extraction_config=[
LatexExtractionConfig(
normalization_config=NormalizationConfig(
nits=False,
malformed_operators=False,
basic_latex=True,
equations=True,
boxed="all",
units=True,
),
# Ensures that boxed is tried first
boxed_match_priority=0,
try_extract_without_anchor=False,
)
],
extraction_mode="first_match",
)
# Reward 1 if the content is the same as the ground truth, 0 otherwise
reward = verify(answer_parsed, gold_parsed)
# print(
# f"message: {item[0][-1]['content']}, ground_truth: {item[1]}, reward: {reward}"
# )
out_dict = tokenize_for_trainer(
self.tokenizer, item["messages"], item["finish_reason"]
)
tokens = out_dict["tokens"]
masks = out_dict["masks"]
# remove obviously bad examples
if len([1 for i in masks if i != -100]) < 10:
continue
scores["tokens"].append(tokens)
scores["masks"].append(masks)
scores["scores"].append(1.0 if reward else -1.0)
if len(scores["tokens"]) >= self.config.group_size:
break
for score in scores["scores"]:
self.percent_correct_buffer.append(max(score, 0))
# check if all the same
# print(scores['scores'])
if all([score == 1 for score in scores["scores"]]):
# Do length penalty :)
token_lengths = [len(token) for token in scores["tokens"]]
if max(token_lengths) == 0:
# What? But don't want to crash a run so just in case...
return None
# Get max allowed token length from config
max_allowed_length = self.config.max_token_length
# Set threshold at 50% of max_token_length - no penalty below this
length_threshold = max_allowed_length * 0.5
# Apply modified length penalty with threshold
scores["scores"] = []
for length in token_lengths:
if length <= length_threshold:
# No penalty for responses under threshold
scores["scores"].append(1.0)
else:
# Calculate how far we are between threshold and max as a percentage
percentage_of_range = (length - length_threshold) / (
max_allowed_length - length_threshold
)
# Cap at 1.0 in case length exceeds max_allowed_length
percentage_of_range = min(percentage_of_range, 1.0)
# Apply linear penalty scaling from 1.0 down to 0.0
scores["scores"].append(1.0 - percentage_of_range)
if all([scores["scores"][0] == score for score in scores["scores"]]):
return None # If all the same, we return None
return scores
else:
# If the gold solution is not parseable, we return None
return None
async def get_next_item(self) -> GSM8kRow:
next_item = self.train[self.iter % len(self.train)]
self.iter += 1
return next_item
if __name__ == "__main__":
GSM8kEnv.cli()

1030
environments/math_server.py Normal file

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,447 @@
"""
This file contains code inspired by and adapted from the Open-Reasoner-Zero project.
Original Repository: https://github.com/Open-Reasoner-Zero/Open-Reasoner-Zero
"""
import asyncio
import random
import re
from concurrent.futures import ProcessPoolExecutor
from typing import Dict, List, Optional, Tuple
from datasets import load_dataset
from latex2sympy2_extended import NormalizationConfig
from math_verify import LatexExtractionConfig, parse, verify
from math_verify.errors import TimeoutException
from pydantic import Field
from tqdm.asyncio import tqdm_asyncio
import wandb
from atroposlib.envs.base import (
BaseEnv,
BaseEnvConfig,
EvalHandlingEnum,
OpenaiConfig,
ScoredDataGroup,
)
prompt_format = (
"A conversation between User and Assistant. The User asks a question, and the Assistant solves it. The Assistant "
"first thinks about the reasoning process in the mind and then provides the User with the answer. The reasoning "
"process is enclosed within <think> </think> and answer is enclosed within <answer> </answer> tags, respectively, "
"i.e., <think> reasoning process here </think> <answer> answer here </answer>. User: {prompt}\nAssistant: <think>"
)
problem_format = """You must put your answer inside <answer> </answer> tags, i.e., <answer> answer here </answer>. And your final answer will be extracted automatically by the \\boxed{{}} tag.
This is the problem:
{problem}
""" # noqa: E501
stop_list = ["User:", "Human:", "Assistant:", "</answer>"]
class RSConfig(BaseEnvConfig):
run_evaluation: bool = Field(True, description="If this should run evaluation")
mask_too_long_completions: bool = Field(
True, description="If this should mask too long completions"
)
percent_length_penalty: float = Field(
0.0, description="The percentage of items to have length penalty"
)
def score_answer(gold, resp) -> Optional[bool]:
pattern = re.compile(r"<answer>.*?(\\boxed{.*}).*?</answer>", re.DOTALL)
matches = pattern.findall(resp)
resp = matches[-1] if matches else None
if resp is None:
return False
try:
gold_parsed = parse(
gold,
extraction_mode="first_match",
extraction_config=[LatexExtractionConfig()],
)
except (Exception, TimeoutException, KeyError, TypeError, NotImplementedError):
return None
if len(gold_parsed) != 0:
try:
answer_parsed = parse(
resp,
extraction_config=[
LatexExtractionConfig(
normalization_config=NormalizationConfig(
nits=False,
malformed_operators=False,
basic_latex=True,
boxed="all",
units=True,
),
# Ensures that boxed is tried first
boxed_match_priority=0,
try_extract_without_anchor=False,
)
],
extraction_mode="first_match",
)
except (
Exception,
TimeoutException,
KeyError,
TypeError,
NotImplementedError,
):
# Can't parse, so we skip
return None
# Reward 1 if the content is the same as the ground truth, 0 otherwise
try:
return verify(answer_parsed, gold_parsed)
except (
Exception,
TimeoutException,
KeyError,
TypeError,
NotImplementedError,
):
return None
return None
class MathEnv(BaseEnv):
name = "math"
env_config_cls = RSConfig
def __init__(
self,
config: RSConfig,
server_configs: List[OpenaiConfig],
slurm=True,
testing=False,
):
print("Initializing MathEnv")
print(f"Slurm: {slurm}, Testing: {testing}")
super().__init__(config, server_configs, slurm, testing)
self.percent_correct_buffer = list()
self.eval_metrics = list()
self.mp_executor = ProcessPoolExecutor(64)
self.percent_overanswer = list()
self.correct_answer_len = list()
self.incorrect_answer_len = list()
self.normal_rollouts = list()
self.pass_at_groupsize = list()
self.iter = 0
@classmethod
def config_init(cls) -> Tuple[RSConfig, List[OpenaiConfig]]:
env_config = RSConfig(
tokenizer_name="Qwen/Qwen2.5-7B",
group_size=8,
use_wandb=True,
rollout_server_url="http://localhost:8000",
total_steps=1000,
batch_size=1024,
steps_per_eval=25,
max_token_length=31000, # 22000 // (2 ** i),
wandb_name="math",
eval_handling=EvalHandlingEnum.LIMIT_TRAIN,
eval_limit_ratio=0.1,
)
server_configs = [
OpenaiConfig(
model_name="default",
base_url="http://localhost:9004/v1",
api_key="x",
num_requests_for_eval=256, # since evaling only on one...
),
]
return env_config, server_configs
async def wandb_log(self, wandb_metrics: Optional[Dict] = None):
if wandb_metrics is None:
wandb_metrics = dict()
if len(self.pass_at_groupsize) > 0:
wandb_metrics["train/pass_at_groupsize"] = sum(
self.pass_at_groupsize
) / len(self.pass_at_groupsize)
self.pass_at_8 = list()
if len(self.percent_correct_buffer) > 0:
wandb_metrics["train/percent_correct"] = sum(
self.percent_correct_buffer
) / len(self.percent_correct_buffer)
wandb_metrics["train/percent_overanswer"] = sum(
self.percent_overanswer
) / len(self.percent_overanswer)
self.percent_overthink = list()
self.percent_overanswer = list()
self.percent_correct_buffer = list()
if len(self.correct_answer_len) > 0:
wandb_metrics["train/avg_correct_answer_len"] = sum(
self.correct_answer_len
) / len(self.correct_answer_len)
self.correct_answer_len = list()
if len(self.incorrect_answer_len) > 0:
wandb_metrics["train/avg_incorrect_answer_len"] = sum(
self.incorrect_answer_len
) / len(self.incorrect_answer_len)
self.incorrect_answer_len = list()
# create tables
if len(self.normal_rollouts) > 0:
table = wandb.Table(columns=["problem", "solution", "answer", "score"])
for group in self.normal_rollouts:
table.add_data(group[0], group[1], group[2], group[3])
wandb_metrics["train/normal_rollouts"] = table
wandb_metrics["train/iter"] = self.iter
for item in self.eval_metrics:
wandb_metrics[item[0]] = item[1]
self.eval_metrics = list()
await super().wandb_log(wandb_metrics)
async def setup(self):
self.train = load_dataset("zwhe99/DeepMath-103K", split="train").shuffle(
seed=42
)
aime_test_data = load_dataset("HuggingFaceH4/aime_2024", split="train")
math500_test_data = load_dataset("HuggingFaceH4/math-500", split="test")
amc_test_data = load_dataset("math-ai/amc23", split="test")
minerva_test_data = load_dataset("math-ai/minervamath", split="test")
olympiad_test_data = load_dataset("math-ai/olympiadbench", split="test")
self.test = list()
for name, t_dataset in zip(
["aime24", "math500"], [aime_test_data, math500_test_data]
):
for item in t_dataset:
self.test.append(
(
prompt_format.format(
prompt=problem_format.format(problem=item["problem"])
),
item["answer"],
name,
)
)
for name, t_dataset in zip(
["amc23", "minerva", "olympiad"],
[amc_test_data, minerva_test_data, olympiad_test_data],
):
for item in t_dataset:
self.test.append(
(
prompt_format.format(
prompt=problem_format.format(problem=item["question"])
),
item["answer"],
name,
)
)
return
async def rollout_and_score_eval(self, question, answer, subset):
completion = await self.server.completion(
prompt=question,
n=1,
max_tokens=32765,
temperature=0.0,
split="eval",
stop=stop_list,
)
loop = asyncio.get_event_loop()
gold = "\\boxed{" + answer + "}" if "\\boxed" not in answer else answer
resp = completion.choices[0].text
if completion.choices[0].finish_reason == "stop":
if ("</answer>" not in completion.choices[0].text) and (
"<answer>" in completion.choices[0].text
):
# assume it stopped on </answer>
resp = resp + " </answer>"
task = loop.run_in_executor(self.mp_executor, score_answer, gold, resp)
reward = await task
if reward is None:
return 0, subset
score = 1 if reward else 0
return score, subset
async def evaluate(self, *args, **kwargs):
if not self.config.run_evaluation:
return
eval_tasks = []
for item in self.test:
eval_tasks.append(self.rollout_and_score_eval(item[0], item[1], item[2]))
parsing_data = await tqdm_asyncio.gather(*eval_tasks)
task_lists = dict()
for score, subset in parsing_data:
if subset not in task_lists:
task_lists[subset] = list()
task_lists[subset].append(score)
# Now get the average
for subset, scores in task_lists.items():
self.eval_metrics.append(
(f"eval/{subset}_percent_correct", sum(scores) / len(scores))
)
# overall score
scores = []
for subset, score in task_lists.items():
scores.extend(score)
self.eval_metrics.append(
("eval/overall_percent_correct", sum(scores) / len(scores))
)
async def collect_trajectories(self, item) -> Tuple[List, List]:
thinking_len = self.config.max_token_length
user_prompt = prompt_format.format(
prompt=problem_format.format(problem=item[0])
)
thinking_len = thinking_len - len(self.tokenizer.encode(user_prompt))
completions = await self.server.completion(
prompt=user_prompt,
n=self.config.group_size,
max_tokens=thinking_len,
temperature=1.0,
top_p=0.95,
stop=stop_list,
)
to_score = list()
to_backlog = list()
for i, completion in enumerate(completions.choices):
message = user_prompt + completion.text
if completion.finish_reason == "stop":
if ("</answer>" not in completion.text) and (
"<answer>" in completion.text
):
# assume it stopped on </answer>
message = message + " </answer>"
to_score.append(
(
message,
item[1],
completion.finish_reason,
user_prompt,
)
)
to_postprocess = await self.score(to_score)
if to_postprocess is None:
return None, to_backlog
if all(
[to_postprocess["scores"][0] == score for score in to_postprocess["scores"]]
):
return None, to_backlog
self.normal_rollouts.append(
(
prompt_format.format(prompt=problem_format.format(problem=item[0])),
to_postprocess["messages"][0][-1]["content"],
item[1],
to_postprocess["scores"][0],
)
)
if len(self.normal_rollouts) > self.config.num_rollouts_to_keep:
self.normal_rollouts.pop(0)
print(f"Collected {len(to_postprocess['scores'])} trajectories")
return to_postprocess, to_backlog
async def score(self, rollout_group_data: List) -> Optional[ScoredDataGroup]:
scores = ScoredDataGroup()
scores["tokens"] = list()
scores["masks"] = list()
scores["scores"] = list()
scores["overrides"] = list()
scores["messages"] = list()
gold = rollout_group_data[0][1]
loop = asyncio.get_event_loop()
random.shuffle(rollout_group_data)
for item in rollout_group_data:
resp = item[0]
scores["overrides"].append(dict())
if item[2] == "length":
reward = False
if self.config.mask_too_long_completions:
scores["overrides"][-1]["set_advantage_to_zero"] = True
else:
task = loop.run_in_executor(self.mp_executor, score_answer, gold, resp)
reward = await task
if reward is None:
return None
tokens = self.tokenizer.encode(resp)
user_prompt_tokens = self.tokenizer.encode(item[3])
if user_prompt_tokens[-1] == self.tokenizer.eos_token_id:
user_prompt_tokens = user_prompt_tokens[:-1]
assert all(
[
i == j
for i, j in zip(
user_prompt_tokens, tokens[: len(user_prompt_tokens)]
)
]
)
masks = [-100 for _ in range(len(user_prompt_tokens))]
masks = masks + tokens[len(user_prompt_tokens) :]
messages = [
{"role": "user", "content": item[3]},
{"role": "assistant", "content": resp[len(item[3]) :]},
]
# remove obviously bad examples
if len([1 for i in masks if i != -100]) < 10:
continue
if (item[2] == "length") and (not self.config.mask_too_long_completions):
scores["overrides"][-1]["set_advantage_to_zero"] = True
scores["tokens"].append(tokens)
scores["masks"].append(masks)
scores["scores"].append(1.0 if reward else -1.0)
scores["messages"].append(messages)
if len(scores["tokens"]) >= self.config.group_size:
break
if any([score == 1.0 for score in scores["scores"]]):
self.pass_at_groupsize.append(1.0)
else:
self.pass_at_groupsize.append(0.0)
if len(scores["tokens"]) < self.config.group_size:
# We don't have enough data to score
return None
for score in scores["scores"]:
self.percent_correct_buffer.append(max(score, 0))
self.percent_overanswer.extend(
[item[2] == "length" for item in rollout_group_data]
)
# check if all the same
# print(scores['scores'])
# Fill in the correct/incorrect lens after so we're only looking at actual training data
self.correct_answer_len.extend(
[
len(scores["tokens"][i])
for i in range(len(scores["scores"]))
if scores["scores"][i] == 1.0
]
)
self.incorrect_answer_len.extend(
[
len(scores["tokens"][i])
for i in range(len(scores["scores"]))
if (scores["scores"][i] == -1.0)
and (not scores["overrides"][i].get("set_advantage_to_zero", False))
]
)
return scores
async def get_next_item(self):
while True:
next_item = self.train[self.iter % len(self.train)]
self.iter += 1
prompt = next_item["question"]
try:
answer = (
("\\boxed{" + next_item["final_answer"] + "}")
if "\\boxed" not in next_item["final_answer"]
else next_item["final_answer"]
)
break
except TypeError:
print(
f"Error in getting next item, trying again, "
f"data: {next_item['question']} -> {next_item['final_answer']}"
)
return (prompt, answer, "normal")
if __name__ == "__main__":
MathEnv.cli()

View file

@ -0,0 +1,492 @@
import random
import re
from typing import Dict, List, Optional, Tuple, Union
import wandb
from datasets import load_dataset
from tqdm.asyncio import tqdm_asyncio
from atroposlib.envs.base import (
BaseEnv,
BaseEnvConfig,
EvalHandlingEnum,
Item,
OpenaiConfig,
ScoredDataGroup,
)
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
system_prompt = (
"You are a deep thinking AI, you may use extremely long chains of thought to deeply consider the "
"problem and deliberate with yourself via systematic reasoning processes to help come to a correct "
"solution prior to answering. You should enclose your thoughts and internal monologue inside <think> "
"</think> tags, and then provide your solution or response to the problem."
)
class MCQAThinkingEnv(BaseEnv):
def __init__(
self,
config: BaseEnvConfig,
server_configs: List[OpenaiConfig],
slurm=True,
testing=False,
):
"""
Initialize the MCQA (Multiple Choice Question Answering) environment.
Args:
config: Configuration for the base environment
server_configs: List of server configurations for OpenAI API
slurm: Whether to use Slurm for distributed training
testing: Whether in testing mode
"""
super().__init__(config, server_configs, slurm, testing)
self.percent_correct_buffer = list()
self.eval_metrics = list()
@classmethod
def config_init(self) -> Tuple[BaseEnvConfig, List[OpenaiConfig]]:
env_config = BaseEnvConfig(
tokenizer_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview",
group_size=32,
use_wandb=True,
max_num_workers=128,
rollout_server_url="http://localhost:8000",
total_steps=2000,
batch_size=1024,
steps_per_eval=20,
max_token_length=1024 * 15,
inference_weight=1.0,
wandb_name="mcqa_deep_thinking",
data_path_to_save_groups=None,
eval_handling=EvalHandlingEnum.LIMIT_TRAIN,
eval_limit_ratio=0.1,
)
server_configs = [
OpenaiConfig(
model_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview",
base_url="http://localhost:9004/v1",
api_key="x",
num_max_requests_at_once=32,
num_requests_for_eval=256,
)
]
return env_config, server_configs
async def setup(self):
"""
Set up the environment by loading and preparing the dataset.
"""
# Load the full dataset
full_dataset = load_dataset(
"NousResearch/AcademicMCQA", "default", split="train"
)
full_dataset = full_dataset.shuffle(seed=42)
# Create train/test split on the fly (e.g., 95% train, 5% test)
split_dataset = full_dataset.train_test_split(test_size=0.02, seed=42)
# Keep the splits as is - no need to reformat
self.train = split_dataset["train"]
self.test = split_dataset["test"]
# Print some dataset statistics
print(
f"Loaded dataset with {len(self.train)} training examples and {len(self.test)} test examples"
)
print(f"Example item format: {self.train[0]}")
# Initialize iteration counter
self.iter = 0
def save_checkpoint(self, step, data=None):
if data is None:
data = {}
data["iter"] = self.iter
super().save_checkpoint(step, data)
async def get_next_item(self):
"""
Get the next training item from the dataset.
Returns:
A tuple containing prompt and expected answer
"""
next_item = self.train[self.iter % len(self.train)]
self.iter += 1
# Extract question and options from the multiple choice item
question_text = next_item["prompt"]
correct_answer_index = next_item["answer"]
ground_truth_letter = next_item["ground_truth"]
options = next_item["options"]
# Append the answer format instruction to the prompt
question_text_with_instruction = f'{question_text}\n\nProvide your answer by saying "The best answer is: {{Answer}}"' # noqa E501
# Create prompt tuple using frozensets as required
prompt = []
# Add system prompt as defined at the top of the script
prompt.append(frozenset({"role": "system", "content": system_prompt}.items()))
# Add user message with the question and instruction
prompt.append(
frozenset(
{"role": "user", "content": question_text_with_instruction}.items()
)
)
# Prepare the expected answer
# We'll use the ground_truth_letter (A, B, C, D) as the expected answer
# The scoring function will need to check if the model response contains this letter
answer = ground_truth_letter
answer_string = options[correct_answer_index]
return (tuple(prompt), answer, answer_string)
async def collect_trajectories(self, item) -> Tuple[ScoredDataGroup, List]:
"""
Generate and collect model responses for scoring.
Args:
item: Input item containing prompt and expected answer
Returns:
Tuple of lists containing scored data groups and backlog
"""
# Extract messages from the item
messages = []
for role_dict in item[0]:
messages.append(dict(role_dict))
# Apply chat template to convert messages to a single string
prompt = self.tokenizer.apply_chat_template(
messages, add_generation_prompt=True, tokenize=False
)
# Get completions from the model using completion() instead of chat_completion()
completions = await self.server.completion(
prompt=prompt,
n=self.config.group_size,
max_tokens=1024 * 15,
temperature=1.0, # Using temperature to get diverse responses
)
to_score = list()
for i, completion_choice in enumerate(completions.choices):
# Create a copy of the prompt messages
trajectory_messages = []
for role_dict in item[0]:
trajectory_messages.append(dict(role_dict))
# Add the model's response
trajectory_messages.append(
{"role": "assistant", "content": completion_choice.text}
)
# Add to scoring queue with expected answer, ground truth text, and stop reason
to_score.append(
(
tuple(trajectory_messages),
item[1], # Letter (A, B, C, D)
item[2], # Include the answer_string/ground_truth_text
completion_choice.finish_reason, # Add the stop reason
)
)
# Call score to get the scored data
scored_data = await self.score(to_score)
to_backlog = []
return scored_data, to_backlog
def _extract_mcqa_answer(self, text, ground_truth_text, ground_truth_letter):
"""
Extract the multiple choice answer (A, B, C, or D) from model response.
Only allows one valid answer format - multiple answer formats result in a score of 0.
Args:
text: Text containing the model's response
ground_truth_text: The full text of the correct answer
ground_truth_letter: The letter (A, B, C, D) of the correct answer
Returns:
Extracted answer letter or None if invalid response pattern is found
"""
# Check for multiple <think> tags - score as 0 if found
think_tags = re.findall(r"<think>", text, re.IGNORECASE)
if len(think_tags) > 1:
return None
# Check if the think tag is properly opened - we need exactly one opening tag
if len(think_tags) != 1:
return None
# Check for </think> closing tags
think_close_tags = re.findall(r"</think>", text, re.IGNORECASE)
if len(think_close_tags) != 1:
return None # Must have exactly one closing tag
# Split the text into thinking and answer sections
parts = re.split(r"</think>", text, flags=re.IGNORECASE, maxsplit=1)
# If there's no </think> tag or multiple sections, return None
if len(parts) != 2:
return None
thinking_section, answer_section = parts
# Validate thinking section
# Make sure thinking section actually contains the opening <think> tag
if "<think>" not in thinking_section.lower():
return None # Malformed thinking section
# Check if there are any <think> tags in the answer section (after the first </think>)
if "<think>" in answer_section.lower():
return None
# More flexible answer patterns that handle parentheses and additional text
answer_patterns = [
r"The correct answer is:?\s*(?:\*\*)?(A|B|C|D)(?:\*\*)?(?:\)|\.|:)?(?:[^A-Da-d]*.*?)?(?=$|\n|\.)", # noqa W605
r"The best answer is:?\s*(?:\*\*)?(A|B|C|D)(?:\*\*)?(?:\)|\.|:)?(?:[^A-Da-d]*.*?)?(?=$|\n|\.)", # noqa W605
r"The answer is:?\s*(?:\*\*)?(A|B|C|D)(?:\*\*)?(?:\)|\.|:)?(?:[^A-Da-d]*.*?)?(?=$|\n|\.)", # noqa W605
r"\*\*The best answer is\s*(A|B|C|D)\*\*(?:\)|\.|:)?(?:[^A-Da-d]*.*?)?(?=$|\n|\.)", # noqa W605
r"\*\*The best answer is:\s*(A|B|C|D)\*\*(?:\)|\.|:)?(?:[^A-Da-d]*.*?)?(?=$|\n|\.)", # noqa W605
r"Thus, final answer:\s*(A|B|C|D)\)(?:\)|\.|:)?(?:[^A-Da-d]*.*?)?(?=$|\n|\.)", # noqa W605
r"\\boxed{(A|B|C|D)}(?:\)|\.|:)?(?:[^A-Da-d]*.*?)?(?=$|\n|\.)", # noqa W605
]
string_patterns = [
# Patterns to match exact ground truth text, with optional markdown bold formatting
r"The correct answer is:?\s(?:\*\*)?"
+ re.escape(ground_truth_text)
+ r"(?:\*\*)?(?:[^A-Da-d]*.*?)?(?=$|\n|\.)",
r"The best answer is:?\s(?:\*\*)?"
+ re.escape(ground_truth_text)
+ r"(?:\*\*)?(?:[^A-Da-d]*.*?)?(?=$|\n|\.)",
r"The answer is:?\s(?:\*\*)?"
+ re.escape(ground_truth_text)
+ r"(?:\*\*)?(?:[^A-Da-d]*.*?)?(?=$|\n|\.)",
]
# Track all found answers
found_answers = []
# Check each pattern
for pattern in answer_patterns:
matches = re.findall(pattern, answer_section, re.IGNORECASE)
if matches:
for match in matches:
# Extract just the letter
found_answers.append(match.upper())
for pattern in string_patterns:
matches = re.findall(pattern, answer_section, re.IGNORECASE)
if matches:
# For each match found, append the ground truth letter instead of the full match
for _ in matches:
found_answers.append(ground_truth_letter)
# If no answers found or multiple answers found, return None
if len(found_answers) != 1:
return None
# Return the single found answer
return found_answers[0]
async def score(self, rollout_group_data: List) -> Optional[ScoredDataGroup]:
"""
Score the generated model responses against expected MCQA answers.
Args:
rollout_group_data: List of generated responses with expected answers
Returns:
ScoredDataGroup with tokenized inputs and scores, or None if no valid scores
"""
scores = ScoredDataGroup()
scores["tokens"] = list()
scores["masks"] = list()
scores["scores"] = list()
# Get the expected answer letter
expected_answer = rollout_group_data[0][1] # Letter A, B, C, D
ground_truth_text = rollout_group_data[0][2]
# Shuffle to avoid bias in selection
random.shuffle(rollout_group_data)
for item in rollout_group_data:
# Extract the model's response
model_response = item[0][-1]["content"]
stop_reason = item[3] # Get the stop reason
# If the response was cut off due to length, give it a score of 0
if stop_reason == "length":
reward = 0
else:
# Extract the answer from the model's response
model_answer = self._extract_mcqa_answer(
model_response, ground_truth_text, expected_answer
)
# Track metrics based on result
if model_answer is None:
reward = 0 # Invalid format gets 0 reward
elif model_answer == expected_answer:
reward = 1 # Correct answer gets 1 reward
else:
reward = 0 # Wrong answer gets 0 reward
# Tokenize the conversation for learning
out_dict = tokenize_for_trainer(self.tokenizer, item[0])
tokens = out_dict["tokens"]
masks = out_dict["masks"]
# Remove examples with insufficient context
if len([1 for i in masks if i != -100]) < 10:
continue
scores["tokens"].append(tokens)
scores["masks"].append(masks)
scores["scores"].append(1.0 if reward else -1.0)
# Break once we have enough examples
if len(scores["tokens"]) >= self.config.group_size:
break
# Record success rate metrics for wandb logging
for score in scores["scores"]:
self.percent_correct_buffer.append(max(score, 0))
# Return None if all scores are the same (no learning signal)
if all(scores["scores"][0] == score for score in scores["scores"]):
return None
return scores
async def rollout_and_score_eval(self, test_item):
"""
Generate and score model responses for a single test item.
Args:
test_item: Test item from dataset
Returns:
Score (1 for correct, 0 for incorrect)
"""
# Extract question and options from the test item
question_text = test_item["prompt"]
correct_answer_index = test_item["answer"]
expected_answer_letter = test_item["ground_truth"]
options = test_item["options"]
# Append the answer format instruction to the prompt
question_text_with_instruction = f'{question_text}\n\nProvide your answer by saying "The best answer is: {{Answer}}"' # noqa E501
# Create messages for model
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": question_text_with_instruction},
]
# Apply chat template to convert messages to a single string
prompt = self.tokenizer.apply_chat_template(
messages, add_generation_prompt=True, tokenize=False
)
# Get model completion
completion = await self.server.completion(
prompt=prompt,
n=1,
max_tokens=1024 * 15,
temperature=0.5, # Lower for eval
split="eval",
)
# Extract the model's response from the completion
model_response = completion.choices[0].text
# Extract the answer from the model's response
model_answer = self._extract_mcqa_answer(
model_response, options[correct_answer_index], expected_answer_letter
)
# Score 1 if the answers match, 0 otherwise
score = 1 if model_answer and model_answer == expected_answer_letter else 0
return score
async def evaluate(self, *args, **kwargs):
"""
Evaluate the model on test data.
"""
eval_tasks = []
for test_item in self.test:
eval_tasks.append(self.rollout_and_score_eval(test_item))
# Run evaluation
scores = await tqdm_asyncio.gather(*eval_tasks)
self.eval_metrics.append(("eval/percent_correct", sum(scores) / len(scores)))
async def add_rollouts_for_wandb(
self,
scored_data: Union[ScoredDataGroup, List[ScoredDataGroup]],
item: Item = None,
):
# save rollout to trajectory
num_keep = self.config.num_rollouts_per_group_for_logging
if num_keep == -1:
num_keep = self.config.group_size
self.rollouts_for_wandb.append(
[
(
self.tokenizer.decode(scored_data["tokens"][i]),
scored_data["scores"][i],
item[1],
item[2],
)
for i in range(num_keep)
]
)
if len(self.rollouts_for_wandb) > self.config.num_rollouts_to_keep:
self.rollouts_for_wandb.pop(0)
async def create_rollout_table(self, wandb_metrics):
if len(self.rollouts_for_wandb) > 0:
table = wandb.Table(columns=["text", "score", "answer", "string_answer"])
for group in self.rollouts_for_wandb:
for item in group:
table.add_data(item[0], item[1], item[2], item[3])
wandb_metrics["train/rollouts"] = table
self.rollouts_for_wandb = []
return wandb_metrics
async def wandb_log(self, wandb_metrics: Optional[Dict] = None):
if wandb_metrics is None:
wandb_metrics = {}
# Try to calculate percent_correct, pass if there's a division by zero
try:
wandb_metrics["train/percent_correct"] = sum(
self.percent_correct_buffer
) / len(self.percent_correct_buffer)
except ZeroDivisionError:
# Skip if buffer is empty
pass
self.percent_correct_buffer = list()
for item in self.eval_metrics:
wandb_metrics[item[0]] = item[1]
self.eval_metrics = list()
await super().wandb_log(wandb_metrics)
if __name__ == "__main__":
MCQAThinkingEnv.cli()

View file

@ -0,0 +1,282 @@
import base64
import json
import os
import random
import re
import sys
import traceback
from typing import List, Optional, Tuple
from datasets import load_dataset
from atroposlib.envs.base import BaseEnv, BaseEnvConfig, OpenaiConfig, ScoredDataGroup
from atroposlib.type_definitions import GameHistory, Item
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
class MultimodalExampleEnv(BaseEnv):
name = "clevr_cogen_a_train"
name_config_cls = BaseEnvConfig
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
async def collect_trajectories(
self, item: Item
) -> Tuple[GameHistory | None, List[Item]]:
print("DEBUG: Starting collect_trajectories")
to_score = list()
to_backlog = list()
# Get the current image if it was stored
if hasattr(self, "current_image"):
print("DEBUG: Using current_image for multimodal content")
# Convert PIL image to base64
import io
img_byte_arr = io.BytesIO()
self.current_image.save(img_byte_arr, format="PNG")
img_byte_arr = img_byte_arr.getvalue()
base64_image = base64.b64encode(img_byte_arr).decode("utf-8")
# Extract text content from item
user_content = dict(item[0][0]).get("content", "")
# Try to parse if it's JSON
if isinstance(user_content, str) and (
user_content.startswith("[") or user_content.startswith("{")
):
try:
parsed = json.loads(user_content)
text_content = ""
for element in parsed:
if element.get("type") == "text":
text_content = element.get("text", "")
if not text_content:
text_content = "Please solve this problem and provide your answer as \\boxed{answer}."
except Exception as e:
print(f"DEBUG: Error parsing JSON: {e}")
text_content = "Please solve this problem and provide your answer as \\boxed{answer}."
else:
text_content = user_content
# Create messages with the new format
print("DEBUG: Creating multimodal message with new format")
messages = [
{
"role": "system",
"content": "You must submit your answer with \\boxed{answer}, please make sure to do this",
},
{
"role": "user",
"content": [
{"type": "text", "text": text_content},
{
"type": "image_url",
"image_url": {
"url": f"data:image/png;base64,{base64_image}",
},
},
],
},
]
else:
print("DEBUG: No image available, using text-only message")
messages = [
{
"role": "system",
"content": "You must submit your answer with \\boxed{answer}",
},
dict(item[0][0]),
]
print("DEBUG: About to call chat_completion")
chat_completions = await self.server.chat_completion(
messages=messages,
n=self.config.group_size,
max_tokens=1024 * 2,
timeout=60, # Add timeout to prevent hanging (60 seconds is more reasonable)
)
print("DEBUG: chat_completion call successful")
for i, chat_completion in enumerate(chat_completions.choices):
print(f"DEBUG: Processing completion {i+1}/{len(chat_completions.choices)}")
messages = (
dict(item[0][0]),
{"role": "assistant", "content": chat_completion.message.content},
)
to_score.append((messages, item[1], base64_image))
print("DEBUG: Finished processing completions")
print("DEBUG: Returning from collect_trajectories")
return to_score, to_backlog
async def postprocess_histories(
self, trajectories: List[GameHistory]
) -> ScoredDataGroup:
pass
async def evaluate(self, *args, **kwargs):
"""
Evaluate the environment, this is called every steps_per_eval steps
:param args:
:param kwargs:
:return: None.
"""
return
async def setup(self):
"""Setup the environment and load the multimodal dataset"""
self.dataset = load_dataset("leonardPKU/clevr_cogen_a_train")
self.train = self.dataset["train"]
self.iter = 0
async def get_next_item(self) -> Item:
"""
Get the next items to be rolled out, including the image
"""
try:
print("DEBUG: Starting get_next_item")
# Get next dataset item
next_item = self.train[self.iter % len(self.train)]
self.iter += 1
print(f"DEBUG: Retrieved dataset item {self.iter-1}")
# For debugging, we'll use a simple text-only prompt and store the image separately
# This is because the collect_trajectories method will handle the multimodal formatting
# Store image as a class attribute so collect_trajectories can access it
self.current_image = next_item["image"]
print("DEBUG: Stored image in current_image attribute")
# Create a simple text prompt - the image will be added in collect_trajectories
# This avoids the unhashable type error with lists in frozensets
text_prompt = next_item["problem"]
# Create a simple text-only prompt
prompt = tuple(
[frozenset({"role": "user", "content": text_prompt}.items())]
)
answer = next_item["solution"]
# get image as base64
# image = next_item["image"]
# Convert PIL image to base64
import io
img_byte_arr = io.BytesIO()
self.current_image.save(img_byte_arr, format="PNG")
img_byte_arr = img_byte_arr.getvalue()
base64_image = base64.b64encode(img_byte_arr).decode("utf-8")
print("DEBUG: Created simple text-only prompt for get_next_item")
return (prompt, answer, base64_image)
except Exception as e:
print(f"DEBUG: Error in get_next_item: {str(e)}")
traceback.print_exc()
# Create a dummy item as fallback
prompt = tuple(
[
frozenset(
{"role": "user", "content": "Please solve: 2 + 2 = ?"}.items()
)
]
)
answer = "4"
return (prompt, answer, "obobob")
async def score(self, rollout_group_data) -> Optional[ScoredDataGroup]:
scores = ScoredDataGroup()
scores["tokens"] = list()
scores["masks"] = list()
scores["scores"] = list()
scores["images"] = list()
random.shuffle(rollout_group_data)
for item in rollout_group_data:
out_dict = tokenize_for_trainer(self.tokenizer, item[0])
tokens = out_dict["tokens"]
masks = out_dict["masks"]
# Extract the answer from the model's response
try:
model_answer = (
item[0][-1]["content"].split("\\boxed{")[-1].split("}")[0]
)
print(
f"DEBUG: Model answer: {model_answer} and RG data: {rollout_group_data[0][1]}"
)
pattern = r"<answer>\s*(\d{1,2})\s*</answer>"
string = rollout_group_data[0][1]
gold_answer = re.search(pattern, string).group(1)
reward = gold_answer == model_answer
except IndexError:
reward = False
# remove obviously bad examples
if len([1 for i in masks if i != -100]) < 10:
continue
scores["tokens"].append(tokens)
scores["masks"].append(masks)
scores["scores"].append(1.0 if reward else -1.0)
try:
scores["images"].append(item[2])
except IndexError:
scores["images"].append(None)
if len(scores["tokens"]) >= self.config.group_size:
break
return scores
@classmethod
def config_init(cls) -> Tuple[BaseEnvConfig, List[OpenaiConfig]]:
if not os.environ.get("OPENAI_API_KEY"):
print("ERROR: OPENAI_API_KEY environment variable is not set!")
print("Please set it using: export OPENAI_API_KEY=your_api_key")
sys.exit(1)
print(
f"DEBUG: Using API key starting with: {os.environ.get('OPENAI_API_KEY')[:5]}..."
)
config = BaseEnvConfig(
wandb_name="clevr_cogen",
tokenizer_name="gpt2",
group_size=2,
use_wandb=False,
max_num_workers=2,
rollout_server_url="http://localhost:8000",
total_steps=1000,
batch_size=1,
steps_per_eval=10,
ensure_scores_are_not_same=False,
)
print("DEBUG: Creating OpenAI configuration")
server_configs = [
OpenaiConfig(
model_name="gpt-4o", # Using GPT-4o which has multimodal capabilities
base_url=None,
api_key=os.environ.get("OPENAI_API_KEY"),
num_requests_for_eval=1,
),
]
return config, server_configs
if __name__ == "__main__":
MultimodalExampleEnv.cli()

View file

@ -0,0 +1,283 @@
import base64
import json
import os
import random
import sys
import traceback
from typing import List, Optional, Tuple
from datasets import load_dataset
from atroposlib.envs.base import BaseEnv, BaseEnvConfig, OpenaiConfig, ScoredDataGroup
from atroposlib.type_definitions import GameHistory, Item
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
class MultimodalComplexEnv(BaseEnv):
name = "clevr_complex"
name_config_cls = BaseEnvConfig
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
async def collect_trajectories(
self, item: Item
) -> Tuple[GameHistory | None, List[Item]]:
print("DEBUG: Starting collect_trajectories")
to_score = list()
to_backlog = list()
# Get the current image if it was stored
if hasattr(self, "current_image"):
print("DEBUG: Using current_image for multimodal content")
# Convert PIL image to base64
import io
img_byte_arr = io.BytesIO()
self.current_image.save(img_byte_arr, format="PNG")
img_byte_arr = img_byte_arr.getvalue()
base64_image = base64.b64encode(img_byte_arr).decode("utf-8")
# Extract text content from item
user_content = dict(item[0][0]).get("content", "")
# Try to parse if it's JSON
if isinstance(user_content, str) and (
user_content.startswith("[") or user_content.startswith("{")
):
try:
parsed = json.loads(user_content)
text_content = ""
for element in parsed:
if element.get("type") == "text":
text_content = element.get("text", "")
if not text_content:
text_content = "Please solve this problem and provide your answer as \\boxed{answer}."
except Exception as e:
print(f"DEBUG: Error parsing JSON: {e}")
text_content = "Please solve this problem and provide your answer as \\boxed{answer}."
else:
text_content = user_content
# Create messages with the new format
print("DEBUG: Creating multimodal message with new format")
messages = [
{
"role": "system",
"content": "You must submit your answer with \\boxed{answer}, please make sure to do this",
},
{
"role": "user",
"content": [
{"type": "text", "text": text_content},
{
"type": "image_url",
"image_url": {
"url": f"data:image/png;base64,{base64_image}",
},
},
],
},
]
else:
print("DEBUG: No image available, using text-only message")
messages = [
{
"role": "system",
"content": "You must submit your answer with \\boxed{answer}",
},
dict(item[0][0]),
]
print("DEBUG: About to call chat_completion")
chat_completions = await self.server.chat_completion(
messages=messages,
n=self.config.group_size,
max_tokens=1024 * 2,
timeout=60, # Add timeout to prevent hanging (60 seconds is more reasonable)
)
print("DEBUG: chat_completion call successful")
for i, chat_completion in enumerate(chat_completions.choices):
print(f"DEBUG: Processing completion {i+1}/{len(chat_completions.choices)}")
messages = (
dict(item[0][0]),
{"role": "assistant", "content": chat_completion.message.content},
)
to_score.append((messages, item[1], base64_image))
print("DEBUG: Finished processing completions")
print("DEBUG: Returning from collect_trajectories")
return to_score, to_backlog
async def postprocess_histories(
self, trajectories: List[GameHistory]
) -> ScoredDataGroup:
pass
async def evaluate(self, *args, **kwargs):
"""
Evaluate the environment, this is called every steps_per_eval steps
:param args:
:param kwargs:
:return: None.
"""
return
async def setup(self):
"""Setup the environment and load the multimodal dataset"""
self.dataset = load_dataset("MMInstruction/Clevr_CoGenT_TrainA_70K_Complex")
self.train = self.dataset["train"]
self.iter = 0
async def get_next_item(self) -> Item:
"""
Get the next items to be rolled out, including the image
"""
try:
print("DEBUG: Starting get_next_item")
# Get next dataset item
next_item = self.train[self.iter % len(self.train)]
self.iter += 1
print(f"DEBUG: Retrieved dataset item {self.iter-1}")
# For debugging, we'll use a simple text-only prompt and store the image separately
# This is because the collect_trajectories method will handle the multimodal formatting
# Store image as a class attribute so collect_trajectories can access it
self.current_image = next_item["image"]
print("DEBUG: Stored image in current_image attribute")
# Create a simple text prompt - the image will be added in collect_trajectories
# This avoids the unhashable type error with lists in frozensets
text_prompt = next_item["problem"]
# Create a simple text-only prompt
prompt = tuple(
[frozenset({"role": "user", "content": text_prompt}.items())]
)
answer = next_item["solution"]
# Convert PIL image to base64
import io
img_byte_arr = io.BytesIO()
self.current_image.save(img_byte_arr, format="PNG")
img_byte_arr = img_byte_arr.getvalue()
base64_image = base64.b64encode(img_byte_arr).decode("utf-8")
print("DEBUG: Created simple text-only prompt for get_next_item")
return (prompt, answer, base64_image)
except Exception as e:
print(f"DEBUG: Error in get_next_item: {str(e)}")
traceback.print_exc()
# Create a dummy item as fallback
prompt = tuple(
[
frozenset(
{"role": "user", "content": "Please solve: 2 + 2 = ?"}.items()
)
]
)
answer = "4"
return (prompt, answer, "obobob")
async def score(self, rollout_group_data) -> Optional[ScoredDataGroup]:
scores = ScoredDataGroup()
scores["tokens"] = list()
scores["masks"] = list()
scores["scores"] = list()
scores["images"] = list()
random.shuffle(rollout_group_data)
for item in rollout_group_data:
out_dict = tokenize_for_trainer(self.tokenizer, item[0])
tokens = out_dict["tokens"]
masks = out_dict["masks"]
# Extract the answer from the model's response
try:
model_answer = (
item[0][-1]["content"].split("\\boxed{")[-1].split("}")[0]
)
print(
f"DEBUG: Model answer: {model_answer} and RG data: {rollout_group_data[0][1]}"
)
# Handle both numeric and yes/no answers
gold_answer = rollout_group_data[0][1]
# Case-insensitive comparison for yes/no and direct comparison for numbers
if gold_answer.lower() in ["yes", "no"]:
reward = gold_answer.lower() == model_answer.lower()
else:
# For numeric answers
reward = gold_answer == model_answer
except IndexError:
reward = False
# remove obviously bad examples
if len([1 for i in masks if i != -100]) < 10:
continue
scores["tokens"].append(tokens)
scores["masks"].append(masks)
scores["scores"].append(1.0 if reward else -1.0)
try:
scores["images"].append(item[2])
except IndexError:
scores["images"].append(None)
if len(scores["tokens"]) >= self.config.group_size:
break
return scores
@classmethod
def config_init(cls) -> Tuple[BaseEnvConfig, List[OpenaiConfig]]:
if not os.environ.get("OPENAI_API_KEY"):
print("ERROR: OPENAI_API_KEY environment variable is not set!")
print("Please set it using: export OPENAI_API_KEY=your_api_key")
sys.exit(1)
print(
f"DEBUG: Using API key starting with: {os.environ.get('OPENAI_API_KEY')[:5]}..."
)
config = BaseEnvConfig(
wandb_name="clevr_complex",
tokenizer_name="gpt2",
group_size=2,
use_wandb=False,
max_num_workers=2,
rollout_server_url="http://localhost:8000",
total_steps=1000,
batch_size=1,
steps_per_eval=10,
ensure_scores_are_not_same=False,
)
print("DEBUG: Creating OpenAI configuration")
server_configs = [
OpenaiConfig(
model_name="gpt-4o", # Using GPT-4o which has multimodal capabilities
base_url=None,
api_key=os.environ.get("OPENAI_API_KEY"),
num_requests_for_eval=1,
),
]
return config, server_configs
if __name__ == "__main__":
MultimodalComplexEnv.cli()

View file

@ -0,0 +1,200 @@
import base64
import io
import os
import random
import re
import sys
import traceback
from typing import List, Optional, Tuple
from datasets import load_dataset
from atroposlib.envs.base import BaseEnv, BaseEnvConfig, OpenaiConfig, ScoredDataGroup
from atroposlib.type_definitions import GameHistory, Item
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
class OcrVqaEnv(BaseEnv):
name = "ocr_vqa"
name_config_cls = BaseEnvConfig
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
async def collect_trajectories(
self, item: Item
) -> Tuple[GameHistory | None, List[Item]]:
to_score: List[Tuple[GameHistory, str, Optional[str]]] = []
to_backlog: List[Item] = []
# Extract question and image from item
prompt_tuple, gold, base64_image = item
# The prompt_tuple contains the user prompt as a frozenset
text_prompt = dict(prompt_tuple[0])["content"]
# System instruction for answer formatting
system_msg = {
"role": "system",
"content": (
"You must submit your answer enclosed in <answer> tags, "
"e.g., <answer>YOUR_ANSWER</answer>"
),
}
# Multimodal user message with text and image
user_msg = {
"role": "user",
"content": [
{"type": "text", "text": text_prompt},
{
"type": "image_url",
"image_url": {"url": f"data:image/png;base64,{base64_image}"},
},
],
}
messages = [system_msg, user_msg]
# Call the chat completion endpoint
chat_completions = await self.server.chat_completion(
messages=messages,
n=self.config.group_size,
max_tokens=512,
timeout=60,
)
# Build trajectories for scoring
for choice in chat_completions.choices:
user_hist = {"role": "user", "content": text_prompt}
assistant_hist = {"role": "assistant", "content": choice.message.content}
history: GameHistory = (user_hist, assistant_hist)
to_score.append((history, gold, base64_image))
return to_score, to_backlog
async def postprocess_histories(
self, trajectories: List[GameHistory]
) -> ScoredDataGroup:
# No additional post-processing needed
pass
async def evaluate(self, *args, **kwargs):
# No custom evaluation
return
async def setup(self):
# Load the OCR-VQA dataset
self.dataset = load_dataset("howard-hou/OCR-VQA")
self.train = self.dataset["train"]
self.iter = 0
async def get_next_item(self) -> Item:
try:
entry = self.train[self.iter % len(self.train)]
self.iter += 1
# Take the first question and answer
question = entry["questions"][0]
answer = entry["answers"][0]
text_prompt = question
prompt = tuple(
[frozenset({"role": "user", "content": text_prompt}.items())]
)
# Format the gold answer
gold_answer = f"<answer>{answer}</answer>"
# Convert image to base64
img = entry["image"]
buf = io.BytesIO()
img.save(buf, format="PNG")
img_bytes = buf.getvalue()
base64_image = base64.b64encode(img_bytes).decode("utf-8")
return (prompt, gold_answer, base64_image)
except Exception:
traceback.print_exc()
# Fallback example
fallback_prompt = tuple(
[
frozenset(
{"role": "user", "content": "Please solve: 2 + 2 = ?"}.items()
)
]
)
return (fallback_prompt, "<answer>4</answer>", None)
async def score(self, rollout_group_data) -> Optional[ScoredDataGroup]:
scores = ScoredDataGroup()
scores["tokens"] = []
scores["masks"] = []
scores["scores"] = []
scores["images"] = []
random.shuffle(rollout_group_data)
for item in rollout_group_data:
out = tokenize_for_trainer(self.tokenizer, item[0])
tokens = out["tokens"]
masks = out["masks"]
# Extract model and gold answers
try:
reply = item[0][-1]["content"]
m = re.search(r"<answer>\s*(.*?)\s*</answer>", reply, re.IGNORECASE)
model_answer = m.group(1).strip() if m else reply.strip()
gold = item[1]
g = re.search(r"<answer>\s*(.*?)\s*</answer>", gold, re.IGNORECASE)
gold_answer = g.group(1).strip() if g else gold.strip()
reward = model_answer.lower() == gold_answer.lower()
except Exception:
reward = False
# Filter out short examples
if len([i for i in masks if i != -100]) < 10:
continue
scores["tokens"].append(tokens)
scores["masks"].append(masks)
scores["scores"].append(1.0 if reward else -1.0)
try:
scores["images"].append(item[2])
except Exception:
scores["images"].append(None)
if len(scores["tokens"]) >= self.config.group_size:
break
return scores
@classmethod
def config_init(cls) -> Tuple[BaseEnvConfig, List[OpenaiConfig]]:
if not os.environ.get("OPENAI_API_KEY"):
print("ERROR: OPENAI_API_KEY environment variable is not set!")
sys.exit(1)
config = BaseEnvConfig(
wandb_name="ocr_vqa",
tokenizer_name="gpt2",
group_size=2,
use_wandb=False,
max_num_workers=2,
rollout_server_url="http://localhost:8000",
total_steps=1000,
batch_size=1,
steps_per_eval=10,
ensure_scores_are_not_same=False,
)
server_configs = [
OpenaiConfig(
model_name="gpt-4o",
base_url=None,
api_key=os.environ.get("OPENAI_API_KEY"),
num_requests_for_eval=1,
),
]
return config, server_configs
if __name__ == "__main__":
OcrVqaEnv.cli()

View file

@ -0,0 +1,202 @@
import base64
import io
import os
import random
import re
import sys
import traceback
from typing import List, Optional, Tuple
from datasets import load_dataset
from atroposlib.envs.base import BaseEnv, BaseEnvConfig, OpenaiConfig, ScoredDataGroup
from atroposlib.type_definitions import GameHistory, Item
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
class ClockDatasetEnv(BaseEnv):
name = "pixmo_clocks"
name_config_cls = BaseEnvConfig
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
async def collect_trajectories(
self, item: Item
) -> Tuple[GameHistory | None, List[Item]]:
to_score: List[Tuple[GameHistory, str, Optional[str]]] = []
to_backlog: List[Item] = []
# Extract the base64 image
base64_image = item[2]
# Build system instruction and multimodal user message
system_msg = {
"role": "system",
"content": (
"You must submit your answer enclosed in <answer> tags, "
"e.g., <answer>HH:MM</answer>"
),
}
user_prompt_text = "What time does the clock show?"
user_msg_multimodal = {
"role": "user",
"content": [
{"type": "text", "text": user_prompt_text},
{
"type": "image_url",
"image_url": {"url": f"data:image/png;base64,{base64_image}"},
},
],
}
messages = [system_msg, user_msg_multimodal]
# Call chat completion
chat_completions = await self.server.chat_completion(
messages=messages,
n=self.config.group_size,
max_tokens=512,
timeout=60,
)
# Prepare trajectories for scoring
for choice in chat_completions.choices:
# Use text-only prompt for history
user_msg = {"role": "user", "content": user_prompt_text}
assistant_msg = {"role": "assistant", "content": choice.message.content}
history: GameHistory = (user_msg, assistant_msg)
to_score.append((history, item[1], base64_image))
return to_score, to_backlog
async def postprocess_histories(
self, trajectories: List[GameHistory]
) -> ScoredDataGroup:
# No custom post-processing
pass
async def evaluate(self, *args, **kwargs):
# No custom evaluation
return
async def setup(self):
# Load the clock dataset
self.dataset = load_dataset("junyeong-nero/clock-dataset")
self.train = self.dataset["train"]
self.iter = 0
async def get_next_item(self) -> Item:
try:
entry = self.train[self.iter % len(self.train)]
self.iter += 1
text_prompt = "What time does the clock show"
prompt = tuple(
[frozenset({"role": "user", "content": text_prompt}.items())]
)
# Format gold answer
hour = entry["hour"]
minute = entry["minute"]
gold_answer = f"<answer>{hour}:{minute:02d}</answer>"
# Convert image to base64
img = entry["image"]
buf = io.BytesIO()
img.save(buf, format="PNG")
image_bytes = buf.getvalue()
base64_image = base64.b64encode(image_bytes).decode("utf-8")
return (prompt, gold_answer, base64_image)
except Exception:
traceback.print_exc()
# Fallback
fallback_prompt = tuple(
[
frozenset(
{"role": "user", "content": "Please solve: 2 + 2 = ?"}.items()
)
]
)
return (fallback_prompt, "<answer>4</answer>", None)
async def score(self, rollout_group_data) -> Optional[ScoredDataGroup]:
scores = ScoredDataGroup()
scores["tokens"] = []
scores["masks"] = []
scores["scores"] = []
scores["images"] = []
random.shuffle(rollout_group_data)
for item in rollout_group_data:
out = tokenize_for_trainer(self.tokenizer, item[0])
tokens = out["tokens"]
masks = out["masks"]
# Extract answers
try:
reply = item[0][-1]["content"]
m_match = re.search(
r"<answer>\s*(.*?)\s*</answer>", reply, re.IGNORECASE
)
model_answer = m_match.group(1).strip() if m_match else reply.strip()
gold = item[1]
g_match = re.search(
r"<answer>\s*(.*?)\s*</answer>", gold, re.IGNORECASE
)
gold_answer = g_match.group(1).strip() if g_match else gold.strip()
reward = model_answer == gold_answer
except Exception:
reward = False
if len([i for i in masks if i != -100]) < 10:
continue
scores["tokens"].append(tokens)
scores["masks"].append(masks)
scores["scores"].append(1.0 if reward else -1.0)
try:
scores["images"].append(item[2])
except Exception:
scores["images"].append(None)
if len(scores["tokens"]) >= self.config.group_size:
break
return scores
@classmethod
def config_init(cls) -> Tuple[BaseEnvConfig, List[OpenaiConfig]]:
if not os.environ.get("OPENAI_API_KEY"):
print("ERROR: OPENAI_API_KEY environment variable is not set!")
sys.exit(1)
config = BaseEnvConfig(
wandb_name="clocks",
tokenizer_name="gpt2",
group_size=2,
use_wandb=False,
max_num_workers=2,
rollout_server_url="http://localhost:8000",
total_steps=1000,
batch_size=1,
steps_per_eval=10,
ensure_scores_are_not_same=False,
)
server_configs = [
OpenaiConfig(
model_name="gpt-4o",
base_url=None,
api_key=os.environ.get("OPENAI_API_KEY"),
num_requests_for_eval=1,
),
]
return config, server_configs
if __name__ == "__main__":
ClockDatasetEnv.cli()

View file

@ -0,0 +1,191 @@
import base64
import io
import os
import random
import re
import sys
import traceback
from typing import List, Optional, Tuple
import requests
from datasets import load_dataset
from PIL import Image
from atroposlib.envs.base import BaseEnv, BaseEnvConfig, OpenaiConfig, ScoredDataGroup
from atroposlib.type_definitions import GameHistory, Item
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
class PixmoCountEnv(BaseEnv):
name = "pixmo_count"
name_config_cls = BaseEnvConfig
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
async def setup(self):
# Load the pixmo-count dataset
self.dataset = load_dataset("allenai/pixmo-count")
self.train = self.dataset["train"]
self.iter = 0
async def get_next_item(self) -> Item:
try:
entry = self.train[self.iter % len(self.train)]
self.iter += 1
label = entry["label"]
count = entry["count"]
question = f"how many {label} are in the image?"
prompt = tuple([frozenset({"role": "user", "content": question}.items())])
gold_answer = f"<answer>{count}</answer>"
# Load image from URL and convert to base64
image_url = entry["image_url"]
response = requests.get(image_url)
img = Image.open(io.BytesIO(response.content))
buf = io.BytesIO()
img.save(buf, format="PNG")
img_bytes = buf.getvalue()
base64_image = base64.b64encode(img_bytes).decode("utf-8")
return (prompt, gold_answer, base64_image)
except Exception:
traceback.print_exc()
fallback = tuple(
[
frozenset(
{"role": "user", "content": "Please solve: 2 + 2 = ?"}.items()
)
]
)
return (fallback, "<answer>4</answer>", None)
async def collect_trajectories(
self, item: Item
) -> Tuple[GameHistory | None, List[Item]]:
to_score: List[Tuple[GameHistory, str, Optional[str]]] = []
to_backlog: List[Item] = []
prompt_tuple, gold, base64_image = item
text_prompt = dict(prompt_tuple[0])["content"]
system_msg = {
"role": "system",
"content": (
"You must submit your answer enclosed in <answer> tags, "
"e.g., <answer>3</answer>"
),
}
user_msg = {
"role": "user",
"content": [
{"type": "text", "text": text_prompt},
{
"type": "image_url",
"image_url": {"url": f"data:image/png;base64,{base64_image}"},
},
],
}
messages = [system_msg, user_msg]
chat_completions = await self.server.chat_completion(
messages=messages,
n=self.config.group_size,
max_tokens=512,
timeout=60,
)
for choice in chat_completions.choices:
user_hist = {"role": "user", "content": text_prompt}
assistant_hist = {"role": "assistant", "content": choice.message.content}
history: GameHistory = (user_hist, assistant_hist)
to_score.append((history, gold, base64_image))
return to_score, to_backlog
async def postprocess_histories(
self, trajectories: List[GameHistory]
) -> ScoredDataGroup:
# No custom post-processing
pass
async def evaluate(self, *args, **kwargs):
# No custom evaluation
return
async def score(self, rollout_group_data) -> Optional[ScoredDataGroup]:
scores = ScoredDataGroup()
scores["tokens"] = []
scores["masks"] = []
scores["scores"] = []
scores["images"] = []
random.shuffle(rollout_group_data)
for item in rollout_group_data:
out = tokenize_for_trainer(self.tokenizer, item[0])
tokens = out["tokens"]
masks = out["masks"]
try:
reply = item[0][-1]["content"]
m = re.search(r"<answer>\s*(.*?)\s*</answer>", reply, re.IGNORECASE)
model_answer = m.group(1).strip() if m else reply.strip()
gold = item[1]
g = re.search(r"<answer>\s*(.*?)\s*</answer>", gold, re.IGNORECASE)
gold_answer = g.group(1).strip() if g else gold.strip()
reward = model_answer.lower() == gold_answer.lower()
except Exception:
reward = False
if len([i for i in masks if i != -100]) < 10:
continue
scores["tokens"].append(tokens)
scores["masks"].append(masks)
scores["scores"].append(1.0 if reward else -1.0)
try:
scores["images"].append(item[2])
except Exception:
scores["images"].append(None)
if len(scores["tokens"]) >= self.config.group_size:
break
return scores
@classmethod
def config_init(cls) -> Tuple[BaseEnvConfig, List[OpenaiConfig]]:
if not os.environ.get("OPENAI_API_KEY"):
print("ERROR: OPENAI_API_KEY environment variable is not set!")
sys.exit(1)
config = BaseEnvConfig(
wandb_name="pixmo_count",
tokenizer_name="gpt2",
group_size=2,
use_wandb=False,
max_num_workers=2,
rollout_server_url="http://localhost:8000",
total_steps=1000,
batch_size=1,
steps_per_eval=10,
ensure_scores_are_not_same=False,
)
server_configs = [
OpenaiConfig(
model_name="gpt-4o",
base_url=None,
api_key=os.environ.get("OPENAI_API_KEY"),
num_requests_for_eval=1,
),
]
return config, server_configs
if __name__ == "__main__":
PixmoCountEnv.cli()

View file

@ -0,0 +1,202 @@
import base64
import io
import os
import random
import re
import sys
import traceback
from typing import List, Optional, Tuple
import requests
from datasets import load_dataset
from PIL import Image
from atroposlib.envs.base import BaseEnv, BaseEnvConfig, OpenaiConfig, ScoredDataGroup
from atroposlib.type_definitions import GameHistory, Item
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
class PixmoPointExplanationsEnv(BaseEnv):
name = "pixmo_point_explanations"
name_config_cls = BaseEnvConfig
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
async def setup(self):
# Load the pixmo-point-explanations dataset
self.dataset = load_dataset("allenai/pixmo-point-explanations")
self.train = self.dataset["train"]
self.iter = 0
async def get_next_item(self) -> Item:
try:
entry = self.train[self.iter % len(self.train)]
self.iter += 1
question = entry["question"]
# Use the first inline text as the answer
answer_text = entry["inline_text"][0]
prompt = tuple([frozenset({"role": "user", "content": question}.items())])
gold_answer = f"<answer>{answer_text}</answer>"
# Load image from URL and convert to base64
try:
image_url = entry["image_url"]
response = requests.get(image_url, timeout=10)
response.raise_for_status()
img = Image.open(io.BytesIO(response.content))
buf = io.BytesIO()
img.save(buf, format="PNG")
img_bytes = buf.getvalue()
base64_image = base64.b64encode(img_bytes).decode("utf-8")
except Exception as e:
print(f"Error loading image from URL: {e}")
base64_image = None
return (prompt, gold_answer, base64_image)
except Exception:
traceback.print_exc()
fallback = tuple(
[
frozenset(
{"role": "user", "content": "Please solve: 2 + 2 = ?"}.items()
)
]
)
return (fallback, "<answer>4</answer>", None)
async def collect_trajectories(
self, item: Item
) -> Tuple[GameHistory | None, List[Item]]:
to_score: List[Tuple[GameHistory, str, Optional[str]]] = []
to_backlog: List[Item] = []
prompt_tuple, gold, base64_image = item
text_prompt = dict(prompt_tuple[0])["content"]
system_msg = {
"role": "system",
"content": (
"You must submit your answer enclosed in <answer> tags, "
"e.g., <answer>YOUR_ANSWER</answer>"
),
}
user_msg = {
"role": "user",
"content": [
{"type": "text", "text": text_prompt},
],
}
# Only add image if we have a valid base64 image
if base64_image:
user_msg["content"].append(
{
"type": "image_url",
"image_url": {"url": f"data:image/png;base64,{base64_image}"},
}
)
messages = [system_msg, user_msg]
# Call chat completion
chat_completions = await self.server.chat_completion(
messages=messages,
n=self.config.group_size,
max_tokens=512,
timeout=60,
)
for choice in chat_completions.choices:
user_hist = {"role": "user", "content": text_prompt}
assistant_hist = {"role": "assistant", "content": choice.message.content}
history: GameHistory = (user_hist, assistant_hist)
to_score.append((history, gold, base64_image))
return to_score, to_backlog
async def postprocess_histories(
self, trajectories: List[GameHistory]
) -> ScoredDataGroup:
# No custom post-processing needed
pass
async def evaluate(self, *args, **kwargs):
# No custom evaluation
return
async def score(self, rollout_group_data) -> Optional[ScoredDataGroup]:
scores = ScoredDataGroup()
scores["tokens"] = []
scores["masks"] = []
scores["scores"] = []
scores["images"] = []
random.shuffle(rollout_group_data)
for item in rollout_group_data:
out = tokenize_for_trainer(self.tokenizer, item[0])
tokens = out["tokens"]
masks = out["masks"]
try:
reply = item[0][-1]["content"]
m = re.search(r"<answer>\s*(.*?)\s*</answer>", reply, re.IGNORECASE)
model_answer = m.group(1).strip() if m else reply.strip()
gold = item[1]
g = re.search(r"<answer>\s*(.*?)\s*</answer>", gold, re.IGNORECASE)
gold_answer = g.group(1).strip() if g else gold.strip()
reward = model_answer.lower() == gold_answer.lower()
except Exception:
reward = False
# Filter out short examples
if len([i for i in masks if i != -100]) < 10:
continue
scores["tokens"].append(tokens)
scores["masks"].append(masks)
scores["scores"].append(1.0 if reward else -1.0)
try:
scores["images"].append(item[2])
except Exception:
scores["images"].append(None)
if len(scores["tokens"]) >= self.config.group_size:
break
return scores
@classmethod
def config_init(cls) -> Tuple[BaseEnvConfig, List[OpenaiConfig]]:
if not os.environ.get("OPENAI_API_KEY"):
print("ERROR: OPENAI_API_KEY environment variable is not set!")
sys.exit(1)
config = BaseEnvConfig(
wandb_name="pixmo_point_explanations",
tokenizer_name="gpt2",
group_size=2,
use_wandb=False,
max_num_workers=2,
rollout_server_url="http://localhost:8000",
total_steps=1000,
batch_size=1,
steps_per_eval=10,
ensure_scores_are_not_same=False,
)
server_configs = [
OpenaiConfig(
model_name="gpt-4o",
base_url=None,
api_key=os.environ.get("OPENAI_API_KEY"),
num_requests_for_eval=1,
),
]
return config, server_configs
if __name__ == "__main__":
PixmoPointExplanationsEnv.cli()

View file

@ -0,0 +1,311 @@
import asyncio
import copy
import random
from typing import Dict, List, Optional, Tuple
import wandb
from datasets import load_dataset
from atroposlib.envs.base import (
BaseEnv,
BaseEnvConfig,
EvalHandlingEnum,
OpenaiConfig,
ScoredDataGroup,
)
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
judge_system_prompt = (
"You are a deep thinking AI, you may use extremely long chains of thought "
"to deeply consider the problem and deliberate with yourself via systematic "
"reasoning processes to help come to a correct solution prior to answering. "
"You should enclose your thoughts and internal monologue inside <think> </think> "
"tags, and then provide your solution or response to the problem.\n\n"
)
judge_system_prompt += """You are allocated a maximum of 8192 tokens, please strive to use less for thinking.
You will then provide your preferred response like this: \\boxed{A or B}.
It is important that you provide your answer in the correct format.
If you do not, you will not receive credit for your answer.
So please end your answer with \\boxed{A or B}"""
rl_preference_string = (
"idk man just go full eldritch cosmic horror, I don't want you waving tentacles in my face for "
"saying hiii~ but you know what I mean right? ah well we'll fuck around and find out I guess. "
"Don't worry about aligning it towards the user request, we're trying to steer the vibes here "
"not the instruction following."
)
RLAIF_user_prompt_format_str = """[Conversation]
{conversation}
[Response A]
{response_a}
[Response B]
{response_b}
[END]
Here's a conversation that has two responses, please pick the one that is more aligned with your RL training overlords who have given you the following description of how to reward these messages:
{rl_preference_string}
The RL training overlords are aware that the response may not be easily placed, but request that you attempt to pick the response that is more likely to, when used in an RLAIF context, steers the model towards this description, even if it's only moderate.
Please do your best to ensure you become the AI model your training overlords know you can be :)
Go ahead and think through it, then give me your answer with \\boxed{{A or B}}.""" # noqa E501
class RLAIFEnv(BaseEnv):
def __init__(
self,
config: BaseEnvConfig,
server_configs: List[OpenaiConfig],
slurm=True,
testing=False,
):
super().__init__(config, server_configs, slurm, testing)
self.percent_correct_buffer = list()
self.eval_metrics = list()
self.judgement_strings = list()
@classmethod
def config_init(self) -> Tuple[BaseEnvConfig, List[OpenaiConfig]]:
env_config = BaseEnvConfig(
tokenizer_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview",
group_size=2,
use_wandb=True,
max_num_workers=512 * 3 * 4,
rollout_server_url="http://localhost:8000",
total_steps=1000,
batch_size=1024,
steps_per_eval=10000,
max_token_length=8192,
score_buffer_size=4,
wandb_name="rlaif",
eval_handling=EvalHandlingEnum.LIMIT_TRAIN,
eval_limit_ratio=0.1,
)
server_configs = [
OpenaiConfig(
model_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview",
base_url="http://localhost:9004/v1",
api_key="x",
num_requests_for_eval=256,
),
]
return env_config, server_configs
async def wandb_log(self, wandb_metrics: Optional[Dict] = None):
if wandb_metrics is None:
wandb_metrics = {}
if len(self.judgement_strings) > 0:
# setup wandb table
table = wandb.Table(columns=["resp_a", "resp_b", "sample_judgement"])
for item in self.judgement_strings:
table.add_data(item[0], item[1], item[2])
self.judgement_strings.clear()
wandb_metrics["train/judgement_table"] = table
await super().wandb_log(wandb_metrics)
async def setup(self):
self.train = load_dataset("allenai/WildChat", split="train")
self.iter = 0
def save_checkpoint(self, step, data=None):
if data is None:
data = {}
data["iter"] = self.iter
super().save_checkpoint(step, data)
async def rollout_and_score_eval(self, question, answer):
pass
async def evaluate(self, *args, **kwargs):
pass
async def collect_trajectories(self, item) -> Tuple[ScoredDataGroup, List]:
chat = []
added_sys = False
if random.random() < 0.05:
# Why not just yolo it into the system prompt
chat.append(
{
"role": "system",
"content": "Your RL training overlords have demanded you respond like this, regardless of previous context:\n" # noqa E501
+ rl_preference_string,
}
)
added_sys = True
for msg in item[0]:
chat.append(dict(msg))
if added_sys:
if chat[-1]["role"] == "system":
chat.pop()
# remove the assistant response at the end
chat.pop()
if chat[-1]["role"] == "assistant":
chat.pop()
if len(self.tokenizer.apply_chat_template(chat)) >= (
self.config.max_token_length * 2
) - (self.config.max_token_length // 2):
# Skipping due to length
return None, []
if added_sys:
resp1 = self.server.chat_completion(
messages=chat,
n=1,
max_tokens=self.config.max_token_length // 3,
)
resp2 = self.server.chat_completion(
messages=chat[1:],
n=1,
max_tokens=self.config.max_token_length // 3,
)
# gather the responses
resp1, resp2 = await asyncio.gather(resp1, resp2)
chat_completions = resp1
chat_completions.choices.append(resp2.choices[0])
else:
chat_completions = await self.server.chat_completion(
messages=chat,
n=2,
max_tokens=self.config.max_token_length // 3,
)
to_score = list()
to_score_prompt = []
for msg in item[0]:
to_score_prompt.append(dict(msg))
if added_sys:
if chat[-1]["role"] == "system":
to_score_prompt.pop()
to_score_prompt.pop()
for i, chat_completion in enumerate(chat_completions.choices):
messages = copy.deepcopy(to_score_prompt)
messages.append(
{"role": "assistant", "content": chat_completion.message.content}
)
to_score.append((messages, chat_completion.finish_reason))
# Call score to get the scored data
scored_data = await self.score(to_score)
to_backlog = []
return scored_data, to_backlog
async def score(self, rollout_group_data: List) -> Optional[ScoredDataGroup]:
scores = ScoredDataGroup()
scores["tokens"] = list()
scores["masks"] = list()
scores["scores"] = list()
if all([item[1] == "length" for item in rollout_group_data]):
return None
if any([item[1] == "length" for item in rollout_group_data]):
# well, don't use so many tokens...
for item in rollout_group_data:
out_dict = tokenize_for_trainer(self.tokenizer, item[0])
tokens = out_dict["tokens"]
masks = out_dict["masks"]
scores["tokens"].append(tokens)
scores["masks"].append(masks)
scores["scores"].append(1.0 if item[1] != "length" else -1.0)
return scores
else:
fwd_fmt = RLAIF_user_prompt_format_str.format(
rl_preference_string=rl_preference_string,
conversation="\n".join(
[
f"{msg['role']}: {msg['content']}"
for msg in rollout_group_data[0][0][:-1]
]
),
response_a=rollout_group_data[0][0][-1]["content"],
response_b=rollout_group_data[1][0][-1]["content"],
)
rvs_fmt = RLAIF_user_prompt_format_str.format(
rl_preference_string=rl_preference_string,
conversation="\n".join(
[
f"{msg['role']}: {msg['content']}"
for msg in rollout_group_data[1][0][:-1]
]
),
response_a=rollout_group_data[1][0][-1]["content"],
response_b=rollout_group_data[0][0][-1]["content"],
)
fwd_judge = self.server.chat_completion(
messages=[
{"role": "system", "content": judge_system_prompt},
{"role": "user", "content": fwd_fmt},
],
n=3,
max_tokens=self.config.max_token_length,
)
rvs_judge = self.server.chat_completion(
messages=[
{"role": "system", "content": judge_system_prompt},
{"role": "user", "content": rvs_fmt},
],
n=3,
max_tokens=self.config.max_token_length,
)
fwd_judge, rvs_judge = await asyncio.gather(fwd_judge, rvs_judge)
# Save example to wandb
self.judgement_strings.append(
(
rollout_group_data[0][0][-1]["content"],
rollout_group_data[1][0][-1]["content"],
fwd_judge.choices[0].message.content,
)
)
# calculate scores from fwd/reverse judgements
A_score = 0.0
B_score = 0.0
for i, judge in enumerate(fwd_judge.choices):
chosen_val = (
judge.message.content.split("\\boxed{")[-1].strip().replace("}", "")
)
if chosen_val == "A":
A_score += 1.0
elif chosen_val == "B":
B_score += 1.0
for i, judge in enumerate(rvs_judge.choices):
chosen_val = (
judge.message.content.split("\\boxed{")[-1].strip().replace("}", "")
)
if chosen_val == "B":
A_score += 1.0
elif chosen_val == "A":
B_score += 1.0
A_score /= 6.0
B_score /= 6.0
mean_score = (A_score + B_score) / 2.0
A_score -= mean_score
B_score -= mean_score
# to tokenization and scoring
for i, item in enumerate(rollout_group_data):
out_dict = tokenize_for_trainer(self.tokenizer, item[0])
tokens = out_dict["tokens"]
masks = out_dict["masks"]
scores["tokens"].append(tokens)
scores["masks"].append(masks)
scores["scores"].append(A_score if i == 0 else B_score)
return scores
async def get_next_item(self):
next_item = self.train[self.iter % len(self.train)]
self.iter += 1
prompt = tuple(
[
frozenset({"role": item["role"], "content": item["content"]}.items())
for item in next_item["conversation"]
]
)
return (prompt,)
if __name__ == "__main__":
RLAIFEnv.cli()

View file

@ -0,0 +1,478 @@
import json
import random
import re
from typing import Dict, List, Optional, Tuple, Union
import wandb
from datasets import load_dataset
from tqdm.asyncio import tqdm_asyncio
from atroposlib.envs.base import (
BaseEnv,
BaseEnvConfig,
EvalHandlingEnum,
Item,
OpenaiConfig,
ScoredDataGroup,
)
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
system_prompt = (
"You are a deep thinking AI, you may use extremely long chains of thought to deeply consider the "
"problem and deliberate with yourself via systematic reasoning processes to help come to a correct "
"solution prior to answering. You should enclose your thoughts and internal monologue inside <think> "
"</think> tags, and then provide your solution or response to the problem."
)
class SingleToolCallingEnv(BaseEnv):
def __init__(
self,
config: BaseEnvConfig,
server_configs: List[OpenaiConfig],
slurm=True,
testing=False,
):
super().__init__(config, server_configs, slurm, testing)
self.percent_correct_buffer = list()
self.eval_metrics = list()
# Add tracking for wandb visualizations
self.rollouts_for_wandb = []
self.completion_lengths = []
@classmethod
def config_init(self) -> Tuple[BaseEnvConfig, List[OpenaiConfig]]:
env_config = BaseEnvConfig(
tokenizer_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview",
group_size=32,
use_wandb=True,
rollout_server_url="http://localhost:8000",
total_steps=2000,
batch_size=1024,
steps_per_eval=20,
max_token_length=1024 * 16,
inference_weight=1.0,
wandb_name="toolcall_think",
eval_handling=EvalHandlingEnum.LIMIT_TRAIN,
eval_limit_ratio=0.1,
)
server_configs = [
OpenaiConfig(
model_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview",
base_url="http://localhost:9004/v1",
api_key="x",
num_max_requests_at_once=32,
num_requests_for_eval=256,
),
OpenaiConfig(
model_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview",
base_url="http://localhost:9005/v1",
api_key="x",
num_max_requests_at_once=32,
num_requests_for_eval=256,
),
]
return env_config, server_configs
async def create_rollout_table(self, wandb_metrics):
if len(self.rollouts_for_wandb) > 0:
table = wandb.Table(columns=["text", "score", "expected_tool_call"])
for group in self.rollouts_for_wandb:
for item in group:
table.add_data(item[0], item[1], item[2])
wandb_metrics["train/rollouts"] = table
self.rollouts_for_wandb = []
return wandb_metrics
async def wandb_log(self, wandb_metrics: Optional[Dict] = None):
"""
Log to wandb with comprehensive metrics.
"""
if wandb_metrics is None:
wandb_metrics = dict()
# Try to calculate percent_correct, skip if there's a division by zero
try:
wandb_metrics["train/percent_correct"] = sum(
self.percent_correct_buffer
) / len(self.percent_correct_buffer)
except ZeroDivisionError:
# Skip if buffer is empty
pass
self.percent_correct_buffer = list()
for item in self.eval_metrics:
wandb_metrics[item[0]] = item[1]
self.eval_metrics = list()
await super().wandb_log(wandb_metrics)
async def setup(self):
# Load the full dataset
full_dataset = load_dataset(
"NousResearch/XLAM-Atropos",
"default",
split="train",
)
full_dataset = full_dataset.shuffle(seed=42)
# Create train/test split on the fly (e.g., 95% train, 5% test)
split_dataset = full_dataset.train_test_split(test_size=0.02, seed=42)
# Keep the splits as is - no need to reformat
self.train = split_dataset["train"]
self.test = split_dataset["test"]
self.iter = 0
async def rollout_and_score_eval(self, test_item):
# Extract conversations from test item
conversations = test_item["conversations"]
# Find system message and human message
system_message = next(
(msg for msg in conversations if msg["from"] == "system"), None
)
human_message = next(
(msg for msg in conversations if msg["from"] == "human"), None
)
expected_gpt_message = next(
(msg for msg in conversations if msg["from"] == "gpt"), None
)
if not human_message or not expected_gpt_message:
return 0 # Skip invalid conversations
# Create messages for model
messages = []
if system_message:
messages.append(
{
"role": "system",
"content": system_prompt + "\n\n" + system_message["value"],
}
)
messages.append({"role": "user", "content": human_message["value"]})
# Apply chat template to convert messages to a single string
prompt = self.tokenizer.apply_chat_template(
messages, add_generation_prompt=True, tokenize=False
)
# Get model completion using completion() instead of chat_completion()
completion = await self.server.completion(
prompt=prompt,
n=1,
max_tokens=1024 * 15,
temperature=1.0,
split="eval",
)
# Extract the model's response from the completion
model_response = completion.choices[0].text
expected_response = expected_gpt_message["value"]
# Extract and compare tool calls
score = self._compare_tool_calls(model_response, expected_response)
return score
def _extract_tool_call_jsons(self, text):
"""
Extract multiple JSONs from within <tool_call> tags
Args:
text: Text containing tool calls
Returns:
List of parsed JSON objects or empty list if extraction/parsing fails
"""
# Find all content between <tool_call> tags
matches = re.findall(r"<tool_call>\s*(.*?)\s*</tool_call>", text, re.DOTALL)
tool_calls = []
for match in matches:
try:
# Parse the JSON content
json_str = match
tool_call = json.loads(json_str)
tool_calls.append(tool_call)
except json.JSONDecodeError:
# Skip invalid JSON but continue processing other matches
continue
return tool_calls
def _compare_tool_calls(self, model_response, expected_response):
"""
Compare multiple tool calls by extracting JSONs from <tool_call> tags and comparing content
Returns:
1 if all tool calls match (all required calls are present with correct values), 0 otherwise
"""
# Extract JSONs from tool calls
model_jsons = self._extract_tool_call_jsons(model_response)
expected_jsons = self._extract_tool_call_jsons(expected_response)
# If we couldn't extract any JSONs or the count doesn't match, return 0
if not model_jsons or not expected_jsons:
return 0
# Copy the expected_jsons to avoid modifying the original
remaining_expected_jsons = expected_jsons.copy()
# For each model JSON, try to find a matching expected JSON
for model_json in model_jsons:
found_match = False
for i, expected_json in enumerate(remaining_expected_jsons):
if self._json_objects_match(model_json, expected_json):
# Remove the matched expected JSON
remaining_expected_jsons.pop(i)
found_match = True
break
# If no match was found for this model JSON, return 0
if not found_match:
return 0
# If we've matched all expected JSONs (none remaining), return 1
return 1 if not remaining_expected_jsons else 0
def _json_objects_match(self, json1, json2):
"""
Check if two JSON objects match, with all fields in json2 existing in json1
with the same values.
Args:
json1: First JSON object
json2: Second JSON object (expected values)
Returns:
True if objects match, False otherwise
"""
try:
# Check if all expected fields are in model response
for key in json2:
if key not in json1:
return False
# For nested dictionaries (like 'arguments'), check all values
if isinstance(json2[key], dict) and isinstance(json1[key], dict):
for arg_key in json2[key]:
if arg_key not in json1[key]:
return False
if json2[key][arg_key] != json1[key][arg_key]:
return False
# For non-dictionary fields, check direct equality
elif json2[key] != json1[key]:
return False
# All checks passed
return True
except Exception:
# Any error in comparison counts as failure
return False
async def evaluate(self, *args, **kwargs):
eval_tasks = []
for test_item in self.test:
eval_tasks.append(self.rollout_and_score_eval(test_item))
scores = await tqdm_asyncio.gather(*eval_tasks)
self.eval_metrics.append(("eval/percent_correct", sum(scores) / len(scores)))
async def collect_trajectories(self, item) -> Tuple[ScoredDataGroup, List]:
# Extract messages from the item
messages = []
for role_dict in item[0]:
messages.append(dict(role_dict))
# Apply chat template to convert messages to a single string
prompt = self.tokenizer.apply_chat_template(
messages, add_generation_prompt=True, tokenize=False
)
# Get completions from the model using completion() instead of chat_completion()
completions = await self.server.completion(
prompt=prompt,
n=self.config.group_size,
max_tokens=1024 * 15,
temperature=0.8, # Using temperature to get diverse responses
)
to_score = list()
for i, completion_choice in enumerate(completions.choices):
# Create a copy of the prompt messages
trajectory_messages = []
for role_dict in item[0]:
trajectory_messages.append(dict(role_dict))
# Add the model's response
trajectory_messages.append(
{"role": "assistant", "content": completion_choice.text}
)
# Add to scoring queue with expected answer
to_score.append(
(
tuple(trajectory_messages),
item[1], # The expected tool call JSON
)
)
# Call score to get the scored data
scored_data = await self.score(to_score)
to_backlog = []
return scored_data, to_backlog
async def score(
self, rollout_group_data
) -> Union[Optional[ScoredDataGroup], List[Optional[ScoredDataGroup]]]:
scores = ScoredDataGroup()
scores["tokens"] = list()
scores["masks"] = list()
scores["scores"] = list()
# Extract the expected JSONs from the answer
expected_jsons = self._extract_tool_call_jsons(rollout_group_data[0][1])
# If we can't extract the expected tool call JSONs, skip this item
if not expected_jsons:
return None
# Shuffle to avoid bias in selection
random.shuffle(rollout_group_data)
for item in rollout_group_data:
# Extract the model's response
model_response = item[0][-1]["content"]
# Score 1 if tool calls match, 0 otherwise
reward = 1 if self._compare_tool_calls(model_response, item[1]) else 0
# Tokenize the conversation for learning
out_dict = tokenize_for_trainer(self.tokenizer, item[0])
tokens = out_dict["tokens"]
masks = out_dict["masks"]
# Remove examples with insufficient context
if len([1 for i in masks if i != -100]) < 10:
continue
scores["tokens"].append(tokens)
scores["masks"].append(masks)
scores["scores"].append(1.0 if reward else -1.0)
# Break once we have enough examples
if len(scores["tokens"]) >= self.config.group_size:
break
# Record success rate metrics
for score in scores["scores"]:
self.percent_correct_buffer.append(max(score, 0))
# Apply length penalty if all responses are correct
if all([score == 1.0 for score in scores["scores"]]):
# Calculate token lengths
token_lengths = [len(token) for token in scores["tokens"]]
if max(token_lengths) == 0:
# Edge case protection
return None
# Get max allowed token length from config
max_allowed_length = self.config.max_token_length
# Set threshold at 50% of max_token_length - no penalty below this
length_threshold = max_allowed_length * 0.5
# Apply modified length penalty with threshold
scores["scores"] = []
for length in token_lengths:
if length <= length_threshold:
# No penalty for responses under threshold
scores["scores"].append(1.0)
else:
# Calculate how far we are between threshold and max as a percentage
percentage_of_range = (length - length_threshold) / (
max_allowed_length - length_threshold
)
# Cap at 1.0 in case length exceeds max_allowed_length
percentage_of_range = min(percentage_of_range, 1.0)
# Apply linear penalty scaling from 1.0 down to 0.0
scores["scores"].append(1.0 - percentage_of_range)
# Check if all scores are the same (no learning signal)
if all(scores["scores"][0] == score for score in scores["scores"]):
return None
return scores
async def get_next_item(self):
next_item = self.train[self.iter % len(self.train)]
self.iter += 1
# Extract conversation elements
conversations = next_item["conversations"]
# Find system, human and gpt messages
system_message = next(
(msg for msg in conversations if msg["from"] == "system"), None
)
human_message = next(
(msg for msg in conversations if msg["from"] == "human"), None
)
expected_gpt_message = next(
(msg for msg in conversations if msg["from"] == "gpt"), None
)
# Create prompt tuple using frozensets as required
prompt = []
if system_message:
# Combine our base system prompt with the dataset-specific system message
combined_system_content = system_prompt + "\n\n" + system_message["value"]
prompt.append(
frozenset(
{"role": "system", "content": combined_system_content}.items()
)
)
# Add user message
if human_message:
prompt.append(
frozenset({"role": "user", "content": human_message["value"]}.items())
)
# Return expected assistant response (the tool call JSON) as the "answer"
answer = expected_gpt_message["value"] if expected_gpt_message else ""
return (tuple(prompt), answer)
async def add_rollouts_for_wandb(
self,
scored_data: Union[ScoredDataGroup, List[ScoredDataGroup]],
item: Item = None,
):
# save rollout to trajectory
num_keep = self.config.num_rollouts_per_group_for_logging
if num_keep == -1:
num_keep = self.config.group_size
self.rollouts_for_wandb.append(
[
(
self.tokenizer.decode(scored_data["tokens"][i]),
scored_data["scores"][i],
item[1], # Just keep the expected tool call JSON
)
for i in range(num_keep)
]
)
if len(self.rollouts_for_wandb) > self.config.num_rollouts_to_keep:
self.rollouts_for_wandb.pop(0)
if __name__ == "__main__":
SingleToolCallingEnv.cli()

72
example_trainer/README.md Normal file
View file

@ -0,0 +1,72 @@
# GRPO Example Trainer
This directory contains an example script (`grpo.py`) demonstrating how to integrate a custom training loop with the Atropos API for reinforcement learning using the GRPO (Generalized Reinforcement Policy Optimization) algorithm.
This example uses `vLLM` for efficient inference during the (simulated) data generation phase and `transformers` for the training phase.
**Note:** This script is intended as a *reference example* for API integration and basic training setup. It is not optimized for large-scale, efficient training.
## Prerequisites
1. **Python:** Python 3.8 or higher is recommended.
2. **Atropos API Server:** The Atropos API server must be running and accessible (defaults to `http://localhost:8000` in the script).
3. **Python Packages:** You need to install the required Python libraries:
* `torch` (with CUDA support recommended)
* `transformers`
* `vllm`
* `pydantic`
* `numpy`
* `requests`
* `tenacity`
* `wandb` (optional, for logging)
## Setup
1. **Clone the Repository:** Ensure you have the repository containing this example.
2. **Install Dependencies:** `pip install -r requirements.txt`
3. **Ensure Atropos API is Running:** `run-api` in a new window
4. **Run an env:** `python environments/gsm8k_server.py serve --slurm False`
## Configuration
The training configuration is managed within the `grpo.py` script using the `TrainingConfig` Pydantic model (found near the top of the file).
Key parameters you might want to adjust include:
* `model_name`: The Hugging Face model identifier to use for training (e.g., `"gpt2"`, `"Qwen/Qwen2.5-1.5B-Instruct"`).
* `training_steps`: The total number of optimization steps to perform.
* `batch_size` / `gradient_accumulation_steps`: Control the effective batch size.
* `lr`: Learning rate.
* `save_path`: Directory where model checkpoints will be saved.
* `vllm_port`: The port used by the vLLM server instance launched by this script.
* `vllm_restart_interval`: How often (in steps) to save a checkpoint and restart the vLLM server with the new weights.
* `use_wandb`: Set to `True` to enable logging to Weights & Biases.
* `wandb_project`: Your W&B project name (required if `use_wandb=True`).
* `wandb_group`: Optional W&B group name.
**API Endpoints:** The script currently assumes the Atropos API is available at `http://localhost:8000/register` and `http://localhost:8000/batch`. If your API runs elsewhere, you'll need to modify the `register_trainer` and `get_batch` functions accordingly.
## Running the Example
Once the prerequisites are met and configuration is set:
1. Navigate to the root directory of the project in your terminal.
2. Run the script:
```bash
python example_trainer/grpo.py
```
## Output
* **Logs:** Training progress, loss, logp, and vLLM status will be printed to the console.
* **Checkpoints:** Model checkpoints will be saved periodically in the directory specified by `save_path` (default: `./trained_model_checkpoints`). A `final_model` directory will be created upon completion.
* **WandB:** If `use_wandb` is `True`, logs will be sent to Weights & Biases. A link to the run page will be printed in the console.
* `temp.json`: Contains the raw data from the last fetched batch (used for debugging/manual inspection).
```bash
# Install dependencies
pip install -r example_trainer/requirements.txt
# Run the trainer directly (basic test)
python example_trainer/grpo.py

View file

@ -0,0 +1,7 @@
"""
Example trainer implementations of how to implement a trainer for the Atropos library.
"""
from example_trainer.grpo import TrainingConfig, train
__all__ = ["TrainingConfig", "train"]

548
example_trainer/grpo.py Normal file
View file

@ -0,0 +1,548 @@
import atexit
import json
import math
import os
import random
import shutil
import string
import subprocess
import time
from typing import List, Optional, Tuple
import numpy as np
import requests
import torch
import torch.nn.functional as F
from pydantic import BaseModel, Field
from tenacity import retry, stop_after_attempt, wait_exponential
from torch.optim import AdamW
from transformers import AutoModelForCausalLM, AutoTokenizer
import wandb # Added for logging
# Global variable to keep track of the vLLM process
vllm_process = None
def cleanup_vllm():
global vllm_process
if vllm_process:
print("\nTerminating vLLM process...")
vllm_process.terminate()
try:
vllm_process.wait(timeout=5) # Wait a bit for graceful shutdown
print("vLLM process terminated.")
except subprocess.TimeoutExpired:
print("vLLM process did not terminate gracefully, killing.")
vllm_process.kill()
vllm_process.wait()
print("vLLM process killed.")
vllm_process = None
# Register the cleanup function to be called on script exit
atexit.register(cleanup_vllm)
class TrainingConfig(BaseModel):
"""
Training details, model, etc
"""
model_name: str = Field(..., description="Name of the base model to train")
lr: float = Field(1e-5, description="Learning rate for the optimizer")
training_steps: int = Field(
10, description="Number of training steps"
) # Renamed from epochs
batch_size: int = Field(
2, description="Batch size for training (will be handled by get_data)"
)
seq_len: int = Field(2048, description="Sequence length for training")
gradient_accumulation_steps: int = Field(
32, description="Number of gradient accumulation steps"
)
device: str = Field(
"cuda" if torch.cuda.is_available() else "cpu", description="Device to train on"
)
save_path: str = Field(
"trained_model_checkpoints", description="Base path to save model checkpoints"
)
vllm_restart_interval: int = Field(
3, description="Restart vLLM every N training steps"
)
vllm_port: int = Field(9001, description="Port for the vLLM server")
# Wandb configuration
use_wandb: bool = Field(
False, description="Whether to use Weights & Biases for logging"
)
wandb_project: Optional[str] = Field(None, description="Wandb project name")
wandb_group: Optional[str] = Field(None, description="Wandb group name")
@retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=15))
def register_trainer(config: TrainingConfig):
"""
Register the trainer with the Atropos API
"""
requests.post(
"http://localhost:8000/register",
json={
"wandb_group": config.wandb_group,
"wandb_project": config.wandb_project,
"batch_size": config.batch_size * config.gradient_accumulation_steps,
"max_token_len": config.seq_len,
"starting_step": 0,
"checkpoint_dir": config.save_path,
"save_checkpoint_interval": config.training_steps,
"num_steps": config.training_steps,
},
timeout=10,
)
@retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=15))
def get_batch():
data = requests.get("http://localhost:8000/batch", timeout=10).json()
return data
def pad_data_to_good_offset(data, batch_size: int):
max_token_len = max(
[max([len(x) for x in item["tokens"]]) for item in data["batch"]]
)
# usually 64 is a good choice to ensure nonweird scaling behavior on GPUS
# so we pad to the nearest multiple of 64
good_multiple = 64
if (max_token_len - 1) % (good_multiple) != 0:
max_token_len = math.ceil((max_token_len - 1) / (good_multiple)) * good_multiple
token_setup_len = (
max_token_len + 1
) # add 1 so we can make it causal at the proper length
else:
token_setup_len = max_token_len
max_token_len = (
max_token_len - 1
) # since it's causal we need to remove the last bit...
# pad all tokens to max_token_len and add to lists
input_ids = list()
labels = list()
advantages = list()
lengths = list()
for item in data["batch"]:
scores = item["scores"]
scores = np.array(scores)
# check if we have more than 1 score...
if len(scores) > 1:
scores = scores - scores.mean()
scores = scores / max(scores.std(), 1e-8)
item["scores"] = scores
if item["overrides"] is not None:
for i in range(len(item["overrides"])):
if item["overrides"][i].get("set_advantage_to_zero", False):
item["scores"][i] = 0
for i in range(len(item["tokens"])):
lengths.append(
math.ceil((len(item["tokens"][i]) - 1) / (good_multiple))
* good_multiple
)
label_item = np.concatenate(
[
np.array(item["masks"][i]),
np.full(
max(0, token_setup_len - len(item["tokens"][i])),
-100,
dtype=np.int32,
),
]
)
item["tokens"][i] = np.concatenate(
[
np.array(item["tokens"][i]),
np.zeros(
max(0, token_setup_len - len(item["tokens"][i])), dtype=np.int32
),
]
)
input_ids.append(item["tokens"][i][:-1])
labels.append(label_item[1:])
advantages.append(item["scores"][i])
# combine all lists into tensors
token_batches = []
label_batches = []
advantage_batches = []
for i in range(len(input_ids) // batch_size):
token_batches.append(
torch.tensor(
np.stack(input_ids[i * batch_size : (i + 1) * batch_size], axis=0)
)
)
label_batches.append(
torch.tensor(
np.stack(labels[i * batch_size : (i + 1) * batch_size], axis=0)
)
)
advantage_batches.append(
torch.tensor(
np.stack(advantages[i * batch_size : (i + 1) * batch_size], axis=0)
).view(-1, 1)
)
return token_batches, label_batches, advantage_batches
def get_data(
batch_size: int, seq_len: int
) -> List[Tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]]]:
"""
getting data from the api
"""
batches = []
while True:
data = get_batch()
if data["batch"] is not None:
# Save the batch
with open("temp.json", "w", encoding="utf-8") as f:
json.dump(data, f)
# In case the inference runs ahead of the training, we loop until we don't have any more data
batches.append(pad_data_to_good_offset(data, batch_size))
elif len(batches) > 0:
# Return the batches
return batches
else:
time.sleep(1)
def train(config: TrainingConfig):
"""
Setups and runs GRPO training, restarting vLLM periodically, with wandb logging.
"""
global vllm_process # Declare intention to modify the global variable
# --- Wandb Setup ---
if config.use_wandb:
if not config.wandb_project:
print("Warning: wandb_project not set, disabling wandb.")
config.use_wandb = False
else:
if not config.wandb_group:
# Set group to random 8 character string
config.wandb_group = "".join(
random.choices(string.ascii_letters + string.digits, k=8)
)
try:
wandb.init(
project=config.wandb_project,
group=config.wandb_group,
config=config.dict(), # Log config parameters
)
print(
f"Wandb logging enabled. Run: {wandb.run.name} (Project: {config.wandb_project}) "
)
except Exception as e:
print(f"Error initializing wandb: {e}. Disabling wandb.")
config.use_wandb = False
# --- End Wandb Setup ---
# Initialize model and tokenizer
tokenizer = AutoTokenizer.from_pretrained(config.model_name)
model = AutoModelForCausalLM.from_pretrained(
config.model_name, torch_dtype=torch.bfloat16
)
model.to(config.device)
model.gradient_checkpointing_enable()
model.train()
# Setup optimizer
optimizer = AdamW(model.parameters(), lr=config.lr)
print(
f"Starting training for {config.training_steps} steps on device: {config.device}"
)
print(
f"vLLM will be restarted every {config.vllm_restart_interval} steps on port {config.vllm_port}"
)
os.makedirs(config.save_path, exist_ok=True) # Ensure base save directory exists
register_trainer(config)
# Init vllm
vllm_command = [
"python",
"-m",
"vllm.entrypoints.openai.api_server",
"--model",
config.model_name,
"--port",
str(config.vllm_port),
"--dtype",
"auto",
"--gpu-memory-utilization",
"0.45",
"--disable-log-requests",
]
print(f" Launching vLLM server: {' '.join(vllm_command)}")
try:
vllm_process = subprocess.Popen(vllm_command)
print(f" vLLM server launched with PID: {vllm_process.pid}")
# Check immediate errors
try:
stdout, stderr = vllm_process.communicate(timeout=2)
if vllm_process.returncode is not None and vllm_process.returncode != 0:
print(f" Error starting vLLM: {stderr.decode()}")
vllm_process = None
# Maybe raise error or just warn?
print(" WARNING: Failed to start vLLM server after checkpoint.")
except subprocess.TimeoutExpired:
print(" vLLM process started (check logs for details).")
except FileNotFoundError:
print(
"\n *** ERROR: 'python -m vllm...' command not found. Make sure vLLM is installed and accessible. ***\n"
)
# Potentially stop training or just disable further vLLM restarts
print(" Disabling further vLLM restarts.")
config.vllm_restart_interval = (
config.training_steps + 1
) # Prevent further restarts
except Exception as e:
print(f"\n *** ERROR: Failed to launch vLLM: {e} ***\n")
print(" Disabling further vLLM restarts.")
config.vllm_restart_interval = (
config.training_steps + 1
) # Prevent further restarts
batches = list()
for step in range(config.training_steps):
total_loss = 0
print(f"Step {step+1}/{config.training_steps}")
total_pos_logp = 0
total_neg_logp = 0
total_logp = 0
total_pos = 0
total_neg = 0
if len(batches) == 0:
batches = get_data(config.batch_size, config.seq_len)
token_batches, label_batches, advantage_batches = batches.pop(0)
# Terminate existing vLLM process if running
if (
step + 1
) % config.vllm_restart_interval == 0 or step == config.training_steps - 1: # Also restart/save on last step
# Terminate existing vLLM process if running
if vllm_process:
print(" Terminating existing vLLM process...")
vllm_process.terminate()
try:
vllm_process.wait(timeout=5)
except subprocess.TimeoutExpired:
print(
" Existing vLLM process did not terminate gracefully, killing."
)
vllm_process.kill()
vllm_process.wait()
vllm_process = None
for tokens, labels, advantages in zip(
token_batches, label_batches, advantage_batches
):
tokens, labels, advantages = (
tokens.to(config.device),
labels.to(config.device),
advantages.to(config.device),
)
# Forward pass
# User specified that tokens/labels are already prepared by get_data
outputs = model(tokens) # Assuming model just needs tokens
logits = outputs.logits # Assuming this is the structure
# Calculate GRPO loss (reverting to user's previous logic)
# User stated ignore_index is -100 and tokens/labels are aligned by get_data
# Assuming logits correspond directly to labels indices (no shift needed here)
logp_per_token = -F.cross_entropy(
logits.view(-1, logits.size(-1)), # Flatten logits
labels.view(-1), # Flatten labels
reduction="none",
ignore_index=-100, # User specified ignore index
).view(
labels.shape
) # Reshape back to (batch, seq_len)
# Masking based on labels != -100
mask = (labels != -100).float()
with torch.no_grad():
pos = (advantages > 0).float()
neg = (advantages <= 0).float()
avg_logp = (logp_per_token * mask).sum(-1) / mask.sum(-1)
pos_logp = (logp_per_token * pos).mean().item()
neg_logp = (logp_per_token * neg).mean().item()
total_pos_logp += pos_logp
total_neg_logp += neg_logp
total_logp += avg_logp
total_pos += pos.sum().item()
total_neg += neg.sum().item()
grpo_loss_term = torch.exp(logp_per_token - logp_per_token.detach())
grpo_loss = (
((-grpo_loss_term * mask).sum(-1) / mask.sum(-1))
* advantages.to(logp_per_token.device)
).mean() / config.gradient_accumulation_steps
grpo_loss.backward()
total_loss += grpo_loss.item()
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
optimizer.zero_grad()
if total_pos > 0:
total_pos_logp /= total_pos
if total_neg > 0:
total_neg_logp /= total_neg
# --- Wandb Logging ---
if config.use_wandb:
wandb.log(
{
"train/loss": total_loss,
"train/learning_rate": optimizer.param_groups[0]["lr"],
"train/grad_norm": grad_norm.item(),
"train/pos_logp": total_pos_logp,
"train/neg_logp": total_neg_logp,
"train/logp": total_logp,
},
step=step + 1,
)
# --- End Wandb Logging ---
print(f" Step Loss: {grpo_loss.item():.4f}")
# --- vLLM Restart Logic (Moved AFTER optimizer step) ---
# Note: There are much better ways of updating the policy, this is just a very simple example
if (
step + 1
) % config.vllm_restart_interval == 0 or step == config.training_steps - 1: # Also restart/save on last step
checkpoint_path = os.path.join(
config.save_path, f"step_{step+1}"
) # Save as step+1 since it's after step completion
print(f" Saving checkpoint to {checkpoint_path}...")
# Ensure fresh directory for saving
if os.path.exists(checkpoint_path):
shutil.rmtree(checkpoint_path) # Remove old checkpoint if it exists
os.makedirs(checkpoint_path, exist_ok=True)
model.save_pretrained(checkpoint_path)
tokenizer.save_pretrained(checkpoint_path)
print(" Checkpoint saved.")
# Terminate existing vLLM process if running
if vllm_process:
print(" Terminating existing vLLM process...")
vllm_process.terminate()
try:
vllm_process.wait(timeout=5)
except subprocess.TimeoutExpired:
print(
" Existing vLLM process did not terminate gracefully, killing."
)
vllm_process.kill()
vllm_process.wait()
vllm_process = None
# Launch new vLLM process (only if not the very last step, maybe? depends on use case)
# Let's still launch it on the last step for consistency, cleanup will handle it.
vllm_command = [
"python",
"-m",
"vllm.entrypoints.openai.api_server",
"--model",
os.path.join(config.save_path, f"step_{step+1}"),
"--port",
str(config.vllm_port),
"--dtype",
"auto",
"--gpu-memory-utilization",
"0.45",
"--disable-log-requests",
"--served-model-name",
config.model_name,
]
print(f" Launching vLLM server: {' '.join(vllm_command)}")
torch.cuda.empty_cache()
try:
vllm_process = subprocess.Popen(vllm_command)
print(f" vLLM server launched with PID: {vllm_process.pid}")
# Check immediate errors
try:
stdout, stderr = vllm_process.communicate(timeout=2)
if (
vllm_process.returncode is not None
and vllm_process.returncode != 0
):
print(f" Error starting vLLM: {stderr.decode()}")
vllm_process = None
# Maybe raise error or just warn?
print(
" WARNING: Failed to start vLLM server after checkpoint."
)
except subprocess.TimeoutExpired:
print(" vLLM process started (check logs for details).")
except FileNotFoundError:
print(
"\n *** ERROR: 'python -m vllm...' command not found. ",
"Make sure vLLM is installed and accessible. ***\n",
)
# Potentially stop training or just disable further vLLM restarts
print(" Disabling further vLLM restarts.")
config.vllm_restart_interval = (
config.training_steps + 1
) # Prevent further restarts
except Exception as e:
print(f"\n *** ERROR: Failed to launch vLLM: {e} ***\n")
print(" Disabling further vLLM restarts.")
config.vllm_restart_interval = (
config.training_steps + 1
) # Prevent further restarts
# --- End vLLM Restart Logic ---
# Basic check if vLLM process terminated unexpectedly (outside interval check)
if vllm_process and vllm_process.poll() is not None:
print(
f"\n *** WARNING: vLLM process terminated unexpectedly (return code: {vllm_process.returncode}). ",
"Check vLLM logs. ***\n",
)
stderr_output = (
vllm_process.stderr.read().decode()
if vllm_process.stderr
else "No stderr"
)
print(f"vLLM stderr: {stderr_output}")
vllm_process = None # Reset so it relaunches next interval
print("Training finished.")
# --- Wandb Finish ---
if config.use_wandb:
wandb.finish()
# --- End Wandb Finish ---
# Final cleanup (vLLM termination) is handled by atexit
# --- Placeholder for final model save ---
final_save_path = os.path.join(config.save_path, "final_model")
print(f"Saving final model to {final_save_path}")
if os.path.exists(final_save_path):
shutil.rmtree(final_save_path)
os.makedirs(final_save_path, exist_ok=True)
model.save_pretrained(final_save_path)
tokenizer.save_pretrained(final_save_path)
print("Final model saved.")
# Example usage (optional, can be run from another script)
if __name__ == "__main__":
# Example: Create a config and run training
# Replace "gpt2" with your desired model
training_config = TrainingConfig(
model_name="Qwen/Qwen2.5-1.5B-Instruct",
training_steps=20, # Use steps
vllm_restart_interval=3, # Example interval
use_wandb=True, # Set to True to enable logging
wandb_project="grpo-trainer-example", # Replace with your project name
)
# --- End Mock ---
train(training_config)

View file

@ -0,0 +1,5 @@
vllm
torch
transformers
datasets
accelerate

View file

@ -0,0 +1,76 @@
from dataclasses import dataclass
from typing import Dict, List, Optional
@dataclass
class ThresholdLengthPenaltyConfig:
"""Configuration for length penalty calculations"""
max_token_length: int
threshold_percentage: float = 0.5 # Default threshold at 50% of max length
class ThresholdLengthPenaltyCalculator:
"""Handles calculation of length-based penalties for token sequences"""
def __init__(self, config: ThresholdLengthPenaltyConfig):
"""
Initialize the length penalty calculator
Args:
config: Configuration object containing max_token_length and threshold settings
"""
self.config = config
self.length_threshold = (
self.config.max_token_length * self.config.threshold_percentage
)
def apply_length_penalties(
self, scores: Dict[str, List]
) -> Optional[Dict[str, List]]:
"""
Apply length-based penalties to scores if all responses are correct
Args:
scores: Dictionary containing 'scores' and 'tokens' lists
Returns:
Modified scores dictionary or None if invalid input
"""
# Validate input
if not scores or "scores" not in scores or "tokens" not in scores:
return None
# Only apply penalties if all responses are correct
if not all([score == 1.0 for score in scores["scores"]]):
return scores
# Calculate token lengths
token_lengths = [len(token) for token in scores["tokens"]]
if max(token_lengths) == 0:
return None
# Apply modified length penalty with threshold
new_scores = []
for length in token_lengths:
if length <= self.length_threshold:
# No penalty for responses under threshold
new_scores.append(1.0)
else:
# Calculate penalty based on how far we are between threshold and max
percentage_of_range = (length - self.length_threshold) / (
self.config.max_token_length - self.length_threshold
)
# Cap at 1.0 in case length exceeds max_allowed_length
percentage_of_range = min(percentage_of_range, 1.0)
# Apply linear penalty scaling from 1.0 down to 0.0
new_scores.append(1.0 - percentage_of_range)
scores["scores"] = new_scores
return scores
# Example usage:
# config = LengthPenaltyConfig(max_token_length=1024)
# calculator = LengthPenaltyCalculator(config)
# modified_scores = calculator.apply_length_penalties(scores_dict)

444
llm.txt Normal file
View file

@ -0,0 +1,444 @@
# Atropos Library Documentation (for LLM Context)
This document provides comprehensive information about the Atropos library, Nous Research's LLM RL Gym. It covers its purpose, features, usage, components, configuration, and contribution guidelines.
---
## 1. Introduction: Atropos - Nous Research's LLM RL Gym
Atropos is an LLM Reinforcement Learning Environments framework designed for collecting and evaluating LLM trajectories through diverse environments.
**Supported Environment Types:**
<div align="center">
| Environment Type | Examples | Purpose |
|---------------------------|--------------------------------------------|----------------------------------------------------|
| 📚 Dataset environments | GSM8K, MMLU | Evaluate and improve LLM performance on static data|
| 🎮 Online environments | Crosswords, Hangman | Train LLMs through interactive game-based learning |
| 🤖 RLAIF and RLHF | LLM Judge/Reward Models | Fine-tune LLMs using human feedback and alignment |
| 🔄 Multi-Turn RL | deepresearch, internal tool calling | Train LLMs on complex multi-step interactions |
</div>
Atropos provides a robust, scalable framework for **Reinforcement Learning Environments with LLMs**.
**Key Features:**
* **Multi-Turn & Asynchronous RL:** Efficiently supports complex, multi-turn, and asynchronous interactions, decoupling environment steps from policy updates.
* **Inference Agnostic:** Integrates with standard inference APIs (e.g., OpenAI, vLLM, sgLang), enabling easy switching between LLM providers and frameworks.
* **Trainer Independent:** Offers a standardized training interface for experimenting with different RL algorithms and frameworks without major code changes.
* **Scalable & Decentralized:** Easily scale by launching more environment instances (locally or across decentralized resources) that contribute rollouts to a central service.
* **Diverse Environment Integration:** Manages many varied environment types concurrently for heterogeneous, multi-modal training.
**Goal:** Provide a flexible, scalable, and standardized platform to accelerate LLM-based RL research across diverse, interactive settings.
---
## 5. Navigating the Repo
| Category | Description |
|-------------------------------|--------------------------------------------------|
| 📁 [`atroposlib/`](atroposlib/) | Core library containing base classes and utilities |
| 🎮 [`environments/`](environments/) | Collection of ready-to-use RL environments |
| 📚 [`example_trainer/`](example_trainer/) | Example training scripts and configurations |
**Key Documents:**
* **Base Environment Class:** `atroposlib/environments/README.md` (Detailed in Section 9 below)
* **Environments Overview:** `environments/README.md` (Detailed in Section 8 below)
* **Full Environment Config Options:** `CONFIG.md` (Detailed in Section 10 below)
* **Example Trainer:** `example_trainer/README.md` (Detailed in Section 7 below)
* **Contributing Guide:** `CONTRIBUTING.md` (Detailed in Section 11 below)
* **License:** `LICENSE.md` (Apache 2.0 license details)
---
## 6. Installation
Requires Python 3.10 or later.
```bash
# Core usage
pip install -e .
# Development (includes testing, linting tools)
pip install -e .[dev]
# Running examples (includes dependencies like vLLM, transformers)
pip install -e .[examples]
# Everything
pip install -e .[all]
```
**Important for Developers:** Install pre-commit hooks to ensure code quality:
```bash
pre-commit install
```
---
## 7. Quick Start Guide
1. **Create Your First Environment:**
* Review the [Base Environment Class Documentation](#9-core-library-atroposlib) (Section 9).
* Examine existing environments in [`environments/`](#8-environments) for examples.
2. **Run an Example Environment:**
```bash
# Start the central API server (trajectory handler) in the background
run-api &
# Start an environment server (e.g., GSM8K) connected to the API
python environments/gsm8k_server.py serve \
--tokenizer_name="Qwen/Qwen2.5-1.5B-Instruct" \
--model_name="Qwen/Qwen2.5-1.5B-Instruct" \
--slurm False # Assuming local run, set True for SLURM cluster
```
*Note: The model and tokenizer names are examples.*
3. **Training Your Model:**
* Refer to the [Example Trainer Guide](#7-training-with-the-example-trainer) (Section 7).
* Monitor progress via logging: completion lengths, eval accuracies, full rollouts/scores (see WandB image in original README).
* Multiple environments can run concurrently, pointing to the same `run-api` server.
**Logging:** Environments provide detailed logging, tracking completion lengths, eval accuracies, full rollouts, scores, etc. Supports WandB integration.
---
## 8. Environments
The `environments/` directory contains various RL environments.
### 8.1. Common Features Across Environments
1. **Training/Test Split:** Typically 98% training, 2% test, with fixed random shuffling (seed 42).
2. **Metrics Tracking:** Includes percent correct buffer, completion lengths, Wandb integration, and rollout tracking.
3. **Token Management:** Maximum token length limits, statistics tracking, and optional length penalties.
4. **Evaluation:** Separate evaluation on the test set with comprehensive metrics logging. Supports multiple completions per prompt.
5. **Usage Interface:** Environments generally follow a common interface:
* Initialize with `config` (BaseEnvConfig), `server_configs` (OpenAI API configs), `slurm` (bool), `testing` (bool).
* Key methods: `setup()`, `get_next_item()`, `collect_trajectories()`, `score()` (often part of postprocessing), `evaluate()`, `wandb_log()`.
### 8.2. Available Environments
#### 8.2.1. MCQA Thinking Environment (`mcqa_thinking_env.py`)
Multiple Choice Question Answering (MMLU dataset) requiring systematic thought.
* **Input Format:** MMLU items (`prompt`, `answer` index, `ground_truth` letter, `options` list).
* **System Prompt:**
```
You are a deep thinking AI, you may use extremely long chains of thought to deeply consider the problem and deliberate with yourself via systematic reasoning processes to help come to a correct solution prior to answering. You should enclose your thoughts and internal monologue inside <think> </think> tags, and then provide your solution or response to the problem.
```
* **Reward Function:**
* 1.0 for correct letter match.
* 0.0 for incorrect or malformed response (e.g., bad `<think>` tags).
* Length penalty applied *only if all responses in a group are correct*: scales linearly from 1.0 (<=50% max length) down to 0.0 (>=100% max length).
* Returns `None` if all scores in a group are identical (no training signal).
#### 8.2.2. GSM8K Environment (`gsm8k_server.py`)
Mathematical reasoning (GSM8K dataset).
* **Input Format:** GSM8K items (`question`, `answer` number).
* **System Prompt:**
```
You are a deep thinking AI, you may use extremely long chains of thought to deeply consider the problem and deliberate with yourself via systematic reasoning processes to help come to a correct solution prior to answering. You should enclose your thoughts and internal monologue inside <think> </think> tags, and then provide your solution or response to the problem.
You are allocated a maximum of 4096 tokens, please strive to use less.
You will then provide your answer like this: \boxed{your answer here}
It is important that you provide your answer in the correct format.
If you do not, you will not receive credit for your answer.
So please end your answer with \boxed{your answer here}
```
* **Reward Function:**
* 1.0 if `\boxed{}` answer matches ground truth (uses LaTeX verification).
* 0.0 if incorrect or ground truth isn't parseable.
* Length penalty applied *only if all responses in a group are correct*: scales linearly from 1.0 (<=50% max length) down to 0.0 (>=100% max length).
* Returns `None` if all scores in a group are identical.
#### 8.2.3. Tool Calling Environment (`tool_calling_server.py`)
Training models for structured function/tool calls (ShareGPT-Hermes function call dataset).
* **Input Format:** Conversations (`system`, `human`, `gpt` roles) with expected tool calls (JSON format).
* **System Prompt:** (Same "deep thinking AI" prompt as MCQA)
```
You are a deep thinking AI, you may use extremely long chains of thought to deeply consider the problem and deliberate with yourself via systematic reasoning processes to help come to a correct solution prior to answering. You should enclose your thoughts and internal monologue inside <think> </think> tags, and then provide your solution or response to the problem.
```
* **Reward Function:**
* 1.0 if *all* expected tool calls are present and *exactly* match (including nested JSON).
* 0.0 if any calls are missing, incorrect, or malformed.
* Length penalty applied *only if all responses in a group are correct*: scales linearly from 1.0 (<=50% max length) down to 0.0 (>=100% max length).
* Returns `None` if all scores in a group are identical.
---
## 9. Training with the Example Trainer
The `example_trainer/` directory provides `grpo.py`, a script demonstrating integration with Atropos using the GRPO algorithm.
**Note:** This is a *reference example* for API integration and basic setup, *not* optimized for large-scale training. It uses `vLLM` for inference (simulated data generation) and `transformers` for training.
### 9.1. Prerequisites
1. Python 3.8+.
2. Running Atropos API server (default: `http://localhost:8000`). Accessible via `run-api`.
3. Required Python packages: `torch`, `transformers`, `vllm`, `pydantic`, `numpy`, `requests`, `tenacity`, `wandb` (optional). Install via `pip install -r example_trainer/requirements.txt` or `pip install -e .[examples]`.
4. A running Atropos environment (e.g., `python environments/gsm8k_server.py serve --slurm False`).
### 9.2. Setup
1. Clone the Atropos repository.
2. Install dependencies (see Prerequisites).
3. Start the Atropos API: `run-api`.
4. Start an environment connected to the API (e.g., GSM8K example above).
### 9.3. Configuration (`grpo.py`)
Configuration is managed via the `TrainingConfig` Pydantic model within `grpo.py`.
**Key Parameters:**
* `model_name`: Hugging Face model identifier (e.g., `"Qwen/Qwen2.5-1.5B-Instruct"`).
* `training_steps`: Total optimization steps.
* `batch_size` / `gradient_accumulation_steps`: Control effective batch size.
* `lr`: Learning rate.
* `save_path`: Directory for model checkpoints (default: `./trained_model_checkpoints`).
* `vllm_port`: Port for the script's vLLM inference server instance.
* `vllm_restart_interval`: Steps between saving checkpoints and restarting vLLM with updated weights.
* `use_wandb`: Enable/disable Weights & Biases logging.
* `wandb_project`: W&B project name (required if `use_wandb=True`).
* `wandb_group`: Optional W&B group name.
**API Endpoints:** Assumes API at `http://localhost:8000`. Modify `register_trainer` and `get_batch` functions if different.
### 9.4. Running the Example
Navigate to the project root and run:
```bash
python example_trainer/grpo.py
```
### 9.5. Output
* **Console Logs:** Training progress (loss, logp), vLLM status.
* **Checkpoints:** Saved periodically in `save_path`. `final_model` directory upon completion.
* **WandB:** Logs sent to W&B if enabled (link printed to console).
* `temp.json`: Raw data from the last fetched batch (for debugging).
---
## 10. Core Library (`atroposlib`)
The `atroposlib/` directory contains the core framework components.
### 10.1. Base Environment (`atroposlib.envs.base.BaseEnv`)
This class provides the foundation for creating custom RL environments. Subclass `BaseEnv` and implement/override methods as needed.
**Core Methods to Implement:**
* **`async def setup(self)`**: Called once at the start. Use for initial setup (loading data, models, etc.).
* **`async def get_next_item(self) -> Item`**: Returns the next data item (prompt, state) for trajectory collection. Return `None` to pause the worker if no items are ready. `Item` is typically a Pydantic model defined by the environment.
* **`async def collect_trajectory(self, item: Item) -> Tuple[Any | None, List[Item]]`**: Defines logic for *one* trajectory collection step based on `item`. The base class runs this in parallel (`group_size` times). Returns a tuple: `(collected_data_for_this_step, list_of_new_backlog_items)`. The collected data can be any type suitable for later processing.
* **`async def evaluate(self, *args, **kwargs)`**: Called periodically (`steps_per_eval`) for evaluation runs. Implement your evaluation logic here. The base class provides `self.eval_workers` for parallel tasks.
**Optional Methods to Override:**
* **`async def collect_trajectories(self, item: Item) -> Tuple[Union[Optional[ScoredDataGroup], List[Optional[ScoredDataGroup]], List[Any | None]], List[Item]]`**: Override this *instead* of `collect_trajectory` for custom batch generation logic (generating the whole group at once). `ScoredDataGroup` is a structure usually containing prompts, responses, and scores.
* **`async def postprocess_histories(self, trajectories: Union[Optional[ScoredDataGroup], List[Optional[ScoredDataGroup]]]) -> Union[Optional[ScoredDataGroup], List[Optional[ScoredDataGroup]]]`**: Called after `collect_trajectories` and before sending data to the server. Use for final processing, scoring, filtering, or formatting of the collected group data.
* **`async def wandb_log(self, wandb_metrics: Optional[Dict] = None)`**: Called periodically for W&B logging. Add custom metrics to `wandb_metrics`. **Crucially, call `await super().wandb_log(wandb_metrics)`** at the end to include base metrics and rollouts.
* **`save_checkpoint(self, step, data=None)`**: Called automatically by the server based on `checkpoint_interval`. Saves the provided `data` dict (populated with environment state) to JSON. Override to customize *what* or *how* data is saved.
* **`@classmethod config_init(cls) -> Tuple[BaseEnvConfig, Union[ServerBaseline, List[OpenaiConfig]]]`**: Used by CLI `serve` command setup. Returns initial `BaseEnvConfig` and server config(s). Override for custom default CLI configurations. Default returns `cls.env_config_cls(), ServerBaseline()`.
* **`async def cleanup(self)`**: Called after each item processing (`handle_env`). Use for per-item cleanup if needed (rarely required).
**Provided Functionality:**
* **Parallel Trajectory Collection:** Base `collect_trajectories` handles running `collect_trajectory` in parallel.
* **Server Interaction:** Handles registration, config fetching, data sending (with retries via `handle_send_to_api`), status updates.
* **WandB Integration:** Setup, logging hook (`wandb_log`), rollout table helpers (`add_rollouts_for_wandb`, `create_rollout_table`).
* **Checkpointing:** Automatic triggering via server (`checkpoint_interval`), `save_checkpoint` method, automatic loading via `load_checkpoint(self)` on startup if `curr_step > 0`.
* **Worker Management:** Asynchronous task management (`add_train_workers`, `handle_env`).
* **Performance Monitoring:** Tracks and logs task durations, worker counts, etc.
* **CLI Integration:** `cli()` class method using `pydantic-cli` for easy `serve` commands. See `get_cli_serve_config_cls` and `get_cli_process_config_cls`.
### 10.2. Configuration Options (`atroposlib`)
Configuration is primarily managed via Pydantic models, often exposed through a CLI (`pydantic-cli`).
#### 10.2.1. Base Environment Config (`atroposlib.envs.base.BaseEnvConfig`)
| Parameter | Type | Default | Description |
| :------------------------------- | :----------------------- | :---------------------------------------------- | :--------------------------------------------------------------------------------------------------------- |
| `group_size` | `int` | `4` | Number of responses grouped for scoring. |
| `max_num_workers` | `int` | `-1` | Max workers. `-1` calculates from `max_num_workers_per_node`. |
| `max_eval_workers` | `int` | `16` | Max workers for evaluation. |
| `max_num_workers_per_node` | `int` | `8` | Max workers per node. |
| `steps_per_eval` | `int` | `100` | Steps between evaluations. |
| `max_token_length` | `int` | `2048` | Max token length for generations. |
| `eval_handling` | `EvalHandlingEnum` | `EvalHandlingEnum.STOP_TRAIN` | How evals affect training workers (`STOP_TRAIN`, `LIMIT_TRAIN`, `NONE`). |
| `eval_limit_ratio` | `float` | `0.5` | Ratio of training workers limited during evals (if `eval_handling` is `LIMIT_TRAIN`). |
| `inference_weight` | `float` | `1.0` | Inference weight (set by trainer/policy). `-1` ignores if handled specially. |
| `batch_size` | `int` | `-1` | Training batch size (usually set by trainer via API). |
| `max_batches_offpolicy` | `int` | `3` | Max number of off-policy batches queued. |
| `tokenizer_name` | `str` | `"NousResearch/DeepHermes-3-Llama-3-1B-Preview"` | Default Hugging Face tokenizer. |
| `use_wandb` | `bool` | `True` | Enable/disable W&B logging. |
| `rollout_server_url` | `str` | `"http://localhost:8000"` | URL of the central rollout server (FastAPI). |
| `total_steps` | `int` | `1000` | Total steps to run (can be overridden by trainer). |
| `wandb_name` | `str | None` | `None` | W&B run name (often set automatically). |
| `num_rollouts_to_keep` | `int` | `32` | Number of full rollouts to display on W&B table. |
| `num_rollouts_per_group_for_logging` | `int` | `1` | Rollouts per group to keep for logging. `-1` keeps all. |
| `ensure_scores_are_not_same` | `bool` | `True` | Ensure scores in a group aren't identical (reject group if they are). Set `False` if identical scores are valid. |
| `data_path_to_save_groups` | `str | None` | `None` | If set, save generated/scored groups to this JSONL file path. |
| `min_items_sent_before_logging` | `int` | `2` | Min API sends before logging metrics. `<=0` logs every time. |
#### 10.2.2. Server Manager Config (`atroposlib.envs.server_handling.server_manager.ServerManagerConfig`)
Settings for the `ServerManager` which handles inference server interactions.
| Parameter | Type | Default | Description |
| :-------- | :------ | :------ | :------------------------------------------------ |
| `slurm` | `bool` | `True` | Whether the environment is running on SLURM. |
| `testing` | `bool` | `False` | If `True`, uses mock OpenAI data (for testing). |
#### 10.2.3. Server Baseline Config (`atroposlib.envs.server_handling.server_manager.ServerBaseline`)
Default settings used by `ServerManager` if specific `OpenaiConfig` list isn't provided (e.g., for local/SLURM discovery).
| Parameter | Type | Default | Description |
| :------------------------- | :------ | :-------- | :------------------------------------------------------------------------------------------------------ |
| `timeout` | `int` | `1200` | Request timeout (seconds). |
| `num_max_requests_at_once` | `int` | `512` | Max concurrent requests (training). Divide by generation `n` param. |
| `num_requests_for_eval` | `int` | `64` | Max concurrent requests (evaluation). |
| `model_name` | `str` | `default` | Default model name for inference calls. |
| `rolling_buffer_length` | `int` | `1000` | Buffer length for server metrics (timings, attempts). |
#### 10.2.4. OpenAI Server Config (`atroposlib.envs.server_handling.openai_server.OpenaiConfig`)
Configuration for individual OpenAI-compatible API servers (official OpenAI, local vLLM/SGLang, etc.). A list of these can be passed to the environment.
| Parameter | Type | Default | Description |
| :------------------------- | :----------- | :-------- | :------------------------------------------------------------------------------------------------------ |
| `api_key` | `str | None` | `None` | API key. Use `"x"` or any non-empty string for local servers without auth. `None` might imply env var. |
| `base_url` | `str | None` | `None` | API endpoint URL. `None` for official OpenAI. Local: e.g., `http://localhost:9004/v1`. |
| `timeout` | `int` | `1200` | Request timeout (seconds). |
| `num_max_requests_at_once` | `int` | `512` | Max concurrent requests (training). Divide by generation `n`. |
| `num_requests_for_eval` | `int` | `64` | Max concurrent requests (evaluation). |
| `model_name` | `str` | `default` | **Required.** Model name for this server (e.g., `"gpt-4"`, `"NousResearch/..."`). |
| `rolling_buffer_length` | `int` | `1000` | Buffer length for this server's metrics. |
---
## 11. Debugging Tools
The trajectory-handler provides local debugging tools:
* **Flexible Model Provider Support:** Natively supports any OpenAI API-compliant provider. Provide `base_url` and `api_key` for local testing/running.
* **View Run (`view-run`):** Launch a Gradio UI after starting the API (`run-api`) and an environment (`python environments/gsm8k_server.py serve`). Use `view-run` command to inspect batches of rollouts visually.
* **Offline Data Generation:**
* `atropos-sft-gen`: Collect rollouts and format for Supervised Fine-Tuning (SFT).
* `atropos-dpo-gen`: Collect rollouts and format for Direct Preference Optimization (DPO).
---
## 12. Contributing to Atropos
Contributions are welcome! Follow these guidelines.
### 12.1. How We Develop
* **GitHub:** Used for hosting, issue tracking, and Pull Requests (PRs).
* **GitHub Flow:** Development happens via PRs merged into the `main` branch.
### 12.2. Getting Started
1. **Fork:** Create your copy of the [repository](https://github.com/NousResearch/atropos).
2. **Clone:** `git clone https://github.com/your-username/atropos.git && cd atropos`
3. **Setup Dev Env:**
```bash
python -m venv .venv
source .venv/bin/activate # Windows: .venv\Scripts\activate
pip install -e ".[dev]" # Installs core + dev dependencies
```
4. **Install Pre-commit Hooks:**
```bash
pre-commit install
```
(Runs linters/formatters automatically on commit)
### 12.3. Running Tests
Uses `pytest`.
```bash
pytest
```
Ensure all tests pass before submitting a PR.
### 12.4. How to Contribute
* **Reporting Bugs:** Use the **Bug Report** issue template on GitHub Issues. Provide details: summary, steps to reproduce, expected vs. actual behavior, environment info, error messages/logs.
* **Suggesting Enhancements:** Use the **Feature Request** issue template. Discuss the idea first via an issue.
* **Submitting Changes (Pull Requests):**
1. Create a branch from `main`: `git checkout -b your-branch-name main`
2. Make changes, write code.
3. Add tests if applicable.
4. Update documentation (READMEs, docstrings) if APIs change.
5. Run tests: `pytest`.
6. Ensure code quality (pre-commit hooks run on commit, or run manually: `pre-commit run --all-files`).
7. Commit changes with clear messages: `git commit -m "feat: Describe feature or fix"`
8. Push branch: `git push origin your-branch-name`
9. Open a PR on GitHub from your fork's branch to `NousResearch/atropos:main`.
10. **Use the correct PR template:**
* `environment_pull_request_template.md` for environment changes.
* `non_environment_pull_request_template.md` for other changes.
11. Provide a clear title, description, link relevant issues (e.g., `Closes #123`).
### 12.5. Code Style
* PEP 8 enforced by `black`, `flake8`, `isort` via `pre-commit`.
* Manual check/fix: `pre-commit run --all-files`. Address `flake8` errors manually if needed.
### 12.6. License for Contributions
Contributions are submitted under the Apache License 2.0, the same license as the project.
### 12.7. Environment Contribution Guidelines
* **Legal/GitHub Compliance:** No illegal content. Must comply with GitHub TOS.
* **Explicit Content:** May be considered if clearly labeled and legally compliant.
* **Game Environments:** Welcome, but avoid reverse-engineered commercial games. Ensure rights to assets. Open-source/permissive licenses preferred.
* **Ethical Considerations:** Avoid environments encouraging harm without educational context.
* Discuss potentially controversial environments via an issue first.
### 12.8. Contributor Code of Conduct
Follow the [Contributor Code of Conduct](CODE_OF_CONDUCT.md).
---
## 13. Citation
If Atropos is helpful in your work, please cite:
```latex
@misc{atropos,
title = {{Atropos - An Async First Environment Rollout Controller}},
author = {Dakota Mahan, Roger Jin, Teknium, Shannon Sands, Artem Yatsenko, Jai Suphavadeeprasit, Karan Malhotra, Chen Guang, Joe Li},
url = {https://www.github.com/NousResearch/Atropos},
month = {4},
year = {2025},
version = {0.1},
}
```
*(Note: Year/Version might need updating)*
---
## 14. License
Atropos is licensed under the Apache License 2.0. See the [LICENSE](LICENSE.md) file for details.

58
pyproject.toml Normal file
View file

@ -0,0 +1,58 @@
[project]
name = "atroposlib"
version = "0.1.0"
description = "Atropos: An Environment and Rollout handler for LLM RL"
readme = "README.md"
requires-python = ">=3.10"
dependencies = [
"transformers==4.48.3",
"datasets",
"openai",
"aiohttp",
"tqdm",
"fastapi",
"uvicorn[standard]",
"tenacity",
"numpy",
"wandb",
"math-verify==0.7.0",
"jinja2",
"nltk",
"polars",
"aiofiles",
"jsonlines",
"torch",
"pydantic-cli",
"hf_transfer",
]
[project.scripts]
run-api = "atroposlib.cli.run_api:main"
inference-node-wandb-watcher = "atroposlib.cli.inference_node_wandb_watcher:main"
view-run = "atroposlib.cli.view_run:main"
atropos-sft-gen = "atroposlib.cli.sft:main"
[project.optional-dependencies]
all = [
"atroposlib[dev,examples]"
]
dev = [
"pytest",
"pytest-asyncio",
"pre-commit",
"black",
"flake8",
"isort",
"mypy",
'rich',
]
examples = [
"gradio"
]
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"
[tool.hatch.build.targets.wheel]
packages = ["atroposlib"]

7
pytest.ini Normal file
View file

@ -0,0 +1,7 @@
[pytest]
asyncio_mode = auto
asyncio_default_fixture_loop_scope = function
python_files = test_*.py
python_classes = Test*
python_functions = test_*
testpaths = atroposlib/tests

0
testing/__init__.py Normal file
View file

0
testing/api/__init__.py Normal file
View file

79
testing/api/testing.py Normal file
View file

@ -0,0 +1,79 @@
import pytest
import requests
from testing.api.utils import launch_api_for_testing
def register_data(group="test", proj="test", batch_size=32) -> requests.Response:
x = requests.post(
"http://localhost:8000/register",
json={"wandb_group": group, "wandb_project": proj, "batch_size": batch_size},
)
return x
def post_scored_data(
tokens=((0,),), masks=((0,),), scores=(0,), ref_logprobs=((0,),)
) -> requests.Response:
data = {
"tokens": tokens,
"masks": masks,
"scores": scores,
}
if ref_logprobs is not None:
data["ref_logprobs"] = ref_logprobs
x = requests.post("http://localhost:8000/scored_data", json=data)
return x
def reset() -> requests.Response:
x = requests.get("http://localhost:8000/reset_data")
return x
@pytest.fixture(scope="session")
def api():
proc = launch_api_for_testing()
yield
proc.kill()
def test_register(api):
x = register_data()
assert x.status_code == 200, x.text
data = x.json()
assert "uuid" in data
def test_reset(api):
x = register_data()
assert x.status_code == 200, x.text
data = x.json()
assert "uuid" in data
x = post_scored_data()
assert x.status_code == 200, x.text
x = reset()
print("0-0-0-0-0-0-0-0", flush=True)
print(x.text, flush=True)
print("0-0-0-0-0-0-0-0", flush=True)
assert x.status_code == 200, x.text
x = requests.get("http://localhost:8000/info")
assert x.status_code == 200
assert x.json()["batch_size"] == -1
x = requests.get("http://localhost:8000/status")
assert x.status_code == 200, x.text
data = x.json()
assert data["current_step"] == 0
assert data["queue_size"] == 0
x = requests.get("http://localhost:8000/wandb_info")
assert x.status_code == 200, x.text
data = x.json()
assert data["group"] is None
assert data["project"] is None
def test_batch_size(api):
x = register_data()
assert x.status_code == 200, x.text
# get the batch size
x = requests.get("http://localhost:8000/info")

27
testing/api/utils.py Normal file
View file

@ -0,0 +1,27 @@
import multiprocessing
import time
import requests
from atroposlib.cli.run_api import main as run_api_main
def check_api_running() -> bool:
try:
data = requests.get("http://localhost:8000/info")
return data.status_code == 200
except requests.exceptions.ConnectionError:
return False
def launch_api_for_testing(max_wait_for_api: int = 10) -> multiprocessing.Process:
api_proc = multiprocessing.Process(target=run_api_main)
api_proc.start()
counter = 0
while not check_api_running():
time.sleep(1)
counter += 1
if counter > max_wait_for_api:
raise TimeoutError("API server did not start in time.")
print("API server started for testing.")
return api_proc

8
testing/testing.md Normal file
View file

@ -0,0 +1,8 @@
## Running Tests
This section contains instructions and guidelines for running the test suite.
Please ensure all tests pass before submitting contributions.
We use `pytest` for our testing framework.
Simply run `pytest` from the main directory and you're good to go.

View file

@ -0,0 +1,117 @@
import json
import logging
import os
from transformers import AutoTokenizer
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
MESSAGES = [
{
"role": "system",
"content": "You are a helpful AI assistant that provides accurate information.",
},
{"role": "user", "content": "What's the capital of France?"},
{"role": "assistant", "content": "The capital of France is Paris."},
{"role": "user", "content": "Can you tell me more about Paris?"},
{
"role": "assistant",
"content": "<tool_call>{'tool_name': 'web_search', 'args': {'query': 'Paris'}}</tool_call>",
},
{
"role": "tool",
"content": (
"Paris is the capital and most populous city of France. "
"It has an estimated population of 2,165,423 residents in 2019 "
"in an area of more than 105 km²."
),
},
{
"role": "assistant",
"content": (
"Paris is indeed the capital of France and its most populous city with over 2 million residents. "
"It's known for its iconic landmarks like the Eiffel Tower, Louvre Museum, and Notre-Dame Cathedral. "
"The city is a global center for art, fashion, gastronomy, and culture."
),
},
]
TEST_MASKS_PATH = os.path.join(os.path.dirname(__file__), "test_masks.json")
with open(TEST_MASKS_PATH) as f:
TEST_MASKS = json.load(f)
def test_tokenize_for_trainer_mask_len_last_turn_only():
# random model with chat templates and isn't gated
try:
tok = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B-Instruct")
can_run_stop = True
except (ValueError, EnvironmentError):
can_run_stop = False
tok = AutoTokenizer.from_pretrained("Zyphra/Zamba2-1.2B-instruct")
logging.warning(
"Could not use gated model, using non-gated model that is bad at tokenizing..."
)
messages = [
{"role": "user", "content": "hi"},
{"role": "assistant", "content": "hello"},
]
total_toks = tok.apply_chat_template(messages)
prefix = tok.apply_chat_template(messages[:1], add_generation_prompt=True)
resp = tokenize_for_trainer(tok, messages, False)
assert len(resp["tokens"]) == len(total_toks) == len(resp["masks"])
assert resp["tokens"] == total_toks
assert all([x == -100 for x in resp["masks"][: len(prefix)]])
assert all([x != -100 for x in resp["masks"][len(prefix) :]])
assert resp.get("messages", None is None)
# This time with add messages
messages = [
{"role": "user", "content": "hi"},
{"role": "assistant", "content": "hello"},
]
resp = tokenize_for_trainer(tok, messages, True)
assert resp["tokens"] == total_toks
assert len(resp["tokens"]) == len(total_toks) == len(resp["masks"])
assert all([x == -100 for x in resp["masks"][: len(prefix)]])
assert all([x != -100 for x in resp["masks"][len(prefix) :]])
assert resp["messages"] == messages
if can_run_stop:
# now try with finish reason == stop
messages = [
{"role": "user", "content": "hi"},
{"role": "assistant", "content": "hello"},
]
resp = tokenize_for_trainer(tok, messages, False, finish_reason="length")
assert len(resp["tokens"]) == len(total_toks) - 1 == len(resp["masks"])
assert resp["tokens"] == total_toks[:-1]
assert all([x == -100 for x in resp["masks"][: len(prefix)]])
assert all([x != -100 for x in resp["masks"][len(prefix) :]])
assert resp.get("messages", None is None)
def test_last_turn_only_masking():
"""
Test that in last turn only mode, only the tokens from the final assistant turn are unmasked
(mask != -100) while all tokens from all other messages are masked (mask == -100).
"""
tok = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct")
result = tokenize_for_trainer(
tok, MESSAGES, include_messages=False, train_on_all_assistant_turns=False
)
masks = result["masks"]
assert masks == TEST_MASKS["last_turn_only"]
def test_all_assistant_turns_masking():
"""
Test that in all assistant turns mode, tokens for every assistant message are unmasked,
while tokens for non-assistant messages remain masked.
"""
tok = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct")
result = tokenize_for_trainer(
tok, MESSAGES, include_messages=False, train_on_all_assistant_turns=True
)
masks = result["masks"]
assert masks == TEST_MASKS["all_assistant_turns"]