From 621d00dd80da1718a64bb40d2833acc1c61e4ff5 Mon Sep 17 00:00:00 2001 From: Dakota Nous Date: Tue, 29 Apr 2025 12:10:10 -0700 Subject: [PATCH] first commit --- .github/ISSUE_TEMPLATE/bug_report.md | 60 + .github/ISSUE_TEMPLATE/feature_request.md | 20 + .github/environment_pull_request_template.md | 44 + .../non_environment_pull_request_template.md | 30 + .gitignore | 184 +++ .pre-commit-config.yaml | 25 + CODE_OF_CONDUCT.md | 44 + CONFIG.md | 67 ++ CONTRIBUTING.md | 150 +++ LICENSE | 21 + README.md | 224 ++++ SLURM.md | 175 +++ atroposlib/__init__.py | 0 atroposlib/api/README.md | 182 +++ atroposlib/api/__init__.py | 3 + atroposlib/api/env_interaction.md | 70 ++ atroposlib/api/server.py | 305 +++++ atroposlib/api/trainer_interaction.md | 58 + atroposlib/api/utils.py | 61 + atroposlib/cli/dpo.py | 322 ++++++ .../cli/inference_node_wandb_watcher.py | 66 ++ atroposlib/cli/run_api.py | 9 + atroposlib/cli/sft.py | 318 +++++ atroposlib/cli/view_run.py | 105 ++ atroposlib/envs/README.md | 63 + atroposlib/envs/__init__.py | 0 atroposlib/envs/base.py | 890 ++++++++++++++ atroposlib/envs/reward_fns/__init__.py | 30 + atroposlib/envs/reward_fns/accuracy_reward.py | 296 +++++ .../reward_fns/cascading_r1_math_reward.py | 200 ++++ atroposlib/envs/reward_fns/combined_reward.py | 91 ++ .../envs/reward_fns/cosine_scaled_reward.py | 201 ++++ .../reward_fns/crossword_format_reward.py | 121 ++ atroposlib/envs/reward_fns/format_reward.py | 105 ++ atroposlib/envs/reward_fns/r1_reward.py | 363 ++++++ .../envs/reward_fns/reasoning_steps_reward.py | 138 +++ atroposlib/envs/reward_fns/registry.py | 279 +++++ .../reward_fns/repetition_penalty_reward.py | 286 +++++ atroposlib/envs/reward_fns/reward_function.py | 125 ++ .../envs/server_handling/openai_server.py | 296 +++++ .../envs/server_handling/server_harness.py | 146 +++ .../envs/server_handling/server_manager.py | 208 ++++ atroposlib/tests/test_advantages.py | 169 +++ atroposlib/tests/test_utils/__init__.py | 0 .../test_utils/test_heterogeneous_batching.py | 28 + atroposlib/type_definitions.py | 70 ++ atroposlib/utils/__init__.py | 7 + atroposlib/utils/advantages.py | 173 +++ atroposlib/utils/config_handler.py | 184 +++ atroposlib/utils/force_diverse_samples.py | 112 ++ atroposlib/utils/metrics.py | 19 + atroposlib/utils/tokenize_for_trainer.py | 192 +++ environments/README.md | 129 +++ .../dataset_environment/LOCAL_TESTING.md | 155 +++ environments/dataset_environment/README.md | 355 ++++++ environments/dataset_environment/__init__.py | 11 + .../configs/dataset_local.yaml | 52 + .../dataset_environment/configs/gsm8k.yaml | 73 ++ .../configs/gsm8k_debug.yaml | 30 + .../dataset_environment/dataset_env.py | 407 +++++++ .../dataset_local_server.py | 248 ++++ .../launch_local_dataset_run.py | 154 +++ .../fundamental_prediction_environment.py | 505 ++++++++ environments/gsm8k_server.py | 295 +++++ environments/math_server.py | 1030 +++++++++++++++++ environments/math_server_zero.py | 447 +++++++ environments/mcqa_thinking_env.py | 492 ++++++++ .../multimodal_dpo/clevr_cogen_a_train.py | 282 +++++ environments/multimodal_dpo/clevr_complex.py | 283 +++++ environments/multimodal_dpo/ocr_vqa.py | 200 ++++ environments/multimodal_dpo/pixmo_clocks.py | 202 ++++ environments/multimodal_dpo/pixmo_count.py | 191 +++ .../pixmo_point_explanations.py | 202 ++++ environments/rlaif_server.py | 311 +++++ environments/tool_calling_server.py | 478 ++++++++ example_trainer/README.md | 72 ++ example_trainer/__init__.py | 7 + example_trainer/grpo.py | 548 +++++++++ example_trainer/requirements.txt | 5 + helpers/length_penalties.py | 76 ++ llm.txt | 444 +++++++ pyproject.toml | 58 + pytest.ini | 7 + testing/__init__.py | 0 testing/api/__init__.py | 0 testing/api/testing.py | 79 ++ testing/api/utils.py | 27 + testing/testing.md | 8 + testing/utils/test_tokenize_for_trainer.py | 117 ++ 89 files changed, 15315 insertions(+) create mode 100644 .github/ISSUE_TEMPLATE/bug_report.md create mode 100644 .github/ISSUE_TEMPLATE/feature_request.md create mode 100644 .github/environment_pull_request_template.md create mode 100644 .github/non_environment_pull_request_template.md create mode 100644 .gitignore create mode 100644 .pre-commit-config.yaml create mode 100644 CODE_OF_CONDUCT.md create mode 100644 CONFIG.md create mode 100644 CONTRIBUTING.md create mode 100644 LICENSE create mode 100644 README.md create mode 100644 SLURM.md create mode 100644 atroposlib/__init__.py create mode 100644 atroposlib/api/README.md create mode 100644 atroposlib/api/__init__.py create mode 100644 atroposlib/api/env_interaction.md create mode 100644 atroposlib/api/server.py create mode 100644 atroposlib/api/trainer_interaction.md create mode 100644 atroposlib/api/utils.py create mode 100644 atroposlib/cli/dpo.py create mode 100644 atroposlib/cli/inference_node_wandb_watcher.py create mode 100644 atroposlib/cli/run_api.py create mode 100644 atroposlib/cli/sft.py create mode 100644 atroposlib/cli/view_run.py create mode 100644 atroposlib/envs/README.md create mode 100644 atroposlib/envs/__init__.py create mode 100644 atroposlib/envs/base.py create mode 100644 atroposlib/envs/reward_fns/__init__.py create mode 100644 atroposlib/envs/reward_fns/accuracy_reward.py create mode 100644 atroposlib/envs/reward_fns/cascading_r1_math_reward.py create mode 100644 atroposlib/envs/reward_fns/combined_reward.py create mode 100644 atroposlib/envs/reward_fns/cosine_scaled_reward.py create mode 100644 atroposlib/envs/reward_fns/crossword_format_reward.py create mode 100644 atroposlib/envs/reward_fns/format_reward.py create mode 100644 atroposlib/envs/reward_fns/r1_reward.py create mode 100644 atroposlib/envs/reward_fns/reasoning_steps_reward.py create mode 100644 atroposlib/envs/reward_fns/registry.py create mode 100644 atroposlib/envs/reward_fns/repetition_penalty_reward.py create mode 100644 atroposlib/envs/reward_fns/reward_function.py create mode 100644 atroposlib/envs/server_handling/openai_server.py create mode 100644 atroposlib/envs/server_handling/server_harness.py create mode 100644 atroposlib/envs/server_handling/server_manager.py create mode 100644 atroposlib/tests/test_advantages.py create mode 100644 atroposlib/tests/test_utils/__init__.py create mode 100644 atroposlib/tests/test_utils/test_heterogeneous_batching.py create mode 100644 atroposlib/type_definitions.py create mode 100644 atroposlib/utils/__init__.py create mode 100644 atroposlib/utils/advantages.py create mode 100644 atroposlib/utils/config_handler.py create mode 100644 atroposlib/utils/force_diverse_samples.py create mode 100644 atroposlib/utils/metrics.py create mode 100644 atroposlib/utils/tokenize_for_trainer.py create mode 100644 environments/README.md create mode 100644 environments/dataset_environment/LOCAL_TESTING.md create mode 100644 environments/dataset_environment/README.md create mode 100644 environments/dataset_environment/__init__.py create mode 100644 environments/dataset_environment/configs/dataset_local.yaml create mode 100644 environments/dataset_environment/configs/gsm8k.yaml create mode 100644 environments/dataset_environment/configs/gsm8k_debug.yaml create mode 100644 environments/dataset_environment/dataset_env.py create mode 100644 environments/dataset_environment/dataset_local_server.py create mode 100644 environments/dataset_environment/launch_local_dataset_run.py create mode 100644 environments/fundamental_prediction_environment.py create mode 100644 environments/gsm8k_server.py create mode 100644 environments/math_server.py create mode 100644 environments/math_server_zero.py create mode 100644 environments/mcqa_thinking_env.py create mode 100644 environments/multimodal_dpo/clevr_cogen_a_train.py create mode 100644 environments/multimodal_dpo/clevr_complex.py create mode 100644 environments/multimodal_dpo/ocr_vqa.py create mode 100644 environments/multimodal_dpo/pixmo_clocks.py create mode 100644 environments/multimodal_dpo/pixmo_count.py create mode 100644 environments/multimodal_dpo/pixmo_point_explanations.py create mode 100644 environments/rlaif_server.py create mode 100644 environments/tool_calling_server.py create mode 100644 example_trainer/README.md create mode 100644 example_trainer/__init__.py create mode 100644 example_trainer/grpo.py create mode 100644 example_trainer/requirements.txt create mode 100644 helpers/length_penalties.py create mode 100644 llm.txt create mode 100644 pyproject.toml create mode 100644 pytest.ini create mode 100644 testing/__init__.py create mode 100644 testing/api/__init__.py create mode 100644 testing/api/testing.py create mode 100644 testing/api/utils.py create mode 100644 testing/testing.md create mode 100644 testing/utils/test_tokenize_for_trainer.py diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md new file mode 100644 index 00000000..7d4675d3 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/bug_report.md @@ -0,0 +1,60 @@ +--- +name: Bug report +about: Create a report to help us improve +title: '' +labels: '' +assignees: '' + +--- + +## Describe the Issue + + + +## Environment/API Details + +- **Environment Class/Name:** [e.g., `atroposlib.envs.MyCustomEnv`, `gsm8k_server.py`] +- **Environment Configuration (`BaseEnvConfig` or subclass):** (Optional) + ```yaml + # Paste relevant config values here + group_size: 4 + max_token_length: 2048 + # ... etc. + ``` +- **API Endpoint/Method Involved:** (If applicable) [e.g., `/register_env`, `/get_status`, `env.collect_trajectory()`] + +## Steps to Reproduce + + +1. Initialize environment `...` with config `...` +2. Call method `get_next_item()` and receive `Item` `...` +3. Call method `collect_trajectory()` or `collect_trajectories()` with `Item` `...` +4. Observe issue `...` (e.g., incorrect `ScoredDataGroup`, error during API call) + +## Interaction Details (if applicable) + + +- **Input `Item` to `collect_trajectory`:** + ```python + # Paste relevant Item details here + ``` +- **Output `ScoredDataGroup` (or error):** + ```python + # Paste relevant ScoredDataGroup or traceback here + ``` +- **Expected `ScoredDataGroup` / Behavior:** + +## Setup Details + +- **OS:** [e.g. macOS, Windows, Linux] +- **Python Version:** [e.g. 3.10, 3.11] +- **`Atropos` Version:** [e.g. output of `pip show atropos` or commit hash] +- **Relevant Libraries/Versions:** [e.g., `pydantic==2.5.0`, `aiohttp==3.9.0`, `transformers==4.35.0`] + +## Additional Context & Logs + + + +```log +# Paste relevant logs here +``` diff --git a/.github/ISSUE_TEMPLATE/feature_request.md b/.github/ISSUE_TEMPLATE/feature_request.md new file mode 100644 index 00000000..bbcbbe7d --- /dev/null +++ b/.github/ISSUE_TEMPLATE/feature_request.md @@ -0,0 +1,20 @@ +--- +name: Feature request +about: Suggest an idea for this project +title: '' +labels: '' +assignees: '' + +--- + +**Is your feature request related to a problem? Please describe.** +A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] + +**Describe the solution you'd like** +A clear and concise description of what you want to happen. + +**Describe alternatives you've considered** +A clear and concise description of any alternative solutions or features you've considered. + +**Additional context** +Add any other context or screenshots about the feature request here. diff --git a/.github/environment_pull_request_template.md b/.github/environment_pull_request_template.md new file mode 100644 index 00000000..4e9a30db --- /dev/null +++ b/.github/environment_pull_request_template.md @@ -0,0 +1,44 @@ + + +## 🔖 Environment Snapshot +| Field | Your Entry | +|-------|------------| +| **Environment Name** | | +| **Short Description** | | +| **Category** | | +| **Dataset Needed?** | | +| **External Deps** | | +| **Environmental Variables** | | +| **Expected Episode Length** | | +| **Compute Footprint Estimate** | | + +--- + +## 🧪 Zero-Training Test Results +
+ +**W&B Link:** + +**Examples of the Environment scoring a good example and a bad example:** + +
+ + +## ✅ Developer & Reviewer Checklist +- [ ] Code follows project style (black, isort, flake8 pass with pre-commit). +- [ ] I have performed a self-review of my own code +- [ ] Docstrings added for all new public classes / functions. +- [ ] If .env vars required, did you add it to the .env.example in repo root? +- [ ] Automatic rollout script (`scripts/run_smoke_test.py`) runs without training and reproduces the metrics above. +- [ ] I have commented my code, particularly in hard-to-understand areas +- [ ] I have made corresponding changes to the documentation to include my environment +- [ ] My changes generate no new warnings +- [ ] New and existing unit tests pass locally with my changes + +--- diff --git a/.github/non_environment_pull_request_template.md b/.github/non_environment_pull_request_template.md new file mode 100644 index 00000000..b6b2bfae --- /dev/null +++ b/.github/non_environment_pull_request_template.md @@ -0,0 +1,30 @@ +## Description + + + +## Related Issues + + + +## Type of Change + + + +- [ ] Bug fix (non-breaking change which fixes an issue) +- [ ] New feature (non-breaking change which adds functionality) +- [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected) +- [ ] This change requires a documentation update +- [ ] Code refactor (no functional changes) +- [ ] Build/CI/CD related changes +- [ ] Other (please describe): + +## Checklist + +- [ ] My code follows the style guidelines of this project +- [ ] I have performed a self-review of my own code +- [ ] I have commented my code, particularly in hard-to-understand areas +- [ ] I have made corresponding changes to the documentation +- [ ] My changes generate no new warnings +- [ ] I have added tests that prove my fix is effective or that my feature works +- [ ] New and existing unit tests pass locally with my changes +- [ ] Any dependent changes have been merged and published in downstream modules diff --git a/.gitignore b/.gitignore new file mode 100644 index 00000000..2a1052c0 --- /dev/null +++ b/.gitignore @@ -0,0 +1,184 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# UV +# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +#uv.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/latest/usage/project/#working-with-version-control +.pdm.toml +.pdm-python +.pdm-build/ + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +.idea/ + +# PyPI configuration file +.pypirc + +# data +*.json +*.jsonl + +# wandb +wandb + +# Allow MCP configuration +!configs/mcp.json + +# gradio +.gradio/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 00000000..24e98830 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,25 @@ +repos: +- repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.5.0 + hooks: + - id: trailing-whitespace + - id: end-of-file-fixer + - id: check-yaml + - id: check-added-large-files + +- repo: https://github.com/psf/black + rev: 24.1.1 + hooks: + - id: black + +- repo: https://github.com/pycqa/isort + rev: 5.13.2 + hooks: + - id: isort + args: ["--profile", "black", "--filter-files"] + +- repo: https://github.com/pycqa/flake8 + rev: 7.0.0 + hooks: + - id: flake8 + args: ["--max-line-length=120", "--extend-ignore=E203,W503"] diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md new file mode 100644 index 00000000..974722f5 --- /dev/null +++ b/CODE_OF_CONDUCT.md @@ -0,0 +1,44 @@ +# Code of Conduct + +## Our Pledge + +We, as contributors and maintainers of the Atropos project, pledge to foster an open and welcoming environment. We aim to make participation in our project a positive, respectful, and harassment-free experience for everyone, regardless of background or experience level. + +## Our Standards + +Examples of behavior that contributes to creating a positive environment include: + +* Being respectful of differing viewpoints and experiences +* Gracefully accepting constructive criticism +* Focusing on what is best for the community and the project +* Showing empathy towards other community members + +Examples of unacceptable behavior include: + +* Unwelcome personal attacks or criticism +* Excessive trolling, insulting/derogatory comments, and personal attacks +* Public or private harassment +* Publishing others' private information without explicit permission +* Other conduct which could reasonably be considered inappropriate in a professional setting + +## Our Responsibilities + +Project maintainers are responsible for clarifying the standards of acceptable behavior and are expected to take appropriate and fair corrective action in response to any instances of unacceptable behavior. + +Project maintainers have the right and responsibility to remove, edit, or reject comments, commits, code, issues, and other contributions that are not aligned with this Code of Conduct. + +## Scope + +This Code of Conduct applies within all project spaces, including GitHub repositories, issue trackers, and related communication channels. It also applies when an individual is representing the project or its community in public spaces. + +## Enforcement + +Instances of abusive, harassing, or otherwise unacceptable behavior may be reported by contacting the project maintainers. All complaints will be reviewed and investigated promptly and fairly. + +Project maintainers are obligated to respect the privacy and security of the reporter of any incident. + +## Attribution + +This Code of Conduct is adapted from general open source community standards and GitHub's community guidelines. + +Remember: Respect each other, collaborate constructively, and focus on making Atropos better for everyone. \ No newline at end of file diff --git a/CONFIG.md b/CONFIG.md new file mode 100644 index 00000000..efe069ac --- /dev/null +++ b/CONFIG.md @@ -0,0 +1,67 @@ +# AtroposLib Configuration + +This document outlines the configuration options available for the `atroposlib` library, primarily defined using Pydantic models. +These configurations are often managed via a command-line interface built using `pydantic-cli`, especially when using the `serve` command provided by environment classes inheriting from `BaseEnv`. + +## Base Environment Configuration (`atroposlib.envs.base.BaseEnvConfig`) + +Basic environment configuration settings. + +| Parameter | Type | Default | Description | +| :------------------------------- | :----------------------- | :---------------------------------------------- | :--------------------------------------------------------------------------------------------------------- | +| `group_size` | `int` | `4` | How many responses are grouped together for scoring. | +| `max_num_workers` | `int` | `-1` | Maximum number of workers to use. `-1` calculates from `max_num_workers_per_node`. | +| `max_eval_workers` | `int` | `16` | Maximum number of workers to use for evaluation. | +| `max_num_workers_per_node` | `int` | `8` | Maximum number of workers to use per node. | +| `steps_per_eval` | `int` | `100` | Number of steps to take before evaluating. | +| `max_token_length` | `int` | `2048` | Maximum token length used in generations. | +| `eval_handling` | `EvalHandlingEnum` | `EvalHandlingEnum.STOP_TRAIN` | How to handle evaluations (`STOP_TRAIN`, `LIMIT_TRAIN`, `NONE`). | +| `eval_limit_ratio` | `float` | `0.5` | Ratio of training workers to limit during evals (used if `eval_handling` is `LIMIT_TRAIN`). | +| `inference_weight` | `float` | `1.0` | Inference weight. Set to `-1` to ignore if doing something special. | +| `batch_size` | `int` | `-1` | Batch size for training. Usually set by the trainer via the API. | +| `max_batches_offpolicy` | `int` | `3` | Maximum number of off-policy batches to have in the queue. | +| `tokenizer_name` | `str` | `"NousResearch/DeepHermes-3-Llama-3-1B-Preview"` | Hugging Face tokenizer to use. | +| `use_wandb` | `bool` | `True` | Whether to use Weights & Biases for logging. | +| `rollout_server_url` | `str` | `"http://localhost:8000"` | URL of the rollout server (FastAPI interface). | +| `total_steps` | `int` | `1000` | Total number of steps to run. | +| `wandb_name` | `str | None` | `None` | Name to be grouped by in WandB. | +| `num_rollouts_to_keep` | `int` | `32` | Number of rollouts to display on WandB. | +| `num_rollouts_per_group_for_logging` | `int` | `1` | Number of rollouts per group to keep for logging. `-1` keeps all. | +| `ensure_scores_are_not_same` | `bool` | `True` | Ensure that scores within a group are not identical (usually `True`). | +| `data_path_to_save_groups` | `str | None` | `None` | Path to save generated groups as a JSONL file. If set, groups will be written here. | +| `min_items_sent_before_logging` | `int` | `2` | Minimum number of items sent to the API before logging metrics. `0` or less logs every time. | + +## Server Manager Configuration (`atroposlib.envs.server_handling.server_manager.ServerManagerConfig`) + +Settings for the `ServerManager`. + +| Parameter | Type | Default | Description | +| :-------- | :------ | :------ | :------------------------------------------------ | +| `slurm` | `bool` | `True` | Whether the environment is running on SLURM. | +| `testing` | `bool` | `False` | If `True`, uses mock OpenAI data for testing. | + +## Server Baseline Configuration (`atroposlib.envs.server_handling.server_manager.ServerBaseline`) + +Baseline configuration used by `ServerManager` if a list of `OpenaiConfig` is not provided, particularly for setting up local or SLURM-based server discovery. + +| Parameter | Type | Default | Description | +| :------------------------- | :------ | :-------- | :------------------------------------------------------------------------------------------------------ | +| `timeout` | `int` | `1200` | Timeout for the request in seconds. | +| `num_max_requests_at_once` | `int` | `512` | Maximum number of concurrent requests (training). Divide this by the generation `n` parameter. | +| `num_requests_for_eval` | `int` | `64` | Maximum number of concurrent requests for evaluation. | +| `model_name` | `str` | `default` | Model name to use when calling inference servers. | +| `rolling_buffer_length` | `int` | `1000` | Length of the rolling buffer to store server metrics (like request timings, attempts). | + +## OpenAI Server Configuration (`atroposlib.envs.server_handling.openai_server.OpenaiConfig`) + +Configuration for individual OpenAI-compatible API servers (including local SGLang/vLLM instances). + +| Parameter | Type | Default | Description | +| :------------------------- | :----------- | :-------- | :------------------------------------------------------------------------------------------------------ | +| `api_key` | `str \| None` | `None` | API key for OpenAI API. Use `"x"` or any non-empty string for local servers that don't require auth. | +| `base_url` | `str \| None` | `None` | URL of the API endpoint. `None` for official OpenAI API, otherwise the local server URL (e.g., `http://localhost:9004/v1`). | +| `timeout` | `int` | `1200` | Timeout for the request in seconds. | +| `num_max_requests_at_once` | `int` | `512` | Maximum number of concurrent requests (training). Divide this by the generation `n` parameter. | +| `num_requests_for_eval` | `int` | `64` | Maximum number of concurrent requests for evaluation. | +| `model_name` | `str` | `default` | The model name to use. Required for both OpenAI and local models (e.g., `"gpt-4"`, `"NousResearch/..."`). | +| `rolling_buffer_length` | `int` | `1000` | Length of the rolling buffer to store server metrics (like request timings, attempts). | diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 00000000..0957bb8e --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,150 @@ +# Contributing to Atropos + +First off, thank you for considering contributing to Atropos! It's people like you that make open source projects such great tools. + +We welcome any type of contribution, not just code. You can help with: +* **Reporting a bug** +* **Discussing the current state of the code** +* **Submitting a fix** +* **Proposing new features** +* **Becoming a maintainer** + +## We Develop with GitHub +We use GitHub to host the code, track issues and feature requests, and accept pull requests. + +## We Use GitHub Flow +We follow the [GitHub Flow](https://guides.github.com/introduction/flow/index.html) development workflow. All code changes happen through Pull Requests. + +## Getting Started + +### Project Setup + +1. **Fork the repository:** Click the "Fork" button on the top right of the [repository page](https://github.com/NousResearch/atropos). This creates your own copy of the project. +2. **Clone your fork:** + ```bash + git clone https://github.com/your-username/atropos.git + cd atropos + ``` +3. **Set up the development environment:** This project uses standard Python `venv` for environment creation and `pip` for dependency management. + ```bash + # Ensure you have Python 3.10+ installed + # Create and activate a virtual environment + python -m venv .venv + source .venv/bin/activate # On Windows use `.venv\Scripts\activate` + + # Install dependencies, including development dependencies + pip install -e ".[dev]" + ``` +4. **Install pre-commit hooks:** This project uses `pre-commit` for code quality checks. The hooks will run automatically when you commit changes. + ```bash + pre-commit install + ``` + +### Running Tests + +We use `pytest` for running tests. To run the test suite: + +```bash +pytest +``` + +Ensure all tests pass before submitting a pull request. + +## How to Contribute + +### Reporting Bugs + +We use GitHub issues to track public bugs. Report a bug by [opening a new issue](https://github.com/NousResearch/atropos/issues) (replace with the actual link if different). + +When opening a bug report, please use the **Bug Report** issue template. This template is designed to gather the information we need to efficiently understand and resolve the issue. + +**Great Bug Reports** tend to have: + +* A quick summary and/or background. +* Steps to reproduce the bug: + * Be specific! + * Provide the exact commands run or a minimal code snippet if possible. +* What you expected to happen. +* What actually happened (including any error messages or logs). +* Your environment details (OS, Python version, relevant package versions). +* Notes (possibly including why you think this might be happening, or stuff you tried that didn't work). + +Thorough bug reports help us address issues faster! + +### Suggesting Enhancements + +If you have an idea for a new feature or an improvement to an existing one, please open an issue first to discuss it. This allows us to coordinate efforts and ensure the suggestion aligns with the project's goals. + +When suggesting an enhancement, please use the **Feature Request** issue template. This helps structure your request and provides context for maintainers and the community to better understand your suggestion. + +### Submitting Changes (Pull Requests) + +Pull requests are the best way to propose changes to the codebase. We actively welcome your pull requests: + +1. **Fork the repo** and create your branch from `main`. + ```bash + git checkout -b your-feature-or-fix-branch main + ``` +2. **Make your changes:** Write your code. +3. **Add tests:** If you've added code that should be tested, add tests. +4. **Update documentation:** If you've changed APIs or added features, update relevant documentation (README, docstrings, etc.). +5. **Ensure tests pass:** Run `pytest`. +6. **Ensure code lints and formats:** The pre-commit hooks will run automatically on commit. You can also run them manually: `pre-commit run --all-files`. +7. **Commit your changes:** Use clear and descriptive commit messages that explain the purpose of the changes. + ```bash + git add . + git commit -m "Clearly describe the changes made in this commit" + ``` +8. **Push your branch:** + ```bash + git push origin your-feature-or-fix-branch + ``` +9. **Open a Pull Request (PR):** Go to the original repository on GitHub and open a PR from your fork's branch to the `main` branch. + * Provide a clear title and description for your PR. + * Link any relevant issues (e.g., "Closes #123"). + * Explain the changes you've made and why. + * **Follow the PR template**: We have two PR templates: + - For environment-related changes, use the `environment_pull_request_template.md` + - For all other changes, use the `non_environment_pull_request_template.md` + + Please fill out the appropriate template thoroughly to help reviewers understand your changes. + +## Code Style + +This project uses standard Python code style (PEP 8) enforced by `black`, `flake8`, and `isort` via `pre-commit`. Please ensure your code adheres to these standards. The pre-commit hooks should help automate formatting and linting. + +You can manually run the checks on all files using: +```bash +pre-commit run --all-files +``` +This command will automatically fix formatting issues found by `black` and `isort`. However, you may need to manually address any linting errors reported by `flake8`. + +## License for Contributions +Any contributions you make will be under the Apache License 2.0. In short, when you submit code changes, your submissions are understood to be under the same [Apache License 2.0](LICENSE) that covers the project. Feel free to contact the maintainers if that\'s a concern. + +## Environment Contribution Guidelines + +Since Atropos is focused on reinforcement learning environments, we encourage contributions of new training environments. However, please adhere to the following guidelines: + +* **Legal compliance**: Do not submit environments that involve illegal activities or content. + +* **GitHub compliance**: All contributions must comply with [GitHub's Terms of Service and Community Guidelines](https://docs.github.com/en/site-policy/github-terms/github-terms-of-service). + +* **Explicit content**: Explicit environments may be considered, but must be: + * Clearly labeled as such + * Comply with all legal requirements + +* **Game environments**: Game-based environments are welcome, but: + * Do not submit reverse-engineered commercial game environments that could lead to copyright or intellectual property issues + * Ensure you have the appropriate rights to any assets used + * Open-source games or games with permissive licenses are preferred + +* **Ethical considerations**: Consider the ethical implications of your environment. Environments that encourage harmful behaviors without educational context may be rejected. + +When in doubt about the appropriateness of an environment, please open an issue to discuss it before investing significant development effort. + +## Code of Conduct + +Please note that this project is released with a [Contributor Code of Conduct](CODE_OF_CONDUCT.md). By participating in this project you agree to abide by its terms. + +Thank you again for your contribution! diff --git a/LICENSE b/LICENSE new file mode 100644 index 00000000..75410e73 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2025 Nous Research + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md new file mode 100644 index 00000000..684b0ed7 --- /dev/null +++ b/README.md @@ -0,0 +1,224 @@ +# Atropos - Nous Research's LLM RL Gym + +![newatr-02](https://github.com/user-attachments/assets/e9b64e10-340e-48f2-835c-ae28fa14730a) + +
+ +*In Greek mythology, Atropos was the eldest of the three Fates. While her sisters spun and measured the threads of mortal lives, Atropos alone held the shears that would cut these threads, determining the final destiny of each soul. Just as Atropos guided souls to their ultimate fate, this system guides language models toward their optimal potential through reinforcement learning.* + +
+ +
+
+
+ + HuggingFace + + + Website + + + @NousResearch + +
+ +Atropos is a Language Model Reinforcement Learning Environments framework for collecting and evaluating LLM trajectories through diverse environments including: + +
+ +| Environment Type | Examples | Purpose | +|---------------------------|--------------------------------------------|----------------------------------------------------| +| 📚 Dataset environments | GSM8K, MMLU | Evaluate and improve LLM performance on static data| +| 🎮 Online environments | Crosswords, Hangman | Train LLMs through interactive game-based learning | +| 🤖 RLAIF and RLHF | LLM Judge/Reward Models | Fine-tune LLMs using human feedback and alignment | +| 🔄 Multi-Turn RL | deepresearch, internal tool calling | Train LLMs on complex multi-step interactions | + +
+ +Atropos is a robust, scalable framework for **Reinforcement Learning Environments with LLMs**. Key features: + +- **Multi-Turn & Asynchronous RL:** Efficiently supports complex, multi-turn, and asynchronous interactions, decoupling environment steps from policy updates. +- **Inference Agnostic:** Integrates with standard inference APIs (e.g., OpenAI, vLLM, SGLang), enabling easy switching between LLM providers and frameworks. +- **Trainer Independent:** Offers a standardized training interface for experimenting with different RL algorithms and frameworks without major code changes. +- **Scalable & Decentralized:** Easily scale by launching more environment instances (locally or across decentralized resources) that contribute rollouts to a central service. +- **Diverse Environment Integration:** Manages many varied environment types concurrently for heterogeneous, multi-modal training. + +The goal: provide a flexible, scalable, and standardized platform to accelerate LLM-based RL research across diverse, interactive settings. + +## 🎉 Upcoming Atropos Hackathon: LLM RL Environments + +Join us in San Francisco on May 18th, 2025 for an exciting hackathon focused on building and experimenting with LLM RL Environments! This in-person event will bring together researchers and developers interested in advancing the field of LLM reinforcement learning. + +More details coming soon! Follow us on Twitter [@NousResearch](https://x.com/NousResearch) to stay updated. + + +--- + +## Experimental results from models trained using Atropos' environments + +We have been able to achieve significant improvements on specific domains or tasks with Atropos - Below are some of the results. + +**Tool Calling Environment Results:** + +
+ +| Berkeley Function Calling Benchmark Type | Base Model | With Atropos RL | Improvement | +|---------------|------------|-----------------|-------------| +| Parallel Tasks| 10% | 46% | **4.6x** ⬆️ | +| Simple Tasks | 21% | 51.75% | **2.5x** ⬆️ | + +
+ +Model Artifact: +https://huggingface.co/NousResearch/DeepHermes-ToolCalling-Specialist-Atropos + + +Environment Used: +https://github.com/NousResearch/Atropos/environments/tool_calling_server.py + +--- + +**Financial Fundamentals Prediction Environment Results**: + +
+ +| Metric | Initial Accuracy | With Atropos RL | Improvement | +|--------|-----------------|-----------------|-------------| +| Directional Prediction Eval Accuracy | 20% | 50% | **2.5x** 📈 | + +
+ +Model Artifact: +https://huggingface.co/NousResearch/DeepHermes-Financial-Fundamentals-Prediction-Specialist-Atropos + +Environment Used: +https://github.com/NousResearch/Atropos/environments/fundamental_prediction_environment.py + +--- + +## RLAIF Experiment Artifacts +Using the RLAIF Environment to change the personality of the model, we have produced several artifacts of interesting and weird personalities. + +**DeepHermes Egregore v1 and v2 8B:** + +https://huggingface.co/NousResearch/DeepHermes-Egregore-v1-RLAIF-8b-Atropos +https://huggingface.co/NousResearch/DeepHermes-Egregore-v2-RLAIF-8b-Atropos + +**DeepHermes Ascension Maze 8B:** + +https://huggingface.co/NousResearch/DeepHermes-AscensionMaze-RLAIF-8b-Atropos + +--- + +## Navigating the Repo + +| Category | Description | +|----------|------------| +| 📁 [`atroposlib/`](atroposlib/) | Core library containing base classes and utilities | +| 🎮 [`environments/`](environments/) | Collection of ready-to-use RL environments | +| 📚 [`example_trainer/`](example_trainer/) | Example training scripts and configurations | + +Key Documents: +- [Base Environment Class](atroposlib/envs/README.md) - Documentation for creating custom environments +- [Environments Overview](environments/README.md) - Documentation for existing environments +- [Full Environment Config Options](CONFIG.md) - Documentation for creating custom environments +- [Example Trainer](example_trainer/README.md) - Getting started with training +- [Slurm Guide](SLURM.md) - Guide for using Atropos with Slurm for distributed inference +- [Contributing Guide](CONTRIBUTING.md) - Guidelines for contributors +- [License](LICENSE.md) - Apache 2.0 license details + +--- + +## Installation + +Get your Python 3.10 (or later) environment ready, then simply pip install: + +```bash +pip install atroposlib +``` + +If you're looking to get into developing the repo or using the environments: + + +```bash +pip install -e . # for using +pip install -e .[dev] # for development +pip install -e .[examples] # for running examples +pip install -e .[all] # for everything +``` + +**Important:** If you're committing to the repository, please install the pre-commit hooks: +```bash +pre-commit install +``` + +--- + +### Quick Start Guide + +1. **Create Your First Environment** + - Review our [Base Class Documentation](atroposlib/envs/README.md) to understand the core concepts + - Check out existing environments in the [`environments/`](environments) directory for examples + +2. **Run an Example Environment** + ```bash + # Start the API server and run the GSM8K environment + run-api & python environments/gsm8k_server.py serve \ + --tokenizer_name="Qwen/Qwen2.5-1.5B-Instruct" \ + --model_name="Qwen/Qwen2.5-1.5B-Instruct" + ``` + +3. **Training Your Model** + - Follow our [training example guide](example_trainer/README.md) for detailed instructions + - Monitor progress through our built-in logging and reporting system: + - Completion lengths + - Evaluation accuracies + - Full rollouts and scores + +You can use multiple environments at once, just point them all to the same server. + +Environments come with detailed logging and reporting support, runs track completion lengths, eval accuracies, full rollouts and scores, and more: + +![image](https://github.com/user-attachments/assets/153a2932-191a-42e3-8da9-25a1b05abb8e) + +--- + +## Debugging Tools + +The trajectory-handler provides several debugging tools to help environment developers test and understand their environments locally without requiring the full distributed infrastructure. + +* **Flexible Model Provider Support:** Atropos natively supports any model provider that adheres to the OpenAI API standard. Simply provide the provider's base URL and your API key, and Atropos can integrate with their models seamlessly for testing or running environments locally. + +After launching the API and your selected environments (e.g. `run-api & python environments/gsm8k_server.py serve`), you are then able to view them to get a quick look, or try to prepare some datasets for some offline training: + +* **View Run (`view-run`):** Launch a Gradio UI to inspect batches of rollouts generated by your environment runs. This is useful for visually debugging the interactions and data flow. +* **Offline Data Generation:** Use `atropos-sft-gen` and `atropos-dpo-gen` to collect rollouts from environments and convert them into formats suitable for Supervised Fine-Tuning (SFT) or Direct Preference Optimization (DPO). + +--- + +## Citation + +If you have found the library helpful in your work, you can cite this repository as: + +```latex +@misc{atropos, + title = {{Atropos - An Async First Environment Rollout Controller}}, + author = {Dakota Mahan, Roger Jin, Teknium, Shannon Sands, Artem Yatsenko, Jai Suphavadeeprasit, Karan Malhotra, Chen Guang, Joe Li}, + url = {https://www.github.com/NousResearch/Atropos}, + month = {4}, + year = {2025}, + version = {0.1}, +} +``` + +--- + +## Contributing + +Atropos is built by the open-source AI community, and relies on our amazing contributors! Please see our [contributing](CONTRIBUTING.md) guide for more details on our code formatting, testing, etc. +Please follow the [Code of Conduct](CODE_OF_CONDUCT.md). + +--- + +## License +Atropos is licensed as Apache 2.0, see the [LICENSE](LICENSE.md) file here for more information diff --git a/SLURM.md b/SLURM.md new file mode 100644 index 00000000..bd070e29 --- /dev/null +++ b/SLURM.md @@ -0,0 +1,175 @@ +## Using `ServerManager` with Slurm + +The `ServerManager` class in `atroposlib` provides built-in support for discovering and managing inference servers distributed across nodes allocated by Slurm. Here's how to use it: + +**Core Concept:** + +The setup assumes you have a Slurm job allocation where: +1. One or more nodes are designated for your main "training" or orchestrator process (the script that initializes `ServerManager`). +2. The remaining nodes in the allocation are dedicated to running the LLM inference servers (e.g., SGLang, TGI, vLLM, etc., accessible via an OpenAI-compatible API). + +**How `ServerManager` Detects Servers:** + +When you initialize `ServerManager` with `slurm=True`: +1. It reads the `SLURM_JOB_NODELIST` environment variable to get the hostnames of all allocated nodes. It uses the `scontrol show hostnames` command internally. +2. It reads the `NUM_TRAINING_NODES` environment variable. This crucial variable tells the manager how many nodes *at the beginning* of the nodelist are *reserved for the training/orchestrator process* and should **not** be treated as inference server nodes. +3. It iterates through the hostnames *after* the first `NUM_TRAINING_NODES`. These are assumed to be the inference nodes. +4. For each inference node, it constructs potential server URLs. By default, it assumes: + * Servers run on ports starting from `9000` (`9000`, `9001`, `9002`, ...). + * The number of server instances per node is determined by `8 // INFER_TP` (where `INFER_TP` is another environment variable, defaulting to 1 if not set, implying 8 servers per node). You should set `INFER_TP` according to your inference server's tensor parallelism configuration if applicable. + * The URL format is `http://{node_hostname}:{port}/v1`. +5. It uses the *first* configuration object you pass in the `configs` list as a template (for settings like `timeout`, `num_max_requests_at_once`, etc.) and creates specific `OpenaiConfig` objects for each discovered URL. +6. The `ServerManager` then load-balances requests across these automatically configured `OpenAIServer` instances. + +**Setup Steps:** + +1. **Launch Inference Servers:** In your Slurm submission script (`sbatch`), launch your inference server instances on the designated inference nodes. + * Ensure they listen on the correct hostname and the expected ports (9000, 9001, ...). + * The number of instances per node should match the `8 // INFER_TP` logic. Adjust the port range or `INFER_TP` environment variable accordingly if your setup differs. + * You might use `srun` to launch these processes on specific nodes. +2. **Set Environment Variables:** In the part of your Slurm script that launches your *main application* (the one using `ServerManager`): + * `export NUM_TRAINING_NODES=` (e.g., `export NUM_TRAINING_NODES=1` if only the first node runs the main script). + * `export INFER_TP=` (Optional, defaults to 1. Set this if your inference servers use tensor parallelism and you run fewer than 8 instances per node). +3. **Initialize `ServerManager`:** In your Python script: + ```python + from atroposlib.envs.server_handling.server_manager import ServerManager, ServerBaseline, OpenaiConfig + + # Provide at least one config object. It will be used as a template + # for Slurm-discovered servers if slurm=True. + # If you pass ServerBaseline, ensure NUM_TRAINING_NODES and potentially INFER_TP are set. + # If you pass a list of OpenaiConfig, the first one is used as the template. + base_config = ServerBaseline( + timeout=1200, + # other baseline settings... + ) + # OR + # base_config = OpenaiConfig( + # base_url="http://dummy", # This URL is ignored when slurm=True finds nodes + # api_key="dummy", + # timeout=1200, + # # other config settings... + # ) + + server_manager = ServerManager( + configs=base_config, # Or [base_config] if using OpenaiConfig + slurm=True + ) + + # Now use server_manager.chat_completion(...) or server_manager.completion(...) + ``` +4. **Submit Slurm Job:** Submit your job ensuring the necessary nodes and resources (like GPUs for inference) are requested. + +**Example Conceptual Slurm Script:** + +```bash +#!/bin/bash +#SBATCH --nodes=5 # 1 trainer node + 4 inference nodes +#SBATCH --ntasks-per-node=1 +#SBATCH --gpus-per-node=8 # Assuming 8 GPUs/node for inference +#SBATCH --job-name=atropos-rl + +# Get allocated node hostnames +nodes=$(scontrol show hostnames $SLURM_JOB_NODELIST) +nodes_array=($nodes) +training_node=${nodes_array[0]} +inference_nodes=${nodes_array[@]:1} # Nodes from index 1 onwards + +echo "Training Node: $training_node" +echo "Inference Nodes: ${inference_nodes[@]}" + +# --- Launch Inference Servers (Example using srun, adapt for your server type) --- +TP_SIZE=1 # Example: Tensor Parallelism = 1 +INSTANCES_PER_NODE=$((8 / TP_SIZE)) + +echo "Launching $INSTANCES_PER_NODE inference servers per node..." + +for node in ${inference_nodes[@]}; do + for i in $(seq 0 $((INSTANCES_PER_NODE - 1))); do + port=$((9000 + i)) + gpu_id=$i # Basic GPU assignment, might need refinement + echo "Starting server on $node:$port (GPU $gpu_id)" + srun --nodes=1 --ntasks=1 --gpus-per-task=1 --gpu-bind=map_gpu:$gpu_id --nodelist=$node \ + your_inference_server_launch_cmd --host 0.0.0.0 --port $port --tp $TP_SIZE [other_args] & + done +done + +echo "Waiting for servers to start..." +sleep 60 # Simple wait, consider a more robust check + +# --- Launch W&B Watcher on each Inference Node --- +echo "Launching W&B watchers..." +# Assume the main API server runs on the training_node at default port 8000 +TRAINER_API_ADDR="http://${training_node}:8000" + +inference_node_index=0 # Start index for node_num +for node in ${inference_nodes[@]}; do + echo "Starting watcher on $node (Node Index $inference_node_index)" + srun --nodes=1 --ntasks=1 --nodelist=$node \ + python atroposlib/cli/inference_node_wandb_watcher.py \ + --api_addr $TRAINER_API_ADDR \ + --tp $TP_SIZE \ + --node_num $inference_node_index & + inference_node_index=$((inference_node_index + 1)) +done + +# --- Launch Main Application on the Training Node --- +export NUM_TRAINING_NODES=1 +export INFER_TP=$TP_SIZE + +echo "Starting main application on $training_node..." +srun --nodes=1 --ntasks=1 --nodelist=$training_node \ + python your_main_atropos_script.py --some_arg=value + +echo "Job finished." +wait # Wait for background server processes launched with '&' +``` + +**Important Notes:** + +* This setup relies on the `scontrol` command being available in the environment where `ServerManager` is initialized. +* Ensure network connectivity and firewall rules allow the training node(s) to reach the inference nodes on ports 9000+. +* The logic assumes a specific port assignment (9000+) and server count based on `INFER_TP`. If your inference server setup differs (e.g., different ports, different discovery mechanism), you would need to modify `server_manager.py` or manually provide the correct list of `OpenaiConfig` objects instead of relying on `slurm=True`. + +## Monitoring Inference Nodes with Weights & Biases + +Atropos includes a utility script, `inference-node-wandb-watcher`, located in `atroposlib/cli/`, designed to run on each inference node alongside the inference servers. + +**Purpose:** + +* **Health Monitoring:** Periodically checks the `/health_generate` endpoint of each local inference server instance (assuming ports 9000+). +* **W&B Logging:** Logs the health status (1 for healthy, 0 for unhealthy) of each server instance to a shared Weights & Biases run group. This allows you to visualize server uptime and availability directly in your W&B dashboard alongside your training metrics. +* **Step Synchronization:** It fetches the current training step from the main Atropos API server (`run-api`) to ensure W&B logs are correctly associated with training progress. + +**Integration into Slurm Script:** + +You can launch this watcher on each inference node using `srun` similarly to how the inference servers are launched. Add the following section to the example Slurm script, **after** launching the inference servers and **before** launching the main application: + +```bash +# --- Launch W&B Watcher on each Inference Node --- +echo "Launching W&B watchers..." +# Assume the main API server runs on the training_node at default port 8000 +TRAINER_API_ADDR="http://${training_node}:8000" + +inference_node_index=0 # Start index for node_num +for node in ${inference_nodes[@]}; do + echo "Starting watcher on $node (Node Index $inference_node_index)" + srun --nodes=1 --ntasks=1 --nodelist=$node \ + python atroposlib/cli/inference_node_wandb_watcher.py \ + --api_addr $TRAINER_API_ADDR \ + --tp $TP_SIZE \ + --node_num $inference_node_index & + inference_node_index=$((inference_node_index + 1)) +done +``` + +**Explanation of Arguments:** + +* `--api_addr`: This is the address of the main Atropos API server (usually started with `run-api`). The script needs this to fetch W&B project/group info and the current training step. In the example, we construct it assuming the API runs on the `training_node` (first node in the allocation) at port `8000` (the default for `run-api`). **Ensure this port is correct and accessible from the inference nodes.** +* `--tp`: This should be the same tensor parallelism size (`TP_SIZE`) used when launching the inference servers. It tells the watcher how many server instances (ports 9000 to 9000 + `8 // TP_SIZE` - 1) to monitor on the local node. +* `--node_num`: A unique integer identifying this specific inference node within the Slurm job. This helps distinguish the metrics from different nodes in W&B (e.g., `server/server_heath_0_0`, `server/server_heath_1_0`). The example script assigns sequential indices starting from 0. + +**Important Notes:** + +* Ensure the `run-api` server is running and accessible from the inference nodes. +* The `inference-node-wandb-watcher` script should be executable and accessible from the inference nodes. +* The script assumes the default port for the `run-api` server (8000). If your setup uses a different port, you may need to modify the script or the port in the `TRAINER_API_ADDR` construction. diff --git a/atroposlib/__init__.py b/atroposlib/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/atroposlib/api/README.md b/atroposlib/api/README.md new file mode 100644 index 00000000..3e719ee3 --- /dev/null +++ b/atroposlib/api/README.md @@ -0,0 +1,182 @@ +# Trajectory Handler API + +## Overview + +The AtroposLib API is a FastAPI application designed to act as a central buffer and aggregator for reinforcement learning (RL) experience data. Its primary purpose is to decouple RL data generation (by "Rollout Handlers" or "Environments") from RL data consumption (by one or more "Trainers"), particularly in distributed online RL settings. + +This service specifically handles the **experience data pathway**: + +* Rollout Handlers connect and push trajectories (tokens, masks, scores, etc.). +* The API buffers this data in a queue. +* Trainers connect and pull processed batches of experience data for training updates. + +**Important:** This service does *not* handle the distribution of updated policies from the Trainer back to the Rollout Handlers/Inference Servers. That part of the online RL loop is assumed to be handled by a separate mechanism. + +## Features + +* Centralized, in-memory queue for RL trajectory data. +* Registration endpoints for Trainers and Rollout Handlers. +* Serves batches of aggregated experience data to Trainers. +* Supports heterogeneous environments with weighting (via `/register-env` weight and internal batching). +* Provides status endpoints for monitoring queue size and training step count. +* Basic integration with Weights & Biases (W&B) project/group info. +* Endpoints for Rollout Handlers to disconnect gracefully. +* Debug endpoint to retrieve the latest submitted data sample. + +## Architecture Context + +This API typically sits within a larger RL system: + +1. **Rollout Handlers:** Instances simulating the environment. They interact with Inference Servers to get actions based on the current policy and send resulting trajectory data (`ScoredData`) to this AtroposLib API (`/scored_data`). +2. **Inference Servers (External):** Serve the current policy (e.g., via an OpenAI-compatible API). Receive policy updates directly from the Trainer. *Not part of this service.* +3. **AtroposLib API (This Service):** Buffers and batches experience data received from Rollout Handlers. +4. **Trainer(s):** Pull batches of experience data from this API (`/batch`), compute gradients, update the policy, and push updated policies directly to the Inference Servers. + + +## Running the Server + +with the repository installed we provide a helper script to run the server: + +```bash +run-api +``` +if you need more control over the server you can run it directly with: + +```bash +uvicorn atroposlib.api.server:app --host 0.0.0.0 --port 8000 --reload +``` + +* `--host 0.0.0.0`: Makes the server accessible on your network. +* `--port 8000`: Specifies the port (change if needed). +* `--reload`: Enables auto-reloading on code changes (for development). Remove for production. + +The API documentation (Swagger UI) will be available at `http://:8000/docs`. + +## API Endpoints + +### General + +* `GET /` + * **Description:** Root endpoint for basic health check. + * **Response:** `{"message": "AtroposLib API"}` + +### Trainer Registration & Info + +* `POST /register` + * **Description:** Called once by the Trainer process to initialize the server state for a training run. Resets state if called again. + * **Request Body:** `Registration` model + ```python + class Registration(BaseModel): + wandb_group: str + wandb_project: str + batch_size: int + max_token_len: int # Max token length expected in trajectories + checkpoint_dir: str # Shared location for checkpoints + save_checkpoint_interval: int + starting_step: int + num_steps: int # Total expected training steps + ``` + * **Response:** `{"uuid": }` +* `GET /wandb_info` + * **Description:** Retrieve W&B group and project info set during registration. + * **Response:** `{"group": , "project": }` +* `GET /info` + * **Description:** Retrieve batch size and max token length set during registration. + * **Response:** `{"batch_size": , "max_token_len": }` +* `GET /status` + * **Description:** Get the current training step (based on batches served) and queue size. + * **Response:** `{"current_step": , "queue_size": }` + +### Rollout Handler Registration & Info + +* `POST /register-env` + * **Description:** Called by each Rollout Handler instance to register itself. + * **Request Body:** `RegisterEnv` model + ```python + class RegisterEnv(BaseModel): + max_token_length: int # Max length this env produces + desired_name: str # Base name for identification/logging + weight: float # Weight for sampling/batching (e.g., 1.0) + ``` + * **Response:** Provides assigned ID, unique W&B name, checkpoint info. + ```json + { + "status": "success", + "env_id": , + "wandb_name": , + "checkpoint_dir": , + "starting_step": , + "checkpoint_interval": , + "num_steps": + } + ``` +* `POST /disconnect-env` + * **Description:** Allows a Rollout Handler to signal it's disconnecting gracefully. + * **Request Body:** `EnvIdentifier` model `{"env_id": }` + * **Response:** `{"status": "success"}` or `{"status": "failure", "error": ...}` +* `GET /status-env` + * **Description:** Called by a Rollout Handler to get general status plus its calculated sampling weight relative to other connected environments. + * **Query Parameter:** Requires `env: EnvIdentifier` model (e.g., `?env_id=0` - actual implementation might differ slightly, check FastAPI docs for query parameter models). **Note:** The code shows `env: EnvIdentifier` as a body parameter for a GET request, which is non-standard. This might need adjustment or testing. Assuming it works via query or a POST instead. + * **Response:** `{"current_step": , "queue_size": , "env_weight": }` + +### Data Handling + +* `POST /scored_data` + * **Description:** Endpoint for Rollout Handlers to push a single chunk of trajectory data. + * **Request Body:** `ScoredData` model + ```python + class ScoredData(BaseModel): + tokens: List[List[int]] + masks: List[List[int]] + scores: List[float] + ref_logprobs: Optional[List[List[float]]] = None + overrides: Optional[List[dict]] = None # Per-item logging overrides + group_overrides: Optional[dict] = None # Group logging overrides + ``` + * **Response:** `{"status": "received"}` +* `POST /scored_data_list` + * **Description:** Endpoint for Rollout Handlers to push a list of `ScoredData` chunks. + * **Request Body:** `List[ScoredData]` + * **Response:** `{"status": "received", "groups_processed": }` +* `GET /batch` + * **Description:** Called by the Trainer to request a batch of data for training. The server uses internal logic (`grab_exact_from_heterogeneous_queue`) to form a batch of the configured size from the available data in the queue, potentially respecting environment weights. The server increments its internal step counter when a batch is successfully formed and returned. + * **Response:** + * Success: `{"batch": [, ..., ]}` where each `data_item` matches the structure pushed via `/scored_data`. + * Not enough data: `{"batch": null}` +* `GET /latest_example` + * **Description:** Debug endpoint to retrieve the most recently added `ScoredData` item. + * **Response:** The last `ScoredData` dictionary pushed, or empty lists if none yet. + +### Debugging + +* `GET /reset_data` + * **Description:** **Warning:** Resets all server state, including the queue, configuration, registered environments, and step count. Use with caution during development/debugging. + * **Response:** Plain text `Reset successful` with HTTP status 200. + +## Common Workflow Example + +1. **Start Server:** Launch the `AtroposLib` API server. +2. **Trainer Initialization:** The main Trainer process sends a `POST /register` request with run parameters. +3. **Rollout Handler Initialization:** Each Rollout Handler starts and sends `POST /register-env`. +4. **Data Generation:** Handlers run simulations, collect data, and send `POST /scored_data` or `POST /scored_data_list` periodically. +5. **Training Loop:** + * The Trainer (e.g., Rank 0 in distributed setup) enters a loop: + * Calls `GET /batch`. + * If `batch` is not `null`: + * (Distribute batch to other ranks if applicable). + * Perform training step. + * Optionally call `GET /status` for monitoring. + * If `batch` is `null`: + * Wait briefly (`time.sleep`) and retry `GET /batch`. + * mermaid diagram of how a trainer interacts with the api is located [here](trainer_interaction.md). + * (In distributed setups, other ranks (1..N-1) might poll `GET /status` to wait for the step counter to increment before expecting the broadcasted batch from Rank 0). + * The envs periodically poll `GET /status-env` to check their status and sampling weight. + * In asynchronous setups, they may stop at a maximum off-policy step count. + * mermaid diagram of how a rollout handler interacts with the api is located [here](env_interaction.md). +6. **Shutdown:** Handlers may call `POST /disconnect-env`. + +## Limitations & TODOs + +* **In-Memory State:** The primary limitation is that all queues, configurations, and states are stored in the FastAPI application's memory (`app.state`). + * **No Persistence:** Data is lost if the server restarts. + * **Scalability Bottleneck:** API cannot scale beyond a single server instance easily. diff --git a/atroposlib/api/__init__.py b/atroposlib/api/__init__.py new file mode 100644 index 00000000..36fb5e3c --- /dev/null +++ b/atroposlib/api/__init__.py @@ -0,0 +1,3 @@ +from .server import app + +__all__ = ["app"] diff --git a/atroposlib/api/env_interaction.md b/atroposlib/api/env_interaction.md new file mode 100644 index 00000000..c4d63c09 --- /dev/null +++ b/atroposlib/api/env_interaction.md @@ -0,0 +1,70 @@ +```mermaid +sequenceDiagram + participant RH as Rollout Handler + participant API as AtroposLib API + + %% --- Initialization --- + RH->>API: POST /register-env (Send env details) + activate API + API-->>RH: Response (env_id, starting_step, wandb_name, ...) %% wandb_name is unique to this handler + deactivate API + Note over RH: Store env_id and unique wandb_name. + + Note over RH: Fetch W&B configuration (Assumes Trainer already called /register) + RH->>API: GET /wandb_info + activate API + API-->>RH: Response {"group": wb_group, "project": wb_project} + deactivate API + Note over RH: Initialize wandb logging (e.g., wandb.init) using group=wb_group, project=wb_project, name=wandb_name. + + Note over RH: Know target batch_size (from config?). Set off_policy_tolerance (e.g., 3). Set internal state = 'Running'. + + loop Simulation Loop + + %% --- Check Pause State & Generate/Send Data --- + alt State is 'Running' + Note over RH: Generating data using internal environment logic... + %% (Internal simulation steps, action selection, etc., happen here - details are opaque to the API) + Note over RH: Trajectory chunk collected (contains tokens, masks, scores...). Log env-specific metrics to wandb (e.g., episode reward, length). + + %% --- Send Data --- + RH->>API: POST /scored_data or /scored_data_list (Send collected chunk) + activate API + API-->>RH: Ack {"status": "received", ...} + deactivate API + else State is 'Paused' + Note over RH: Currently paused, skipping data generation and sending. Will check status again. + %% Implement delay/sleep here to avoid busy-checking status when paused + end + + + %% --- Periodic Queue Size Check (Pause/Resume Logic) --- + Note over RH: Checking API queue status to decide pause/resume state. + RH->>API: GET /status-env (using stored env_id) + activate API + API-->>RH: Response {"current_step": T_current, "queue_size": Q, "env_weight": W} + deactivate API + Note over RH: T_current might be logged or used for other internal reasons by the handler. Log queue size Q? + + Note over RH: Calculate threshold = off_policy_tolerance * batch_size + alt Check if queue size exceeds threshold (Q > threshold) + Note over RH: Queue size (Q = Q) > threshold. Setting internal state to 'Paused'. + opt State was 'Running' + Note over RH: Stopping data generation. Log pause event to wandb. + end + else Queue size is acceptable (Q <= threshold) + Note over RH: Queue size (Q = Q) <= threshold. Ensuring state is 'Running'. + opt State was 'Paused' + Note over RH: Resuming data generation. Log resume event to wandb. + end + end + + end %% End Simulation Loop + + %% --- Optional Shutdown --- + RH->>API: POST /disconnect-env (using stored env_id) + activate API + API-->>RH: Ack {"status": "success"} + deactivate API + Note over RH: Finalize wandb logging (wandb.finish). +``` diff --git a/atroposlib/api/server.py b/atroposlib/api/server.py new file mode 100644 index 00000000..8cda4abd --- /dev/null +++ b/atroposlib/api/server.py @@ -0,0 +1,305 @@ +import time +import uuid +from typing import Any, List, Optional + +from fastapi import FastAPI, status +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import PlainTextResponse +from pydantic import BaseModel + +from atroposlib.api.utils import grab_exact_from_heterogeneous_queue + +app = FastAPI(title="AtroposLib API") + +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + + +@app.get("/") +async def root(): + return {"message": "AtroposLib API"} + + +class Registration(BaseModel): + wandb_group: str + wandb_project: str + batch_size: int + max_token_len: int + checkpoint_dir: str + save_checkpoint_interval: int + starting_step: int + num_steps: int + + +class RegisterEnv(BaseModel): + max_token_length: int + desired_name: str + weight: float + + +class EnvIdentifier(BaseModel): + env_id: int + + +class ScoredData(BaseModel): + tokens: List[List[int]] + masks: List[List[int]] + scores: List[float] + ref_logprobs: Optional[List[List[float]]] = None + overrides: Optional[List[dict]] = None + group_overrides: Optional[dict] = None + images: Optional[Any] = None + + +class Status(BaseModel): + """ + basemodel for status information of the current server + """ + + current_step: int + queue_size: int + + +class Info(BaseModel): + """ + basemodel for useful information + """ + + batch_size: int = -1 + + +@app.post("/register") +async def register(registration: Registration): + try: + isinstance(app.state.queue, list) + except AttributeError: + app.state.queue = [] + app.state.group = registration.wandb_group + app.state.project = registration.wandb_project + app.state.batchsize = int(registration.batch_size) + app.state.max_token_len = int(registration.max_token_len) + app.state.status_dict = {"step": registration.starting_step} + app.state.checkpoint_dir = registration.checkpoint_dir + app.state.save_checkpoint_interval = registration.save_checkpoint_interval + app.state.num_steps = registration.num_steps + app.state.curr_batch = [] + app.state.started = False + app.state.envs = [] + try: + app.state.requesters.append(uuid.uuid4().int) + except AttributeError: + # If requesters doesn't exist, create it + app.state.requesters = [uuid.uuid4().int] + return {"uuid": app.state.requesters[-1]} + + +@app.post("/register-env") +async def register_env_url(register_env: RegisterEnv): + try: + isinstance(app.state.envs, list) + except AttributeError: + app.state.envs = [] + checkpoint_dir = "" + try: + checkpoint_dir = app.state.checkpoint_dir + except AttributeError: + pass + real_name = ( + f"{register_env.desired_name}_" + f"{len([x for x in app.state.envs if x['desired_name'] == register_env.desired_name])}" + ) + registered_id = len(app.state.envs) + app.state.envs.append( + { + "max_context_len": register_env.max_token_length, + "weight": register_env.weight if register_env.weight is not None else 1.0, + "desired_name": register_env.desired_name, + "real_name": real_name, + "registered_id": registered_id, + "last_update": time.time(), + "connected": True, + } + ) + return { + "status": "success", + "env_id": registered_id, + "wandb_name": real_name, + "checkpoint_dir": checkpoint_dir, + "starting_step": app.state.status_dict["step"], + "checkpoint_interval": app.state.save_checkpoint_interval, + "num_steps": app.state.num_steps, + } + + +@app.post("/disconnect-env") +async def disconnect_env(disconnect_env: EnvIdentifier): + try: + app.state.envs[disconnect_env.env_id]["connected"] = False + return {"status": "success"} + except (AttributeError, IndexError) as e: + return {"status": "failure", "error": str(e)} + + +@app.get("/wandb_info") +async def wandb_info(): + try: + return {"group": app.state.group, "project": app.state.project} + except AttributeError: + return {"group": None, "project": None} + + +@app.get("/info") +async def info(): + try: + return { + "batch_size": app.state.batchsize, + "max_token_len": app.state.max_token_len, + } + except AttributeError: + return {"batch_size": -1, "max_token_len": -1} + + +@app.get("/batch") +async def get_batch(): + if not app.state.started: + app.state.started = True + + if len(app.state.curr_batch) > 0: + return {"batch": app.state.curr_batch.pop()} + else: + new_batches = [] + batch, app.state.queue = grab_exact_from_heterogeneous_queue( + app.state.queue, app.state.batchsize + ) + while batch is not None: + new_batches.append(batch) + batch, app.state.queue = grab_exact_from_heterogeneous_queue( + app.state.queue, app.state.batchsize + ) + steps_to_take = len(new_batches) + if steps_to_take == 0: + return {"batch": None} + app.state.status_dict["step"] += steps_to_take + # chunk it + for batch in new_batches: + app.state.curr_batch.append(batch) + curr_batch = app.state.curr_batch.pop() + # check length before sending + print(f"Sending batch of length {sum(len(x['tokens']) for x in curr_batch)}") + return {"batch": curr_batch} + + +@app.get("/latest_example") +async def get_latest_example(): + try: + return app.state.latest + except AttributeError: + return { + "tokens": [], + "masks": [], + "scores": [], + "ref_logprobs": [], + "images": [], + } + + +@app.post("/scored_data") +async def scored_data(scored_data: ScoredData): + app.state.queue.append( + { + "tokens": scored_data.tokens, + "masks": scored_data.masks, + "scores": scored_data.scores, + "ref_logprobs": scored_data.ref_logprobs, + "overrides": scored_data.overrides, + "group_overrides": scored_data.group_overrides, + "images": scored_data.images, + } + ) + app.state.latest = app.state.queue[-1] + return {"status": "received"} + + +@app.post("/scored_data_list") +async def scored_data_list(scored_data_list: List[ScoredData]): + """Handle a list of ScoredData objects for step-based learning""" + + for idx, scored_data in enumerate(scored_data_list): + + app.state.queue.append( + { + "tokens": scored_data.tokens, + "masks": scored_data.masks, + "scores": scored_data.scores, + "ref_logprobs": scored_data.ref_logprobs, + "images": scored_data.images, + } + ) + + if scored_data_list: + app.state.latest = app.state.queue[-1] + + return {"status": "received", "groups_processed": len(scored_data_list)} + + +@app.get("/status") +async def get_status(): + try: + return { + "current_step": app.state.status_dict["step"], + "queue_size": len(app.state.queue), + } + except AttributeError: + return {"current_step": 0, "queue_size": 0} + + +@app.get("/status-env") +async def get_status_env(env: EnvIdentifier): + total = sum( + [ + x["max_context_len"] * max(0.0, x["weight"]) + for x in app.state.envs + if x["connected"] + ] + ) + env_weight = ( + app.state.envs[env.env_id]["max_context_len"] + * app.state.envs[env.env_id]["weight"] + / total + ) + env_weight = max( + 0.01, env_weight + ) # Minimum weight of 0.01 :) TODO: try to figure out a better way to do this + + try: + ret_dict = { + "current_step": app.state.status_dict["step"], + "queue_size": len(app.state.queue), + } + except AttributeError: + ret_dict = {"current_step": 0, "queue_size": 0} + ret_dict["env_weight"] = env_weight + return ret_dict + + +@app.get("/reset_data") +async def reset_data(): + try: + del app.state.queue + app.state.group = None + app.state.project = None + app.state.batchsize = -1 + app.state.num_steps = -1 + app.state.status_dict = {"step": 0} + app.state.curr_batch = [] + app.state.started = False + app.state.requesters = [] + app.state.envs = [] + except KeyError: + pass + return PlainTextResponse("Reset successful", status_code=status.HTTP_200_OK) diff --git a/atroposlib/api/trainer_interaction.md b/atroposlib/api/trainer_interaction.md new file mode 100644 index 00000000..13779ab9 --- /dev/null +++ b/atroposlib/api/trainer_interaction.md @@ -0,0 +1,58 @@ +```mermaid +sequenceDiagram + participant R0 as Trainer Rank 0 + participant R1N as Trainer Rank 1..N-1 + participant API as AtroposLib API + + R0->>API: POST /register (send Registration data) + activate API + API-->>R0: Respond with {'uuid': trainer_uuid} + deactivate API + Note over R0, R1N: Initialization complete. Trainer begins requesting data + + loop Training Steps + %% --- Phase 2: Rank 0 fetches batch, others wait/poll --- + par Fetch vs Poll + loop While Batch is Null: + R0->>API: GET /batch + activate API + + Note over API: Checks queue, potentially increments step counter if batch is formed. + + alt Batch Available + API-->>R0: {'batch': [data_item_1, ...]} + Note over R0: Received batch for step S+1. Breaking loop. + else No Batch Available + API-->>R0: {'batch': null} + Note over R0: No batch ready yet. Will retry. + end + deactivate API + end + and + Note over R1N: Poll status until step increments from S. + loop While Server Step is S + R1N->>API: GET /status + activate API + API-->>R1N: {'current_step': S_new, 'queue_size': Q_new} + deactivate API + Note over R1N: Checking if S_new > S... (Current S_new = S_new) + %% In implementation, add delay here if S_new == S to avoid busy-wait + end + Note over R1N: Detected step incremented (S_new > S). Ready for broadcast. + end + + %% --- Phase 3: Handle result --- + Note over R0: Broadcasts received batch data to Ranks 1..N-1 (External Mechanism) + Note over R1N: Receives broadcasted data from Rank 0. + Note over R0, R1N: All ranks now have the same batch for step S+1. + + %% --- Phase 4: Perform Training Step --- + par Perform Training + R0->>R0: Perform training step with batch data + and + R1N->>R1N: Perform training step with batch data + end + Note over R0, R1N: Training step S+1 complete. + + end # End Training Steps Loop +``` diff --git a/atroposlib/api/utils.py b/atroposlib/api/utils.py new file mode 100644 index 00000000..c2fef67c --- /dev/null +++ b/atroposlib/api/utils.py @@ -0,0 +1,61 @@ +from typing import Dict, List, Optional, Tuple + + +def grab_exact_from_heterogeneous_queue( + queue: List[Dict[str, List]], batch_size: int +) -> Tuple[Optional[List], List]: + """ + Grabs a batch of size batchsize from a queue of different sized items + + e.g. queue = [{"tokens": [[1, 2, 3],[4, 5, 6, 7, 8]]}, {"tokens": [[9, 10]]}] + + without going over the batchsize. This function will return a batch of size batchsize, and the new queue. + + Because all groups are a common denominator of the batchsize, and all groups are a power of 2, + we can simplify a bit by assuming we can grab groups of groups to be equal to the maximum group size. + Note that we cannot drop items from groups, so we must grab the entire group if we grab it. + + There may be a more efficient clearing mechanism by grouping these smaller groups heterogeneously, but + forcing them all into powers of two groups is a simple way to ensure we can grab a batch of the correct size. + + :param queue: + :param batch_size: + :return: batch, new_queue + """ + # check if we can even potentially grab a batch + if sum(len(item["tokens"]) for item in queue) < batch_size: + return None, queue + # Get max batch size + max_group_size = max(len(group["tokens"]) for group in queue) + group_sizes = set(len(group["tokens"]) for group in queue) + group_batching_storage = {i: [] for i in group_sizes} + # pack the groups into [max_group_size // group_size] packs + potential_batch = [] + for i, item in enumerate(queue): + key = len(item["tokens"]) + group_batching_storage[key].append({"group": item, "indx": i}) + if len(group_batching_storage[key]) * key == max_group_size: + potential_batch.extend(group_batching_storage[key]) + group_batching_storage[key] = [] + if ( + sum(len(grouped_items["group"]["tokens"]) for grouped_items in potential_batch) + < batch_size + ): + return None, queue + # we have a batch + batch = [] + indxes_to_remove_from_queue = [] + for item in potential_batch: + group = item["group"] + indx = item["indx"] + batch.append(group) + indxes_to_remove_from_queue.append(indx) + if sum(len(item["tokens"]) for item in batch) == batch_size: + break + if sum(len(item["tokens"]) for item in batch) != batch_size: + return None, queue + # remove the items from the queue + new_queue = [ + item for i, item in enumerate(queue) if i not in indxes_to_remove_from_queue + ] + return batch, new_queue diff --git a/atroposlib/cli/dpo.py b/atroposlib/cli/dpo.py new file mode 100644 index 00000000..bd409c18 --- /dev/null +++ b/atroposlib/cli/dpo.py @@ -0,0 +1,322 @@ +import argparse +import asyncio +import os +import random + +import aiohttp +import jsonlines +from tqdm.asyncio import tqdm # Import tqdm for async +from transformers import AutoTokenizer + + +def find_common_prefix(strings): + """ + Finds the longest common prefix among a list of strings. + + Args: + strings: A list of strings. + + Returns: + The longest common prefix string, or an empty string if the list is empty + or no common prefix exists. + """ + if not strings: + return "" + + prefix = strings[0] + for s in strings[1:]: + while not s.startswith(prefix): + prefix = prefix[:-1] + if not prefix: + return "" + return prefix + + +async def register_to_api(group_size, max_token_len, api_url, num_steps): + """ + Registers this data grabber instance with the Atropos API. + + This involves resetting any previous data on the server and then sending + configuration parameters for the current session. + + Args: + group_size: The number of sequences processed per group by the API. + max_token_len: The maximum token length for sequences. + api_url: The base URL of the Atropos API server. + num_steps: The number of steps to run the API for. + """ + async with aiohttp.ClientSession() as session: + # Reset data on the API server before registering + async with session.get(f"{api_url}/reset_data") as response: + print(await response.text()) + # Register this instance with its configuration + async with session.post( + f"{api_url}/register", + json={ + "wandb_group": "test", + "wandb_project": "test", + "batch_size": group_size * 8, + "max_token_len": max_token_len, + "checkpoint_dir": "checkpoints", + "save_checkpoint_interval": 10, + "starting_step": 0, + "num_steps": num_steps * 2, # For a bit of a buffer just in case + }, + ) as response: + print("output of register is") + print(await response.text()) + + +async def check_for_batch(api_url): + """ + Continuously polls the Atropos API until a batch of data is available. + + Args: + api_url: The base URL of the Atropos API server. + + Returns: + The batch data received from the API. + """ + while True: + async with aiohttp.ClientSession() as session: + async with session.get(f"{api_url}/batch") as response: + data = await response.json() + if data["batch"] is not None: + return data["batch"] + await asyncio.sleep(1) # Wait before polling again + + +def grab_group_data( + tok, + datagroup, + save_messages, + save_n_pairs_per_group, + allow_negative_scores=False, + minimum_score_diff_max_min=0.0, +): + """ + Processes a single group of data received from the API. + + This function sorts the sequences within the group by score, filters them + based on scoring criteria, and formats them for saving. + + Args: + tok: The Hugging Face tokenizer instance. + datagroup: A dictionary representing a group of sequences and their scores. + save_messages: Boolean indicating whether to save raw message structures + or decoded text completions. + save_n_pairs_per_group: The maximum number of sequences to save from this group. + allow_negative_scores: Boolean indicating whether to allow sequences with + negative scores. + minimum_score_diff_max_min: The minimum score difference required to save a pair. + + Returns: + A list of processed and filtered sequences from the group, ready to be + written to the output file. + """ + if save_messages: + chats = datagroup["messages"] + else: + chats = [tok.decode(chat) for chat in datagroup["tokens"]] + # find common prefix + prefix = find_common_prefix(chats) + chats = [(prefix, chat.split(prefix)[1]) for chat in chats] + # sort chats by scores + scores = datagroup["scores"] + sorted_chats = [ + ( + {"prefix": x[0], "pos": x[1], "score": score} + if not save_messages + else {"pos": x, "score": score} + ) + for score, x in sorted( + zip(scores, chats), key=lambda pair: pair[0], reverse=True + ) + ] + neg_sorted_chats = [ + ( + {"prefix": x[0], "completion": x[1], "score": score} + if not save_messages + else {"messages": x, "score": score} + ) + for score, x in sorted( + zip(scores, chats), key=lambda pair: pair[0], reverse=False + ) + ] + neg_sorted_chats = neg_sorted_chats[:save_n_pairs_per_group] + if not allow_negative_scores: + sorted_chats = [x for x in sorted_chats if x["score"] > 0] + total_pairs = [] + for i in range(min(save_n_pairs_per_group, len(sorted_chats))): + neg_candidates = [ + x + for x in neg_sorted_chats + if x["score"] < sorted_chats[i]["score"] - minimum_score_diff_max_min + ] + if len(neg_candidates) > 0: + if save_n_pairs_per_group > 0: + neg_candidate = random.choice(neg_candidates) + else: + neg_candidate = neg_sorted_chats[0] # worst negative candidate + # remove from neg_sorted_chats + neg_sorted_chats.remove(neg_candidate) + sorted_chats[i]["neg"] = ( + neg_candidate["completion"] + if "completion" in neg_candidate + else neg_candidate["messages"] + ) + total_pairs.append(sorted_chats[i]) + return total_pairs + + +async def dpo_data_grabber( + filepath, + api_url, + group_size, + max_token_len, + tokenizer, + save_messages, + save_n_pairs_per_group, + num_seqs_to_save, + allow_negative_scores, + minimum_score_diff_max_min, + append_to_previous, +): + """ + Main asynchronous function to grab DPO data from the Atropos API. + + It registers with the API, continuously fetches batches of data, processes + each batch, and writes the selected sequences to a JSONL file until the + desired number of sequences is saved. + + Args: + filepath: Path to the output JSONL file. + api_url: Base URL of the Atropos API server. + group_size: Number of sequences processed per group by the API. + max_token_len: Maximum token length for sequences. + tokenizer: Hugging Face tokenizer model ID. + save_messages: Whether to save raw messages or decoded text. + save_n_pairs_per_group: Max sequences to save per group. + num_seqs_to_save: Total number of sequences to save. + allow_negative_scores: Whether to allow negative scores. + minimum_score_diff_max_min: Min score difference from group minimum. + append_to_previous: Whether to append to an existing file or overwrite. + """ + tok = AutoTokenizer.from_pretrained(tokenizer) + total_count = 0 + + async def grab_batch(jsonl_writer: jsonlines.Writer): + data = await check_for_batch(api_url) + count = 0 + for group in data: + for item in grab_group_data( + tok, + group, + save_messages, + save_n_pairs_per_group, + allow_negative_scores, + minimum_score_diff_max_min, + ): + jsonl_writer.write(item) + count += 1 + return count + + await register_to_api(group_size, max_token_len, api_url) + if os.path.exists(filepath) and not append_to_previous: + raise ValueError("File already exists and append_to_previous is False.") + with open(filepath, "w" if not append_to_previous else "a") as f: + jsonl_writer = jsonlines.Writer(f) + with tqdm(total=num_seqs_to_save, desc="Grabbing DPO data", unit="seq") as pbar: + while total_count < num_seqs_to_save: + batch_count = await grab_batch(jsonl_writer) + total_count += batch_count + pbar.update(min(batch_count, num_seqs_to_save - total_count)) + + +def main(): + parser = argparse.ArgumentParser( + description="Grab SFT data from an Atropos API instance." + ) + parser.add_argument( + "filepath", + type=str, + default="sft_data.jsonl", + help="Path to the output JSONL file for SFT data.", + ) + parser.add_argument( + "--api-url", + type=str, + default="http://localhost:8000", + help="Base URL for the Atropos API server.", + ) + parser.add_argument( + "--group-size", + type=int, + default=2, + help="Number of sequences processed per group by the API.", + ) + parser.add_argument( + "--max-token-len", + type=int, + default=2048, + help="Maximum token length for sequences.", + ) + parser.add_argument( + "--tokenizer", + type=str, + default="NousResearch/DeepHermes-3-Llama-3-8B-Preview", + help="Hugging Face tokenizer model ID (used if --save-messages is not set).", + ) + parser.add_argument( + "--save-messages", + action="store_true", + help="Save raw message structures instead of decoded text completions, if your environment supports it.", + ) + parser.add_argument( + "--save-n-pairs-per-group", + type=int, + default=3, + help="Maximum number of paired sequences to save from each group.", + ) + parser.add_argument( + "--num-seqs-to-save", + type=int, + default=100, + help="Total number of sequences to save before stopping.", + ) + parser.add_argument( + "--allow-negative-scores", + action="store_true", + help="Allow sequences with negative scores to be saved.", + ) + parser.add_argument( + "--minimum-score-diff-max-min", + type=float, + default=0.5, + help="Minimum score difference from the group minimum required to save a sequence.", + ) + parser.add_argument( + "--append-to-previous", + action="store_true", + help="Append to the previous file instead of overwriting it.", + ) + args = parser.parse_args() + asyncio.run( + dpo_data_grabber( + args.filepath, + args.api_url, + args.group_size, + args.max_token_len, + args.tokenizer, + args.save_messages, + args.save_n_pairs_per_group, + args.num_seqs_to_save, + args.allow_negative_scores, + args.minimum_score_diff_max_min, + args.append_to_previous, + ) + ) + + +if __name__ == "__main__": + main() diff --git a/atroposlib/cli/inference_node_wandb_watcher.py b/atroposlib/cli/inference_node_wandb_watcher.py new file mode 100644 index 00000000..2ef77520 --- /dev/null +++ b/atroposlib/cli/inference_node_wandb_watcher.py @@ -0,0 +1,66 @@ +import argparse +import time + +import requests + +import wandb + + +def update_wandb(health_statuses): + wandb.log(health_statuses) + + +def run(api_addr, tp, node_num): + print(f"Starting up with {api_addr}, {tp}, {node_num}", flush=True) + while True: + try: + data = requests.get(f"{api_addr}/wandb_info").json() + wandb_group = data["group"] + wandb_project = data["project"] + except (requests.exceptions.ConnectionError, requests.exceptions.Timeout): + wandb_project = None + wandb_group = None + print("Waiting for init...") + + if wandb_project is None: + time.sleep(1) + else: + wandb.init( + project=wandb_project, group=wandb_group, name=f"inf_node_{node_num}" + ) + break + curr_step = 0 + health_statuses = { + f"server/server_heath_{node_num}_{i}": 0.0 for i in range(8 // tp) + } + while True: + data = requests.get(f"{api_addr}/status").json() + step = data["current_step"] + if step > curr_step: + wandb.log(health_statuses, step=step) + curr_step = step + time.sleep(60) + # Check on each server + for i in range(8 // tp): + try: + health_status = requests.get( + f"http://localhost:{9000 + i}/health_generate" + ).status_code + health_statuses[f"server/server_heath_{node_num}_{i}"] = ( + 1 if health_status == 200 else 0 + ) + except (requests.exceptions.ConnectionError, requests.exceptions.Timeout): + health_statuses[f"server/server_heath_{node_num}_{i}"] = 0 + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--api_addr", type=str, required=True) + parser.add_argument("--tp", type=int, required=True) + parser.add_argument("--node_num", type=int, required=True) + args = parser.parse_args() + run(args.api_addr, args.tp, args.node_num) + + +if __name__ == "__main__": + main() diff --git a/atroposlib/cli/run_api.py b/atroposlib/cli/run_api.py new file mode 100644 index 00000000..5a04aa07 --- /dev/null +++ b/atroposlib/cli/run_api.py @@ -0,0 +1,9 @@ +import uvicorn + + +def main(): + uvicorn.run("atroposlib.api:app", host="0.0.0.0", port=8000, reload=True) + + +if __name__ == "__main__": + main() diff --git a/atroposlib/cli/sft.py b/atroposlib/cli/sft.py new file mode 100644 index 00000000..b58badb1 --- /dev/null +++ b/atroposlib/cli/sft.py @@ -0,0 +1,318 @@ +import argparse +import asyncio +import os + +import aiohttp +import jsonlines +from tqdm.asyncio import tqdm # Import tqdm for async +from transformers import AutoTokenizer + + +def find_common_prefix(strings): + """ + Finds the longest common prefix among a list of strings. + + Args: + strings: A list of strings. + + Returns: + The longest common prefix string, or an empty string if the list is empty + or no common prefix exists. + """ + if not strings: + return "" + + prefix = strings[0] + for s in strings[1:]: + while not s.startswith(prefix): + prefix = prefix[:-1] + if not prefix: + return "" + return prefix + + +async def register_to_api(group_size, max_token_len, api_url, num_steps): + """ + Registers this data grabber instance with the Atropos API. + + This involves resetting any previous data on the server and then sending + configuration parameters for the current session. + + Args: + group_size: The number of sequences processed per group by the API. + max_token_len: The maximum token length for sequences. + api_url: The base URL of the Atropos API server. + num_steps: The number of steps to run the API for. + """ + async with aiohttp.ClientSession() as session: + # Reset data on the API server before registering + async with session.get(f"{api_url}/reset_data") as response: + print(await response.text()) + # Register this instance with its configuration + async with session.post( + f"{api_url}/register", + json={ + "wandb_group": "test", + "wandb_project": "test", + "batch_size": group_size * 8, + "max_token_len": max_token_len, + "checkpoint_dir": "checkpoints", + "save_checkpoint_interval": 10, + "starting_step": 0, + "num_steps": num_steps * 2, # For a bit of a buffer just in case + }, + ) as response: + print("output of register is") + print(await response.text()) + + +async def check_for_batch(api_url): + """ + Continuously polls the Atropos API until a batch of data is available. + + Args: + api_url: The base URL of the Atropos API server. + + Returns: + The batch data received from the API. + """ + while True: + async with aiohttp.ClientSession() as session: + async with session.get(f"{api_url}/batch") as response: + data = await response.json() + if data["batch"] is not None: + return data["batch"] + await asyncio.sleep(1) # Wait before polling again + + +def grab_group_data( + tok, + datagroup, + save_messages, + save_top_n_per_group, + allow_negative_scores=False, + minimum_score_diff_max_min=0.0, +): + """ + Processes a single group of data received from the API. + + This function sorts the sequences within the group by score, filters them + based on scoring criteria, and formats them for saving. + + Args: + tok: The Hugging Face tokenizer instance. + datagroup: A dictionary representing a group of sequences and their scores. + save_messages: Boolean indicating whether to save raw message structures + or decoded text completions. + save_top_n_per_group: The maximum number of sequences to save from this group. + allow_negative_scores: Boolean indicating whether to allow sequences with + negative scores. + minimum_score_diff_max_min: The minimum score difference from the group's + minimum score required to save a sequence. + + Returns: + A list of processed and filtered sequences from the group, ready to be + written to the output file. + """ + if save_messages: + # Use raw message structures if specified + chats = datagroup["messages"] + else: + # Decode tokens into text and find common prefix/completion pairs + chats = [tok.decode(chat) for chat in datagroup["tokens"]] + prefix = find_common_prefix(chats) + # Split each chat into (prefix, completion) + chats = [(prefix, chat.split(prefix)[1]) for chat in chats] + scores = datagroup["scores"] + # Sort chats by score in descending order + sorted_chats = [ + ( + {"prefix": x[0], "completion": x[1], "score": score} + if not save_messages + else {"messages": x, "score": score} + ) + for score, x in sorted( + zip(scores, chats), key=lambda pair: pair[0], reverse=True + ) + ] + + # Apply filtering based on score criteria + if not allow_negative_scores: + sorted_chats = [x for x in sorted_chats if x["score"] > 0] + if minimum_score_diff_max_min > 0: + # Ensure the score is sufficiently higher than the minimum score in the group + min_score = min(scores) if scores else 0 # Handle empty scores list + sorted_chats = [ + x + for x in sorted_chats + if x["score"] - min_score > minimum_score_diff_max_min + ] + + # Return only the top N sequences + return sorted_chats[:save_top_n_per_group] + + +async def sft_data_grabber( + filepath, + api_url, + group_size, + max_token_len, + tokenizer, + save_messages, + save_top_n_per_group, + num_seqs_to_save, + allow_negative_scores, + minimum_score_diff_max_min, + append_to_previous, +): + """ + Main asynchronous function to grab SFT data from the Atropos API. + + It registers with the API, continuously fetches batches of data, processes + each batch, and writes the selected sequences to a JSONL file until the + desired number of sequences is saved. + + Args: + filepath: Path to the output JSONL file. + api_url: Base URL of the Atropos API server. + group_size: Number of sequences processed per group by the API. + max_token_len: Maximum token length for sequences. + tokenizer: Hugging Face tokenizer model ID. + save_messages: Whether to save raw messages or decoded text. + save_top_n_per_group: Max sequences to save per group. + num_seqs_to_save: Total number of sequences to save. + allow_negative_scores: Whether to allow negative scores. + minimum_score_diff_max_min: Min score difference from group minimum. + append_to_previous: Whether to append to an existing file or overwrite. + """ + tok = AutoTokenizer.from_pretrained(tokenizer) + total_count = 0 + + async def grab_batch(jsonl_writer: jsonlines.Writer): + """Fetches and processes one batch of data, returning the count.""" + data = await check_for_batch(api_url) + count = 0 + for group in data: + for item in grab_group_data( + tok, + group, + save_messages, + save_top_n_per_group, + allow_negative_scores, + minimum_score_diff_max_min, + ): + jsonl_writer.write(item) + count += 1 + return count + + # Register with the API first + await register_to_api(group_size, max_token_len, api_url, num_steps=total_count) + + # Check for file existence before opening + if os.path.exists(filepath) and not append_to_previous: + raise ValueError( + f"File '{filepath}' already exists and --append-to-previous is False." + ) + + # Open the file in write or append mode + file_mode = "a" if append_to_previous and os.path.exists(filepath) else "w" + with open(filepath, file_mode) as f: + jsonl_writer = jsonlines.Writer(f) + # Use tqdm for progress bar + with tqdm(total=num_seqs_to_save, desc="Grabbing SFT data", unit="seq") as pbar: + while total_count < num_seqs_to_save: + batch_count = await grab_batch(jsonl_writer) + total_count += batch_count + pbar.update(min(batch_count, num_seqs_to_save - total_count)) + + +def main(): + """Parses command-line arguments and runs the SFT data grabber.""" + parser = argparse.ArgumentParser( + description="Grab SFT data from an Atropos API instance." + ) + parser.add_argument( + "filepath", + type=str, + default="sft_data.jsonl", + help="Path to the output JSONL file for SFT data.", + ) + parser.add_argument( + "--api-url", + type=str, + default="http://localhost:8000", + help="Base URL for the Atropos API server.", + ) + parser.add_argument( + "--group-size", + type=int, + default=2, + help="Number of sequences processed per group by the API.", + ) + parser.add_argument( + "--max-token-len", + type=int, + default=2048, + help="Maximum token length for sequences.", + ) + parser.add_argument( + "--tokenizer", + type=str, + default="NousResearch/DeepHermes-3-Llama-3-8B-Preview", + help="Hugging Face tokenizer model ID (used if --save-messages is not set).", + ) + parser.add_argument( + "--save-messages", + action="store_true", + help="Save raw message structures instead of decoded text completions, if your environment supports it.", + ) + parser.add_argument( + "--save-top-n-per-group", + type=int, + default=3, + help="Maximum number of highest-scoring sequences to save from each group.", + ) + parser.add_argument( + "--num-seqs-to-save", + type=int, + default=100, + help="Total number of sequences to save before stopping.", + ) + parser.add_argument( + "--allow-negative-scores", + action="store_true", + help="Allow sequences with negative scores to be saved.", + ) + parser.add_argument( + "--minimum-score-diff-max-min", + type=float, + default=0.0, + help="Minimum score difference from the group minimum required to save a sequence.", + ) + parser.add_argument( + "--append-to-previous", + action="store_true", + help="Append to the previous file instead of overwriting it.", + ) + args = parser.parse_args() + + # Run the main async function + asyncio.run( + sft_data_grabber( + args.filepath, + args.api_url, + args.group_size, + args.max_token_len, + args.tokenizer, + args.save_messages, + args.save_top_n_per_group, + args.num_seqs_to_save, + args.allow_negative_scores, + args.minimum_score_diff_max_min, + args.append_to_previous, + ) + ) + + +if __name__ == "__main__": + main() diff --git a/atroposlib/cli/view_run.py b/atroposlib/cli/view_run.py new file mode 100644 index 00000000..42462355 --- /dev/null +++ b/atroposlib/cli/view_run.py @@ -0,0 +1,105 @@ +import argparse +import asyncio + +import aiohttp +import gradio as gr +from transformers import AutoTokenizer + + +def find_common_prefix(strings): + if not strings: + return "" + + prefix = strings[0] + for s in strings[1:]: + while not s.startswith(prefix): + prefix = prefix[:-1] + if not prefix: + return "" + return prefix + + +async def register_to_api(group_size, max_token_len): + async with aiohttp.ClientSession() as session: + async with session.get("http://localhost:8000/reset_data") as response: + print(await response.text()) + print(group_size) + async with session.post( + "http://localhost:8000/register", + json={ + "wandb_group": "test", + "wandb_project": "test", + "batch_size": group_size + * 8, # * 8 just in case you want to just sample from a large group + "max_token_len": max_token_len, + "checkpoint_dir": "checkpoints", + "save_checkpoint_interval": 10, + "starting_step": 0, + "num_steps": 69, + }, + ) as response: + print("output of register is") + print(await response.text()) + + +async def check_for_batch(): + while True: + async with aiohttp.ClientSession() as session: + async with session.get("http://localhost:8000/batch") as response: + data = await response.json() + print(data) + if data["batch"] is not None: + return data["batch"] + await asyncio.sleep(1) + + +async def build_interface(group_size, max_token_len, tokenizer, port): + async def grab_batch(): + tok = AutoTokenizer.from_pretrained(tokenizer) + data = await check_for_batch() + print(data) + chats = [tok.decode(chat) for chat in data[0]["tokens"]] + + # find common prefix + prefix = find_common_prefix(chats) + return ( + (prefix,) + + tuple([chat.split(prefix)[1] for chat in chats[:group_size]]) + + tuple(data[0]["scores"][:group_size]) + ) + + with gr.Blocks() as demo: + prefix_blk = gr.Textbox(label="Prefix") + with gr.Row(): + score_blks = [gr.Textbox(label=f"Score_{i+1}") for i in range(group_size)] + with gr.Row(): + outputs_blks = [ + gr.Textbox(label=f"Output_{i+1}") for i in range(group_size) + ] + with gr.Row(): + grab_next = gr.Button(value="Grab Next Batch") + grab_next.click( + fn=grab_batch, + outputs=[prefix_blk] + outputs_blks + score_blks, + api_name="get_batch", + ) + await register_to_api(group_size, max_token_len) + demo.launch(server_port=port, share=True) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--port", type=int, default=9001) + parser.add_argument("--group-size", type=int, default=2) + parser.add_argument("--max-token-len", type=int, default=2048) + parser.add_argument( + "--tokenizer", type=str, default="NousResearch/DeepHermes-3-Llama-3-8B-Preview" + ) + args = parser.parse_args() + asyncio.run( + build_interface(args.group_size, args.max_token_len, args.tokenizer, args.port) + ) + + +if __name__ == "__main__": + main() diff --git a/atroposlib/envs/README.md b/atroposlib/envs/README.md new file mode 100644 index 00000000..1d4d0c68 --- /dev/null +++ b/atroposlib/envs/README.md @@ -0,0 +1,63 @@ +# Base Environment (`BaseEnv`) + +The `BaseEnv` class (located in `trajectoryhandler/envs/base.py`) provides a foundation for creating custom reinforcement learning environments that interact with Atropos. When creating your own environment, you will typically subclass `BaseEnv` and implement several key methods. + +## Core Methods to Implement + +These methods **must** be implemented in your subclass: + +* **`async def setup(self)`**: This method is called once at the beginning of the environment's lifecycle (`env_manager`). Use it for any initial setup required for your specific environment, such as loading datasets, initializing models, or connecting to external resources. + +* **`async def get_next_item(self) -> Item`**: This method is responsible for generating or retrieving the next piece of data (prompt, state, etc.) that will be used to start a new trajectory collection. If no more items are available or should be generated, it can return `None` to signal the worker to pause. + +* **`async def collect_trajectory(self, item: Item) -> Tuple[Any | None, List[Item]]`**: This method defines the logic for a *single* logical trajectory collection step based on the input `item`. \ + * **How it relates to multiple generations**: The `BaseEnv` uses `collect_trajectories` to run this method multiple times in parallel (controlled by `group_size`) to gather a batch of trajectories. \ + * **Your implementation**: You can implement this method to generate *one* response/trajectory per call.\ + * **Return value**: It returns a tuple containing:\ + 1. The collected data for this step (one trajectory). This data can be processed further in `postprocess_histories`, if you require additional filtering right before sending to the API.\ + 2. A list of new `Item` objects to be added to the backlog for future processing (e.g., follow-up prompts).\ + +* **`async def evaluate(self, *args, **kwargs)`**: This method is called periodically (controlled by `steps_per_eval` in the config) to perform evaluation runs. You define the evaluation logic here. The base class provides an example using `self.eval_workers` for parallel evaluation tasks, but you can implement any evaluation procedure suitable for your environment. + +## Optional Methods to Override + +These methods have default implementations or are optional based on your needs: + +* **`async def collect_trajectories(self, item: Item) -> Tuple[Union[Optional[ScoredDataGroup], List[Optional[ScoredDataGroup]], List[Any | None]], List[Item]]`**: The default implementation of this method runs `collect_trajectory` multiple times in parallel (controlled by `group_size`). You can override this *instead* of `collect_trajectory` if you have a more efficient way to generate the entire group of responses/trajectories at once based on the input `item`. It should return the collected group data and a list of backlog items. + +* **`async def postprocess_histories(self, trajectories: Union[Optional[ScoredDataGroup], List[Optional[ScoredDataGroup]]]) -> Union[Optional[ScoredDataGroup], List[Optional[ScoredDataGroup]]]`**: This method is called after `collect_trajectories` and before the data is sent to the training server. It receives the collected data from the parallel runs (or your custom `collect_trajectories` implementation). Use this to perform final processing, scoring, or formatting you may require before sending to the server. You usually won't need this. + +* **`async def wandb_log(self, wandb_metrics: Optional[Dict] = None)`**: Called periodically to log metrics to Weights & Biases. If you override this to add custom metrics, **ensure you call `super().wandb_log(wandb_metrics)`** at the end of your implementation. This ensures that the base class's performance metrics and rollout tables are also logged. + ```python + async def wandb_log(self, wandb_metrics: Optional[Dict] = None): + if wandb_metrics is None: + wandb_metrics = {} + # Add your custom metrics + wandb_metrics['my_custom_metric'] = calculate_my_metric() + # ... add more metrics + + # Call the parent method to log base metrics + await super().wandb_log(wandb_metrics) + ``` + +* **`save_checkpoint(self, step, data=None)`**: The base class calls this method automatically at checkpoint intervals determined by the server. It saves the provided `data` dictionary (which you might populate with environment-specific state) to a JSON file. You can override this to customize *what* data is saved or *how* it's saved (e.g., using a different format or location), but the triggering mechanism remains automatic. + +* **`@classmethod config_init(cls) -> Tuple[BaseEnvConfig, Union[ServerBaseline, List[OpenaiConfig]]]`**: This class method is used by the default `get_cli_serve_config_cls` implementation to get the initial environment configuration (`BaseEnvConfig` subclass) and server configurations (`ServerBaseline` or `List[OpenaiConfig]`) when setting up the `serve` command. The default implementation returns `cls.env_config_cls(), ServerBaseline()`. You might override this if your environment requires different default configurations or specific server setups (like multiple `OpenaiConfig` instances) when run via the CLI `serve` command. + +* **`async def cleanup(self)`**: Called after each call to `handle_env`. You can implement this for any cleanup needed after processing a single item, though it's often not required. + +## Provided Functionality + +`BaseEnv` provides several helpful features: + +* **Parallel Trajectory Collection (`collect_trajectories`)**: The base implementation runs your `collect_trajectory` method multiple times in parallel (based on `group_size`) and gathers the results. You can override `collect_trajectories` directly for custom group generation logic (see Optional Methods). +* **Server Interaction**: Handles registration with the rollout server, fetching configuration (like `batch_size`), sending scored data (`handle_send_to_api` with retries), and status updates. +* **WandB Integration**: Sets up WandB logging (if enabled) based on server information and provides the `wandb_log` hook for custom metrics (remember to call `super().wandb_log()`). It uses helper methods `add_rollouts_for_wandb` (to temporarily store rollout data) and `create_rollout_table` (to format the data into a `wandb.Table`). You can override either of these helpers for custom logging behavior (e.g., changing what data is stored or how the final table is structured). +* **Checkpointing**: + * The environment automatically triggers checkpoint saves based on the `checkpoint_interval` received from the server, calling the `save_checkpoint` method (see Optional Methods). + * `load_checkpoint(self)`: Loads data from the checkpoint file corresponding to the environment's `curr_step`. It attempts to restore attributes of the environment object based on the keys in the loaded JSON data. This is called automatically if `curr_step > 0` during registration. +* **Worker Management**: Manages asynchronous worker tasks for collecting trajectories (`add_train_workers`, `handle_env`). +* **Performance Monitoring**: Tracks and logs various performance statistics (task durations, worker counts, etc.). +* **CLI Integration**: Provides a `cli()` class method using `pydantic-cli` to easily create command-line interfaces for your environment (e.g., `python your_env_module.py serve --port 8001 ...`). See `get_cli_serve_config_cls` and `get_cli_process_config_cls`. + +By implementing the required methods and optionally overriding others, you can create diverse environments that leverage the distributed training infrastructure provided by the `Atropos` framework. diff --git a/atroposlib/envs/__init__.py b/atroposlib/envs/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/atroposlib/envs/base.py b/atroposlib/envs/base.py new file mode 100644 index 00000000..30b32895 --- /dev/null +++ b/atroposlib/envs/base.py @@ -0,0 +1,890 @@ +import asyncio +import json +import logging +import os +import random +import sys +import time +import uuid +import warnings +from abc import ABC, abstractmethod +from enum import Enum +from typing import Any, Dict, List, Optional, Tuple, TypedDict, Union + +import aiohttp +import jsonlines +import numpy as np +from pydantic import BaseModel, Field +from pydantic_cli import Cmd, FailedExecutionException, run_and_exit +from tenacity import retry, stop_after_attempt, wait_random_exponential +from transformers import AutoTokenizer + +import wandb +from atroposlib.type_definitions import UUID +from atroposlib.utils.metrics import get_std_min_max_avg + +from ..type_definitions import Item, Message +from .server_handling.server_manager import ( + OpenaiConfig, + ServerBaseline, + ServerManager, + ServerManagerConfig, +) + +logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) + + +class ScoredDataGroup(TypedDict): + tokens: List[List[int]] + masks: List[List[int]] + scores: List[float] + advantages: Optional[List[List[float]]] + ref_logprobs: Optional[List[List[float]]] + messages: Optional[List[List[Message]]] + group_overrides: Optional[Dict] + overrides: Optional[List[Dict]] + + +class EvalHandlingEnum(Enum): + """ + Enum for handling evals. + """ + + STOP_TRAIN = "STOP_TRAIN" + LIMIT_TRAIN = "LIMIT_TRAIN" + NONE = "NONE" + + +class BaseEnvConfig(BaseModel): + """ + Basic env configuration. + """ + + group_size: int = Field( + default=4, description="How many responses are grouped together for scoring" + ) + max_num_workers: int = Field( + default=-1, + description="Maximum number of workers to use, -1 calculates from max_num_workers_per_node", + ) + max_eval_workers: int = Field( + default=16, description="Maximum number of workers to use for evaluation" + ) + max_num_workers_per_node: int = Field( + default=8, description="Maximum number of workers to use per node" + ) + steps_per_eval: int = Field( + default=100, description="Number of steps to take before evaluating" + ) + max_token_length: int = Field( + default=2048, description="Maximum token length used in generations" + ) + eval_handling: EvalHandlingEnum = Field( + default=EvalHandlingEnum.STOP_TRAIN, description="How to handle evaluations" + ) + eval_limit_ratio: float = Field( + default=0.5, description="Ratio of training workers to limit during evals" + ) + inference_weight: float = Field( + default=1.0, + description="Inference weight, set to -1 to ignore it if you're doing something special here.", + ) + batch_size: int = Field( + default=-1, + description="Batch size for training, will be set by the trainer and passed in via the fastapi interface, if applicable", # noqa: E501 + ) + max_batches_offpolicy: int = Field( + default=3, description="Maximum number of batches to have in queue." + ) + tokenizer_name: str = Field( + default="NousResearch/DeepHermes-3-Llama-3-1B-Preview", + description="Hugging Face tokenzer to use.", + ) + use_wandb: bool = Field(default=True, description="Whether to use wandb") + rollout_server_url: str = Field( + default="http://localhost:8000", description="URL of the rollout server" + ) + total_steps: int = Field(default=1000, description="Total number of steps to run") + wandb_name: str | None = Field( + default=None, + description="Name to be grouped by in wandb", + ) + num_rollouts_to_keep: int = Field( + default=32, description="Number of rollouts to display on wandb" + ) + num_rollouts_per_group_for_logging: int = Field( + default=1, + description="Number of rollouts per group to keep for logging. If -1, keep all rollouts", + ) + ensure_scores_are_not_same: bool = Field( + default=True, + description="Ensure that the scores are not the same, should usually be True", + ) + data_path_to_save_groups: Optional[str] = Field( + default=None, + description="Path to save the groups, if set, will write groups to this jsonl", + ) + min_items_sent_before_logging: int = Field( + default=2, + description="Minimum number of items sent before logging, if 0 or less, logs every time", + ) + + +class BaseEnv(ABC): + + name = None + env_config_cls = BaseEnvConfig + + def __init__( + self, + config: BaseEnvConfig, + server_configs: Union[ServerBaseline, List[OpenaiConfig]], + slurm=True, + testing=False, + ): + self.items_sent_this_step = 0 + self.eval_runner = None # type: Optional[asyncio.Task] + self.workers_added_list = list() + self.succeeded_task_duration = list() + self.failed_task_duration = list() + self.task_duration = list() + self.mainloop_timings = list() + self.task_successful = list() + self.last_loop_time = None + self.last_completed_item = None + self.config = config + self.server = ServerManager(server_configs, slurm=slurm, testing=testing) + self.workers = set() + self.eval_workers = set() + self.backlog = [] + self.rollouts_for_wandb = [] + self.running_items: dict[UUID, Item] = dict() + self.wandb_project = None + self.wandb_group = None + self.curr_step = 0 + self.max_token_len = -1 + self.tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_name) + self.completion_lengths = [] + self.max_num_workers = config.max_num_workers + if self.max_num_workers == -1: + self.max_num_workers = config.max_num_workers_per_node * len( + self.server.servers + ) + self.wandb_prepend = None + self.checkpoint_dir = "" + self.checkpoint_interval = -1 + if self.config.data_path_to_save_groups is not None: + if os.path.exists(self.config.data_path_to_save_groups): + raise FileExistsError( + "Data path already exists! Please remove it or change it." + ) + self.jsonl_writer = jsonlines.open( + self.config.data_path_to_save_groups, "w" + ) # type: jsonlines.Writer + else: + self.jsonl_writer = None + + @classmethod + def config_init( + cls, + ) -> Tuple[BaseEnvConfig, Union[ServerBaseline, List[OpenaiConfig]]]: + """ + Initialize the config + """ + return cls.env_config_cls(), ServerBaseline() + + async def collect_trajectory(self, item: Item) -> Tuple[Any | None, List[Item]]: + raise NotImplementedError( + "Handle env single method must be implemented in subclass " + ) + + async def collect_trajectories(self, item: Item) -> Tuple[ + Union[ + Optional[ScoredDataGroup], List[Optional[ScoredDataGroup]], List[Any | None] + ], + List[Item], + ]: + """ + + :param item: + :return: + """ + tasks = [] + for _ in range(self.config.group_size): + tasks.append(self.collect_trajectory(item)) + results = await asyncio.gather(*tasks) + backlog = [] + to_postprocess = [] + for result in results: + if result[0] is not None: + to_postprocess.append(result[0]) + backlog.extend(result[1]) + random.shuffle(backlog) + return to_postprocess, backlog + + async def postprocess_histories( + self, + trajectories: Union[Optional[ScoredDataGroup], List[Optional[ScoredDataGroup]]], + ) -> Union[Optional[ScoredDataGroup], List[Optional[ScoredDataGroup]]]: + """ + Postprocess the histories, this is called after the collect_trajectories method + + If you don't need to do anything to the trajectories, you may safely ignore this. + + :param trajectories: + :return: + """ + return trajectories + + @abstractmethod + async def get_next_item(self) -> Item: + """ + Get the next items to be rolled out + """ + raise NotImplementedError( + "Get_next_items method must be implemented in subclass " + ) + + @abstractmethod + async def evaluate(self, *args, **kwargs): + """ + Evaluate the environment, this is called every steps_per_eval steps + + Included here is an example on how to use eval workers to run a task. + + You may however do whatever you want in this method. + + :param args: + :param kwargs: + :return: None. + """ + for data in ["my", "eval", "data"]: + while len(self.eval_workers) >= self.config.max_eval_workers: + await asyncio.sleep(0.1) + worker = asyncio.create_task(asyncio.sleep(0.1)) + self.eval_workers.add(worker) + worker.add_done_callback(self.eval_workers.discard) + raise NotImplementedError("Evaluate method must be implemented in subclass ") + + def load_checkpoint(self): + # check if file exists... + ckpt_path = os.path.join( + self.checkpoint_dir, + "env_checkpoints", + self.wandb_prepend, + f"step-{self.curr_step}.json", + ) + if os.path.exists(ckpt_path): + with open(ckpt_path, "r") as f: + data = json.load(f) + # now load the data + for key in data: + setattr(self, key, data[key]) + + def save_checkpoint(self, step, data=None): + print(f"Saving checkpoint at step {step} with data {data}") + if data is None: + # Don't have anything to save, abort + return + # check if file exists... + ckpt_dir = os.path.join( + self.checkpoint_dir, "env_checkpoints", self.wandb_prepend + ) + # create directory if necessary + os.makedirs(ckpt_dir, exist_ok=True) + ckpt_path = os.path.join( + self.checkpoint_dir, + "env_checkpoints", + self.wandb_prepend, + f"step-{step}.json", + ) + os.makedirs(os.path.dirname(ckpt_path), exist_ok=True) + with open(ckpt_path, "w") as f: + json.dump(data, f) + + async def setup(self): + """Setup the environment""" + raise NotImplementedError("Setup method must be implemented in subclass") + + async def setup_wandb(self): + if self.config.use_wandb: + # Setup wandb getting the group and project via the server + while self.wandb_project is None: + async with aiohttp.ClientSession() as session: + async with session.get( + f"{self.config.rollout_server_url}/wandb_info" + ) as resp: + data = await resp.json() + self.wandb_group = data["group"] + self.wandb_project = data["project"] + if self.wandb_project is None: + await asyncio.sleep(1) + else: + wandb.init( + project=self.wandb_project, + group=self.wandb_group, + config=self.config.model_dump(), + ) + break + + async def register_env(self): + # Now register the env... + async with aiohttp.ClientSession() as session: + async with session.post( + f"{self.config.rollout_server_url}/register-env", + json={ + "max_token_length": self.config.max_token_length, + "desired_name": self.config.wandb_name, + "weight": self.config.inference_weight, + }, + ) as resp: + data = await resp.json() + self.env_id = data["env_id"] + self.wandb_prepend = data["wandb_name"] + self.curr_step = data["starting_step"] + self.checkpoint_dir = data["checkpoint_dir"] + self.checkpoint_interval = data["checkpoint_interval"] + if self.config.total_steps == -1: + self.config.total_steps = data["num_steps"] + if self.config.total_steps == -1: + raise ValueError("Total steps not set in config or server!") + print( + f"Initialized env with id {self.env_id}: " + f"curr_step: {self.curr_step}, " + f"checkpoint_dir: {self.checkpoint_dir}, " + f"checkpoint_interval: {self.checkpoint_interval}" + ) + if self.curr_step > 0: + self.load_checkpoint() + + async def get_server_info(self): + """ + Get the server info + """ + async with aiohttp.ClientSession() as session: + async with session.get(f"{self.config.rollout_server_url}/info") as resp: + data = await resp.json() + if data["batch_size"] != -1: + # update the batch size + self.config.batch_size = data["batch_size"] + if data["max_token_len"] != -1: + self.max_token_len = data["max_token_len"] + if self.config.batch_size == -1: + logging.warning("Batch size not set by config or server!") + if self.config.group_size > self.config.batch_size: + raise ValueError( + f"group_size ({self.config.group_size}) " + f"must be less than batch_size ({self.config.batch_size})" + ) + + def perf_stats(self, metrics_dict): + """ + returns wandb metrics for performance + """ + if len(self.task_duration) > 1: + get_std_min_max_avg( + "train_perf/task_duration", self.task_duration, metrics_dict + ) + self.task_duration = list() + if len(self.succeeded_task_duration) > 1: + get_std_min_max_avg( + "train_perf/succeeded_task_duration", + self.succeeded_task_duration, + metrics_dict, + ) + metrics_dict["train/items_sent_to_api"] = len(self.succeeded_task_duration) + self.succeeded_task_duration = list() + if len(self.failed_task_duration) > 1: + get_std_min_max_avg( + "train_perf/failed_task_duration", + self.failed_task_duration, + metrics_dict, + ) + metrics_dict["train/items_rejected"] = len(self.failed_task_duration) + self.failed_task_duration = list() + if len(self.mainloop_timings) > 1: + get_std_min_max_avg( + "train_perf/mainloop_timings", + self.mainloop_timings, + metrics_dict, + ) + self.mainloop_timings = list() + if len(self.workers_added_list) > 1: + get_std_min_max_avg( + "train_perf/workers_added_per_attempt", + self.workers_added_list, + metrics_dict, + ) + self.workers_added_list = list() + return metrics_dict + + async def create_rollout_table(self, wandb_metrics): + if len(self.rollouts_for_wandb) > 0: + table = wandb.Table(columns=["text", "score"]) + for group in self.rollouts_for_wandb: + for item in group: + table.add_data(item[0], item[1]) + wandb_metrics["train/rollouts"] = table + return wandb_metrics + + async def add_rollouts_for_wandb( + self, + scored_data: Union[ScoredDataGroup, List[ScoredDataGroup]], + item: Item = None, + ): + # Save rollout to trajectory + num_keep = self.config.num_rollouts_per_group_for_logging + if num_keep == -1: + num_keep = self.config.group_size + self.rollouts_for_wandb.append( + [ + ( + self.tokenizer.decode(scored_data["tokens"][i]), + scored_data["scores"][i], + ) + for i in range(num_keep) + ] + ) + if len(self.rollouts_for_wandb) > self.config.num_rollouts_to_keep: + self.rollouts_for_wandb.pop(0) + + async def wandb_log(self, wandb_metrics: Optional[Dict] = None): + """ + Log to wandb. + + To use this in your subclass, please ensure this is called after you do your metrics + e.g. + def wandb_log(self, wandb_metrics: Optional[Dict] = None): + wandb_metrics = {} + wandb_metrics['my_metric'] = 0.5 + super().wandb_log(wandb_metrics) + """ + if wandb_metrics is None: + wandb_metrics = dict() + for i, server in enumerate(self.server.servers): + server_wandb_metrics = await server.wandb_metrics({}, f"server_{i}") + if len(self.completion_lengths) > 0: + wandb_metrics["train/completion_lengths"] = sum( + self.completion_lengths + ) / len(self.completion_lengths) + wandb_metrics["train/completion_lengths_std"] = np.std( + self.completion_lengths + ) + wandb_metrics["train/completion_lengths_max"] = np.max( + self.completion_lengths + ) + wandb_metrics["train/completion_lengths_min"] = np.min( + self.completion_lengths + ) + wandb_metrics["train/completion_lengths_p95"] = ( + np.array(self.completion_lengths) > (0.95 * self.max_token_len) + ).mean() + wandb_metrics = await self.create_rollout_table(wandb_metrics) + wandb_metrics = self.perf_stats(wandb_metrics) + self.rollouts_for_wandb = [] + self.completion_lengths = [] + if self.config.use_wandb: + if self.wandb_prepend is not None: + wandb_metrics = { + f"{self.wandb_prepend}_{k}": v for k, v in wandb_metrics.items() + } + # add server metrics to wandb without prepend to collate them all + wandb_metrics.update(server_wandb_metrics) + wandb.log(wandb_metrics, step=self.curr_step) + + @retry( + stop=stop_after_attempt(3), + wait=wait_random_exponential(multiplier=1, max=10), + ) + async def _send_scored_data_to_api(self, scored_data): + """ + Send scored data to the API with retry logic for timeouts and server errors. + """ + url = ( + f"{self.config.rollout_server_url}/scored_data_list" + if isinstance(scored_data, list) + else f"{self.config.rollout_server_url}/scored_data" + ) + async with aiohttp.ClientSession() as session: + async with session.post( + url, + json=scored_data, + ) as resp: + if resp.status >= 500: + # Server errors (5xx) should trigger a retry + logging.debug(f"Server error: {resp.status}, retrying...") + raise Exception(f"Server error: {resp.status}") + elif resp.status >= 400: + # Client errors (4xx) are logged but not retried + logging.error(f"Client error: {resp.status}, not retrying") + return + # Success case: print response text + print(await resp.text()) + + async def handle_send_to_api( + self, + scored_data: Union[ScoredDataGroup, List[ScoredDataGroup]], + item: Item = None, + ): + """ + Send the chats to the API with robust error handling and support for multiple ScoredDataGroups. + + Args: + scored_data: List of scored items to send + item: Optional item for context + """ + group_size = scored_data.get("group_overrides", {}).get( + "group_size", self.config.group_size + ) + if ( + (scored_data is not None) + and (None not in scored_data) + and (len(scored_data["tokens"]) == group_size) + ): + if self.config.ensure_scores_are_not_same: + if len(set(scored_data["scores"])) == 1: + # Scores are the same, don't send to API + return + await self.add_rollouts_for_wandb(scored_data, item) + # Check for ref_logprobs + if "ref_logprobs" not in scored_data: + # Strongly typed dict, so we need to add it + scored_data["ref_logprobs"] = None + if "overrides" not in scored_data: + scored_data["overrides"] = None + if "group_overrides" not in scored_data: + scored_data["group_overrides"] = None + + # Track completion lengths + for mask in scored_data["masks"]: + self.completion_lengths.append(len(mask)) + # Add the scores to the queue + if any([len(x) >= self.max_token_len for x in scored_data["tokens"]]): + # Don't send to API if the token length is too long + return + # Save data, if applicable: + if self.jsonl_writer is not None: + self.jsonl_writer.write(scored_data) + # Send data with retries and error handling + try: + self.items_sent_this_step += 1 + await self._send_scored_data_to_api(scored_data) + except (Exception, TimeoutError) as e: + print(f"Failed to send scored data after retries: {e}") + + async def handle_env( + self, item_uuid: str + ) -> Optional[Union[ScoredDataGroup, List[ScoredDataGroup]]]: + """ + Handle the rollout of an item + """ + item = self.running_items.get(item_uuid) + if item is None: + print(f"item {item_uuid} not found... returning") + return None + start_time = time.time() + logger.debug(f"handle_env: Starting with item: {item}") + # do a rollout with item + try: + to_postprocess, to_backlog = await self.collect_trajectories(item) + except Exception: + to_postprocess = None + to_backlog = [] + # add the items to the queue + if len(to_backlog) > 0: + self.backlog.extend(to_backlog) + try: + if (to_postprocess is None) or (len(to_postprocess) == 0): + pass + else: + to_postprocess = await self.postprocess_histories(to_postprocess) + except Exception as e: + logger.error(f"Error in scoring: {item}") + print(e) + to_postprocess = None + self.running_items.pop(item_uuid, None) + duration = max(0.0, time.time() - start_time) + self.task_duration.append(duration) + if to_postprocess is not None: + self.task_successful.append(1) + self.succeeded_task_duration.append(duration) + logger.debug(f"handle_env: Collected {len(to_postprocess)} trajectories") + await self.handle_send_to_api(to_postprocess, item) + else: + self.task_successful.append(0) + self.failed_task_duration.append(duration) + logger.debug("handle_env: No trajectories collected") + # Finally pop it + await self.cleanup() + return to_postprocess + + async def cleanup(self): + """ + Optional: Cleanup the environment + """ + pass + + @retry( + stop=stop_after_attempt(3), wait=wait_random_exponential(multiplier=1, max=10) + ) + async def get_status(self): + async with aiohttp.ClientSession() as session: + async with session.get( + f"{self.config.rollout_server_url}/status-env", + json={"env_id": self.env_id}, + ) as resp: + self.status_dict = await resp.json() + new_weight = self.status_dict["env_weight"] + max_num_workers = self.config.max_num_workers + if max_num_workers == -1: + max_num_workers = self.config.max_num_workers_per_node * len( + self.server.servers + ) + self.max_num_workers = max_num_workers + await self.server.update_weight(new_weight) + + async def env_step_checks(self): + # Check if we need to run an eval or log... + if self.curr_step != self.status_dict["current_step"]: + if self.config.steps_per_eval > 0: + if (self.curr_step % self.config.steps_per_eval) > ( + self.status_dict["current_step"] % self.config.steps_per_eval + ): + if (self.eval_runner is None) or (self.eval_runner.done()): + eval_task = asyncio.create_task(self.evaluate()) + self.eval_runner = eval_task + if self.config.eval_handling == EvalHandlingEnum.STOP_TRAIN: + # Stop training if eval is running + self.backlog.extend(self.running_items.values()) + for worker in self.workers: + worker.cancel() + self.workers = set() + self.running_items: dict[UUID, Item] = dict() + else: + warnings.warn( + "Eval is not finished in this iteration of the loop, skipping this eval step..." + ) + if self.checkpoint_interval > 0: + if (self.curr_step % self.checkpoint_interval) > ( + self.status_dict["current_step"] % self.checkpoint_interval + ): + checkpoint_step = ( + self.status_dict["current_step"] // self.checkpoint_interval + ) * self.checkpoint_interval + self.save_checkpoint(checkpoint_step) + self.curr_step = self.status_dict["current_step"] + if self.items_sent_this_step >= self.config.min_items_sent_before_logging: + self.items_sent_this_step = 0 + await self.wandb_log({}) + + async def add_train_workers(self): + if (self.eval_runner is not None) and (not self.eval_runner.done()): + if self.config.eval_handling == EvalHandlingEnum.STOP_TRAIN: + return + elif self.config.eval_handling == EvalHandlingEnum.LIMIT_TRAIN: + max_num_workers = int( + self.max_num_workers * self.config.eval_limit_ratio + ) + else: + max_num_workers = self.max_num_workers + else: + max_num_workers = self.max_num_workers + # set max_num_workers to whatever is max off policy and num workers + max_num_workers = min( + max_num_workers, + ( + self.config.max_batches_offpolicy + * self.config.batch_size + // self.config.group_size + ) + - (self.status_dict["queue_size"]), + ) + if (self.curr_step == 0) and (len(self.workers) == 0): + # We are starting up, so we should just skip the append to the list + pass + else: + self.workers_added_list.append(max_num_workers - len(self.workers)) + while len(self.workers) < max_num_workers: + # Generate a UUID for tracking this item + item_uuid = str(uuid.uuid4()) + if len(self.backlog) > 0: + item = self.backlog.pop() + else: + item = await self.get_next_item() + if item is None: + break + self.running_items[item_uuid] = item + worker = asyncio.create_task(self.handle_env(item_uuid)) + self.workers.add(worker) + worker.add_done_callback( + lambda fut, i=item: ( + ( + self.workers.discard(fut), + ( + setattr(self, "last_completed_item", i) + if fut.result() + else None + ), + )[1] + if fut.done() and not fut.cancelled() + else None + ) + ) + + async def env_manager(self): + """ + Rollout manager + """ + await self.setup() + await self.setup_wandb() + await self.register_env() + await self.get_server_info() + # Wait for other instances to get setup :) + await asyncio.sleep(5) + while True: + if self.last_loop_time is not None: + self.mainloop_timings.append( + max(0.0, time.time() - self.last_loop_time) + ) + # get status from server + self.last_loop_time = time.time() + await self.get_status() + await self.env_step_checks() + logger.info(f"env_manager: Status dict: {self.status_dict}") + if ( + self.status_dict["current_step"] + + ( + self.status_dict["queue_size"] + * self.config.group_size + // self.config.batch_size + ) + ) > self.config.total_steps: + for worker in self.workers: + worker.cancel() + break + if ( + ( + self.status_dict["queue_size"] * self.config.group_size + >= self.config.max_batches_offpolicy * self.config.batch_size + ) + and (self.config.max_batches_offpolicy > 0) + ) or (self.config.batch_size == -1): + # We have too many, lets cleanup the tasks and wait a bit + self.backlog.extend(self.running_items.values()) + for worker in self.workers: + worker.cancel() + self.running_items = dict() + self.workers = set() + elif len(self.workers) >= self.max_num_workers: + pass + else: + await self.add_train_workers() + await asyncio.sleep(0.1) + + @classmethod + def cli(cls): + """ + Command-line interface entry point for the environment. + This method handles the CLI commands for serve and process. + """ + + # Create subcommands dictionary + subcommands = { + "serve": cls.get_cli_serve_config_cls(), + "process": cls.get_cli_process_config_cls(), + } + + # Custom exception handler for cleaner error output + def custom_error_handler(ex: Exception) -> int: + """Handles exceptions with clean output for known error types.""" + if isinstance(ex, FailedExecutionException): + # Handle argparse errors (already printed by argparse) + print() + print(ex.message.split("error: ")[-1]) + return 2 + else: + # For any other exception + print(f"Error: {str(ex)}", file=sys.stderr) + return 1 + + run_and_exit( + subcommands, + description=f"CLI for {cls.__name__}", + exception_handler=custom_error_handler, + ) + + @classmethod + def get_cli_serve_config_cls(cls) -> type: + """ + Returns the CLI configuration class for serving commands. + + Returns: + type: The CliServeConfig class for serving commands. + """ + + env_config, server_configs = cls.config_init() + + class CliServeConfig( + cls.env_config_cls, OpenaiConfig, ServerManagerConfig, Cmd + ): + """ + Configuration for the serve command. + This combines BaseEnvConfig and OpenaiConfig into a single command. + """ + + def run(self) -> None: + """The logic to execute for the 'serve' command.""" + # Convert this config into the formats needed by BaseEnv + if self.wandb_name is None and cls.name is not None: + self.wandb_name = cls.name + model_dumped = self.model_dump(exclude_unset=True) + server_manager_config = ServerManagerConfig(**model_dumped) + # Create the environment instance + env = cls( + config=env_config, + server_configs=server_configs, + slurm=server_manager_config.slurm, + testing=server_manager_config.testing, + ) + + # Run the environment + asyncio.run(env.env_manager()) + + return CliServeConfig + + @classmethod + def get_cli_process_config_cls(cls) -> type: + """ + Returns the CLI configuration class for processing commands. + + Returns: + type: The CliProcessConfig class for processing commands. + """ + + class CliProcessConfig(Cmd): + """ + Configuration for the process command. + This is a placeholder for future implementation. + """ + + # Add process-specific fields here + group_size: int = Field( + default=4, description="Number of responses per prompt" + ) + n_groups: int = Field(default=1, description="Number of groups to process") + output_file: str = Field( + ..., description="Path to jsonl file to write results" + ) + + def run(self) -> None: + """The logic to execute for the 'process' command.""" + print( + f"Processing {self.n_groups} groups of " + f"{self.group_size} responses and " + f"writing to {self.output_file}" + ) + print("This is a placeholder implementation for the process command.") + # Actual implementation would go here + + return CliProcessConfig diff --git a/atroposlib/envs/reward_fns/__init__.py b/atroposlib/envs/reward_fns/__init__.py new file mode 100644 index 00000000..411d6149 --- /dev/null +++ b/atroposlib/envs/reward_fns/__init__.py @@ -0,0 +1,30 @@ +""" +Reward functions for evaluating model outputs in various environments. + +This module provides a framework for creating, composing, and applying reward functions +to evaluate model outputs. Reward functions can be used for both dataset environments +and online/gymnasium environments. + +Key components: +- RewardFunction: Abstract base class for all reward functions +- RewardRegistry: Registry for registering and loading reward functions +- CombinedReward: Meta reward function that combines multiple reward functions + +Usage: + # Define a reward function + @registry.register + class MyReward(RewardFunction): + def compute(self, completions, **kwargs): + # Implementation + return [score for completion in completions] + + # Create and use a reward function + reward_fn = registry.create("my_reward", weight=1.5) + scores = reward_fn(completions, **kwargs) +""" + +from .combined_reward import CombinedReward +from .registry import registry +from .reward_function import RewardFunction + +__all__ = ["RewardFunction", "registry", "CombinedReward"] diff --git a/atroposlib/envs/reward_fns/accuracy_reward.py b/atroposlib/envs/reward_fns/accuracy_reward.py new file mode 100644 index 00000000..4c53f168 --- /dev/null +++ b/atroposlib/envs/reward_fns/accuracy_reward.py @@ -0,0 +1,296 @@ +"""Reward function for checking if completions match ground truth answers.""" + +import logging +import re +from typing import Any, List, Optional, Union + +from latex2sympy2_extended import NormalizationConfig +from math_verify import LatexExtractionConfig, parse, verify + +from .registry import registry +from .reward_function import RewardFunction + +logger = logging.getLogger(__name__) + + +def _normalize_numerical_value(value_str: str) -> float: + """Convert a string representation of a number to float, handling formatting.""" + return float(value_str.replace(",", "").strip()) + + +def _extract_final_answer(text: str) -> str: + """ + Extract the final answer from text that might include a full solution. + + Handles formats like: + - "#### 42" (GSM8K style) + - "The answer is 42" + - "\\boxed{42}" + + Returns the extracted answer or the original text if no pattern is found. + """ + # Check for GSM8K style answers (#### 42) + if "####" in text: + match = re.search(r"####\s*(.*?)(?:\s*$|\n)", text) + if match: + return match.group(1).strip() + + # Check for boxed answers + if "\\boxed{" in text: + match = re.search(r"\\boxed\{([^}]+)\}", text) + if match: + return match.group(1).strip() + + # If no special format is found, return the original text + return text + + +def _verify_answer( + content: str, gold_answer: Union[float, int, str], tolerance: float = 1e-6 +) -> bool: + """ + Verifies if the provided content contains an answer matching the gold answer. + Uses a robust approach with multiple fallback strategies. + + Args: + content: The model's response content to evaluate + gold_answer: The correct answer to compare against + tolerance: Tolerance for floating point comparisons + + Returns: + Boolean indicating whether the answer is correct + """ + # Extract the final answer from the gold answer if it has a special format + if isinstance(gold_answer, str): + # Check for GSM8K style answers (#### number) + if "####" in gold_answer: + gold_answer = _extract_final_answer(gold_answer) + logger.warning(f"Extracted gold answer: {gold_answer}") + + # Convert gold_answer to numerical if it's not already and if possible + gold_value = None + if isinstance(gold_answer, (int, float)): + gold_value = gold_answer + elif isinstance(gold_answer, str): + # Try to extract numerical value if it's in boxed format + if "\\boxed{" in gold_answer: + try: + gold_value = _normalize_numerical_value( + gold_answer.replace("\\boxed{", "").replace("}", "") + ) + except ValueError: + # Not a numerical value, keep as string for LaTeX parsing + pass + else: + # Try to convert to float if possible + try: + gold_value = _normalize_numerical_value(gold_answer) + except ValueError: + # Not a numerical value, keep as string for LaTeX parsing + pass + + # First attempt: Try to parse with math_verify + try: + answer_parsed = parse( + content, + extraction_config=[ + LatexExtractionConfig( + normalization_config=NormalizationConfig( + nits=False, + malformed_operators=False, + basic_latex=True, + equations=True, + boxed="all", + units=True, + ), + # Ensures that boxed is tried first + boxed_match_priority=0, + try_extract_without_anchor=False, + ) + ], + extraction_mode="first_match", + ) + + logger.warning(f"Answer parsed result: {answer_parsed}") + + # If we got a valid parse, verify it against the gold answer + if answer_parsed: + # Format gold answer for verification + gold_str = ( + f"\\boxed{{{gold_answer}}}" + if not isinstance(gold_answer, str) or "\\boxed" not in gold_answer + else gold_answer + ) + + gold_parsed = parse( + gold_str, + extraction_mode="first_match", + extraction_config=[LatexExtractionConfig()], + ) + logger.warning(f"Gold parsed result: {gold_parsed}") + + if gold_parsed: + return verify(answer_parsed, gold_parsed) + except Exception as e: + logger.warning(f"Exception in primary parsing: {e}") + + # Fallback: Use regex to extract boxed content for numerical comparison + if gold_value is not None: # Only try numerical comparison if gold is a number + try: + # Try to extract a boxed answer first + boxed_matches = re.findall(r"\\boxed\{([^}]+)\}", content) + if boxed_matches: + logger.warning(f"Regex boxed matches: {boxed_matches}") + # Try to extract a numerical value from the boxed content + try: + extracted_value = _normalize_numerical_value(boxed_matches[0]) + logger.warning( + f"Extracted value: {extracted_value}, Gold value: {gold_value}" + ) + # Allow for small floating point differences + return abs(extracted_value - gold_value) < tolerance + except ValueError: + logger.warning(f"Could not convert '{boxed_matches[0]}' to float") + + # If no boxed answer, check for a final answer after #### + if "####" in content: + match = re.search(r"####\s*([\d\.]+)", content) + if match: + extracted_value = _normalize_numerical_value(match.group(1)) + logger.warning( + f"Extracted value from ####: {extracted_value}, Gold value: {gold_value}" + ) + return abs(extracted_value - gold_value) < tolerance + except Exception as e: + logger.warning(f"Exception in regex parsing: {e}") + + return False + + +@registry.register +class AccuracyReward(RewardFunction): + """ + Reward function that checks if completions match ground truth answers. + + Works with boxed LaTeX answers, GSM8K-style answers, and other formats. + Uses a robust approach with multiple fallback strategies for parsing and verification. + """ + + def __init__( + self, + tolerance: float = 1e-6, + split_on_think_tag: bool = True, + max_boxed_threshold: int = 6, + weight: float = 1.0, + **kwargs, + ): + """ + Initialize the accuracy reward function. + + Args: + tolerance: Tolerance for floating point comparisons + split_on_think_tag: Whether to use only the text after tag + max_boxed_threshold: Maximum number of boxed expressions before marking as incorrect + weight: Weight for this reward + **kwargs: Additional configuration + """ + super().__init__(weight=weight, **kwargs) + self.tolerance = tolerance + self.split_on_think_tag = split_on_think_tag + self.max_boxed_threshold = max_boxed_threshold + + def compute( + self, + completions: List[Any], + solution: Optional[Union[str, List[str]]] = None, + ground_truth: Optional[Union[str, List[str]]] = None, + **kwargs, + ) -> List[float]: + """ + Check if completions match ground truth answers. + + Args: + completions: List of model completions to evaluate + solution: Ground truth solution(s) - can be a single value or list of values + ground_truth: Optional canonical ground truth answers (used instead of solution if provided) + **kwargs: Additional context + + Returns: + List of reward values (1.0 for correct, 0.0 for incorrect) + """ + rewards = [] + + # Check if we have a solution or ground truth + if solution is None and ground_truth is None: + logger.warning("No solution or ground_truth provided to accuracy_reward") + return [0.0] * len(completions) + + # Use ground_truth instead of solution if available + gold_answers = ground_truth if ground_truth is not None else solution + + if isinstance(gold_answers, list): + answers = gold_answers + else: + answers = [gold_answers] * len(completions) + + for completion, ans in zip(completions, answers): + try: + content = self.get_content(completion) + + if ( + self.split_on_think_tag + and "" in content + and content.split("")[-1].count("\\boxed") + > self.max_boxed_threshold + ): + logger.warning( + "Too many \\boxed commands in response, marking as incorrect" + ) + reward = 0.0 + else: + if self.split_on_think_tag and "" in content: + answer_part = content.split("")[-1] + else: + answer_part = content + + reward = float(_verify_answer(answer_part, ans, self.tolerance)) + + except Exception as e: + logger.warning(f"Error in accuracy_reward: {e}") + logger.exception(e) + reward = 0.0 + + rewards.append(reward) + + # Calculate statistics + if rewards: + logger.info( + f"Accuracy: {sum(rewards)}/{len(rewards)} ({sum(rewards)/len(rewards):.2f})" + ) + + return rewards + + +# Legacy function for backward compatibility +def accuracy_reward( + completions: List[Any], + solution: Union[str, List[str]] = None, + ground_truth: Union[str, List[str]] = None, + **kwargs, +) -> List[float]: + """ + Legacy function wrapper for AccuracyReward. + + Args: + completions: List of model completions to evaluate + solution: Ground truth solution(s) - can be a single value or list of values + ground_truth: Optional canonical ground truth answers (used instead of solution if provided) + **kwargs: Additional parameters + + Returns: + List of reward values (1.0 for correct, 0.0 for incorrect) + """ + reward_fn = AccuracyReward() + return reward_fn.compute( + completions, solution=solution, ground_truth=ground_truth, **kwargs + ) diff --git a/atroposlib/envs/reward_fns/cascading_r1_math_reward.py b/atroposlib/envs/reward_fns/cascading_r1_math_reward.py new file mode 100644 index 00000000..584b7785 --- /dev/null +++ b/atroposlib/envs/reward_fns/cascading_r1_math_reward.py @@ -0,0 +1,200 @@ +import logging +import re +from typing import Any, List, Union + +from latex2sympy2_extended import NormalizationConfig +from math_verify import LatexExtractionConfig, parse, verify + +logger = logging.getLogger(__name__) + + +def get_completion_content(completion) -> str: + """Extract content from completion in various formats.""" + if isinstance(completion, str): + return completion + elif isinstance(completion, dict): + if "content" in completion: + return completion["content"] + elif isinstance(completion.get("message", {}), dict): + return completion["message"].get("content", "") + elif isinstance(completion, list) and len(completion) > 0: + if isinstance(completion[0], dict) and "content" in completion[0]: + return completion[0]["content"] + + logger.warning(f"Could not extract content from completion: {completion}") + return str(completion) + + +def _normalize_numerical_value(value_str: str) -> float: + """Convert a string representation of a number to float, handling formatting.""" + return float(value_str.replace(",", "").strip()) + + +def _extract_final_answer(text: str) -> str: + """Extract the final answer from text with various formats (GSM8K, boxed, etc).""" + if "####" in text: + match = re.search(r"####\s*(.*?)(?:\s*$|\n)", text) + if match: + return match.group(1).strip() + + if "\\boxed{" in text: + match = re.search(r"\\boxed\{([^}]+)\}", text) + if match: + return match.group(1).strip() + + return text + + +def _verify_answer(content: str, gold_answer: Union[float, int, str]) -> bool: + """Verifies if the content matches the gold answer using multiple strategies.""" + if isinstance(gold_answer, str): + if "####" in gold_answer: + gold_answer = _extract_final_answer(gold_answer) + logger.warning(f"Extracted gold answer: {gold_answer}") + + gold_value = None + if isinstance(gold_answer, (int, float)): + gold_value = gold_answer + elif isinstance(gold_answer, str): + if "\\boxed{" in gold_answer: + try: + gold_value = _normalize_numerical_value( + gold_answer.replace("\\boxed{", "").replace("}", "") + ) + except ValueError: + pass + else: + try: + gold_value = _normalize_numerical_value(gold_answer) + except ValueError: + pass + + # Try math_verify parsing first + try: + answer_parsed = parse( + content, + extraction_config=[ + LatexExtractionConfig( + normalization_config=NormalizationConfig( + nits=False, + malformed_operators=False, + basic_latex=True, + equations=True, + boxed="all", + units=True, + ), + boxed_match_priority=0, + try_extract_without_anchor=False, + ) + ], + extraction_mode="first_match", + ) + + if answer_parsed: + gold_str = ( + f"\\boxed{{{gold_answer}}}" + if not isinstance(gold_answer, str) or "\\boxed" not in gold_answer + else gold_answer + ) + gold_parsed = parse( + gold_str, + extraction_mode="first_match", + extraction_config=[LatexExtractionConfig()], + ) + if gold_parsed: + return verify(answer_parsed, gold_parsed) + except Exception as e: + logger.warning(f"Exception in primary parsing: {e}") + + # Fallback to numerical comparison + if gold_value is not None: + try: + boxed_matches = re.findall(r"\\boxed\{([^}]+)\}", content) + if boxed_matches: + try: + extracted_value = _normalize_numerical_value(boxed_matches[0]) + return abs(extracted_value - gold_value) < 1e-6 + except ValueError: + pass + + if "####" in content: + match = re.search(r"####\s*([\d\.]+)", content) + if match: + extracted_value = _normalize_numerical_value(match.group(1)) + return abs(extracted_value - gold_value) < 1e-6 + except Exception as e: + logger.warning(f"Exception in regex parsing: {e}") + + return False + + +def format_reward(completions, reward_value=0.5, **kwargs): + """Checks if completion has proper think tag formatting.""" + pattern = r"^[^<]*[^<]*$" + + try: + completion_contents = [get_completion_content(c) for c in completions] + matches = [ + re.match(pattern, content, re.DOTALL) for content in completion_contents + ] + return [reward_value if match else 0.0 for match in matches] + except Exception as e: + logger.error(f"Error in format reward calculation: {e}") + return [0.0] * len(completions) + + +def accuracy_reward( + completions: List[Any], solution: Union[str, List[str]], **kwargs +) -> List[float]: + """Checks answer accuracy using sophisticated verification.""" + rewards = [] + + if not isinstance(solution, list): + solution = [solution] * len(completions) + + for completion, sol in zip(completions, solution): + try: + content = get_completion_content(completion) + + if ( + "" in content + and content.split("")[-1].count("\\boxed") > 6 + ): + logger.warning( + "Too many \\boxed commands in response, marking as incorrect" + ) + reward = 0.0 + else: + answer_part = ( + content.split("")[-1] if "" in content else content + ) + reward = float(_verify_answer(answer_part, sol)) + except Exception as e: + logger.warning(f"Error in accuracy reward: {e}") + reward = 0.0 + + rewards.append(reward) + + return rewards + + +def cascading_r1_math_reward(completions, solution, **kwargs) -> list[float]: + """Combines sophisticated accuracy checking with format verification.""" + try: + accuracy_rewards = accuracy_reward(completions, solution, **kwargs) + format_rewards = format_reward(completions) + + combined_rewards = [] + for accuracy_score, format_score in zip(accuracy_rewards, format_rewards): + # Only add format bonus if answer is correct + format_bonus = format_score if accuracy_score > 0 else 0.0 + total_reward = accuracy_score + format_bonus + combined_rewards.append(total_reward) + + logger.info( + f"Teknium rewards: accuracy={accuracy_rewards}, format={format_rewards}, combined={combined_rewards}" + ) + return combined_rewards + except Exception as e: + logger.error(f"Error in teknium_reward: {e}") + return [0.0] * len(completions) diff --git a/atroposlib/envs/reward_fns/combined_reward.py b/atroposlib/envs/reward_fns/combined_reward.py new file mode 100644 index 00000000..c5aa8d61 --- /dev/null +++ b/atroposlib/envs/reward_fns/combined_reward.py @@ -0,0 +1,91 @@ +"""Combined reward function that combines multiple reward functions.""" + +import logging +from typing import Any, Dict, List, Union + +from .registry import registry +from .reward_function import RewardFunction + +logger = logging.getLogger(__name__) + + +@registry.register +class CombinedReward(RewardFunction): + """Meta reward function that combines multiple reward functions""" + + def __init__( + self, + rewards: List[Union[str, Dict]], + normalization: str = "none", + weight: float = 1.0, + **kwargs, + ): + """ + Initialize with a list of reward functions to combine. + + Args: + rewards: List of reward functions (names or config dicts) + normalization: How to normalize rewards, one of: + - "none": No normalization + - "sum": Divide by sum of weights + - "minmax": Scale to range [0,1] based on min/max values + weight: Weight for this combined reward + **kwargs: Additional parameters + """ + super().__init__(weight=weight, **kwargs) + self.normalization = normalization + self.reward_functions = [] + + # Initialize all sub-reward functions + for reward_config in rewards: + self.reward_functions.append(registry.create(reward_config)) + + @property + def name(self) -> str: + """Get a descriptive name for this combined reward""" + return f"combined({','.join(r.name for r in self.reward_functions)})" + + def set_wandb_logger(self, logger): + """Propagate the WandB logger to all sub-rewards""" + super().set_wandb_logger(logger) + for reward_fn in self.reward_functions: + reward_fn.set_wandb_logger(logger) + + def compute(self, completions: List[Any], **kwargs) -> List[float]: + """Compute combined rewards by calling all sub-rewards""" + if not completions: + return [] + + # Initialize with zeros + combined_rewards = [0.0] * len(completions) + + # Collect all sub-reward values + all_rewards = [] + for reward_fn in self.reward_functions: + try: + rewards = reward_fn.compute(completions, **kwargs) + all_rewards.append(rewards) + + # Add to combined total (pre-normalization) + for i, r in enumerate(rewards): + combined_rewards[i] += r + except Exception as e: + logger.error(f"Error computing reward for {reward_fn.name}: {e}") + logger.exception(e) + + # Apply normalization if needed + if self.normalization == "sum": + total_weight = sum(r.weight for r in self.reward_functions) + if total_weight > 0: + combined_rewards = [r / total_weight for r in combined_rewards] + elif self.normalization == "minmax": + # Avoid division by zero + reward_min = min(combined_rewards) if combined_rewards else 0 + reward_max = max(combined_rewards) if combined_rewards else 0 + if reward_max > reward_min: + combined_rewards = [ + (r - reward_min) / (reward_max - reward_min) + for r in combined_rewards + ] + + return combined_rewards diff --git a/atroposlib/envs/reward_fns/cosine_scaled_reward.py b/atroposlib/envs/reward_fns/cosine_scaled_reward.py new file mode 100644 index 00000000..0b620abe --- /dev/null +++ b/atroposlib/envs/reward_fns/cosine_scaled_reward.py @@ -0,0 +1,201 @@ +"""Reward function for evaluating semantic similarity between completions and solutions.""" + +import logging +from typing import Any, List, Optional, Union + +import scipy +import torch +from transformers import AutoModel, AutoTokenizer + +from .registry import registry +from .reward_function import RewardFunction + +logger = logging.getLogger(__name__) + + +@registry.register +class CosineScaledReward(RewardFunction): + """ + Reward function that measures semantic similarity between completions and solutions. + + Uses sentence embeddings to compute cosine similarity, providing higher rewards + for completions that are semantically similar to the reference solution. + """ + + # Class-level variables for model caching + _model = None + _tokenizer = None + _model_name = "sentence-transformers/all-MiniLM-L6-v2" + + def __init__( + self, + model_name: Optional[str] = None, + scale_factor: float = 1.0, + min_reward: float = -1.0, + max_reward: float = 1.0, + default_reward: float = 0.0, + weight: float = 1.0, + **kwargs, + ): + """ + Initialize the cosine similarity reward function. + + Args: + model_name: Name of embedding model to use (default: "sentence-transformers/all-MiniLM-L6-v2") + scale_factor: Factor to scale similarity by (default: 1.0) + min_reward: Minimum reward value (default: -1.0) + max_reward: Maximum reward value (default: 1.0) + default_reward: Default reward when similarity can't be calculated (default: 0.0) + weight: Weight for this reward + **kwargs: Additional configuration + """ + super().__init__(weight=weight, **kwargs) + self.model_name = model_name or self._model_name + self.scale_factor = scale_factor + self.min_reward = min_reward + self.max_reward = max_reward + self.default_reward = default_reward + + # Initialize model and tokenizer if needed + self._ensure_model_loaded() + + def _ensure_model_loaded(self): + """Ensure the model and tokenizer are loaded, loading them if needed.""" + # Check if we need to load a different model than what's cached + if self.model_name != self._model_name or CosineScaledReward._model is None: + try: + CosineScaledReward._tokenizer = AutoTokenizer.from_pretrained( + self.model_name + ) + CosineScaledReward._model = AutoModel.from_pretrained(self.model_name) + CosineScaledReward._model_name = self.model_name + logger.info( + f"Loaded model and tokenizer for cosine similarity: {self.model_name}" + ) + except Exception as e: + logger.error(f"Error loading model for cosine similarity: {e}") + logger.exception(e) + CosineScaledReward._tokenizer = None + CosineScaledReward._model = None + + def _mean_pooling(self, model_output, attention_mask): + """Mean Pooling - Take attention mask into account for correct averaging""" + token_embeddings = model_output[0] + input_mask_expanded = ( + attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() + ) + return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp( + input_mask_expanded.sum(1), min=1e-9 + ) + + def _get_embeddings(self, text): + """Get embeddings for text using the model""" + if CosineScaledReward._model is None or CosineScaledReward._tokenizer is None: + logger.error("Model or tokenizer not available for embeddings") + return None + + try: + # Tokenize and prepare for the model + encoded_input = CosineScaledReward._tokenizer( + text, padding=True, truncation=True, return_tensors="pt" + ) + + # Get model output + with torch.no_grad(): + model_output = CosineScaledReward._model(**encoded_input) + + # Perform mean pooling + sentence_embeddings = self._mean_pooling( + model_output, encoded_input["attention_mask"] + ) + + return sentence_embeddings.numpy() + except Exception as e: + logger.error(f"Error getting embeddings: {e}") + logger.exception(e) + return None + + def compute( + self, + completions: List[Any], + solution: Optional[Union[str, List[str]]] = None, + **kwargs, + ) -> List[float]: + """ + Calculate reward based on cosine similarity between completion and solution. + + Args: + completions: List of completions to evaluate + solution: The reference solution to compare against + **kwargs: Additional context + + Returns: + List of rewards based on cosine similarity, scaled to [min_reward, max_reward] + """ + # Extract content from different possible formats + completion_contents = [ + self.get_content(completion) for completion in completions + ] + + # If no solution provided, can't calculate similarity + if not solution: + logger.warning("No solution provided for cosine similarity calculation") + return [self.default_reward] * len(completion_contents) + + solution_text = ( + solution if isinstance(solution, str) else self.get_content(solution) + ) + + rewards = [] + for content in completion_contents: + try: + # Get embeddings + solution_embedding = self._get_embeddings(solution_text) + completion_embedding = self._get_embeddings(content) + + if solution_embedding is None or completion_embedding is None: + logger.warning("Could not get embeddings for cosine similarity") + rewards.append(self.default_reward) + continue + + # Calculate cosine similarity + similarity = scipy.spatial.distance.cosine( + solution_embedding.flatten(), completion_embedding.flatten() + ) + + # Scale similarity to a reward between min_reward and max_reward + # Cosine distance ranges from 0 (similar) to 2 (dissimilar) + # We want to map 0 → max_reward (good) and 2 → min_reward (bad) + normalized_similarity = 1.0 - similarity * self.scale_factor + reward = min( + self.max_reward, max(self.min_reward, normalized_similarity) + ) + + logger.info(f"Cosine similarity: {similarity}, scaled reward: {reward}") + rewards.append(reward) + + except Exception as e: + logger.error(f"Error in cosine similarity calculation: {e}") + logger.exception(e) + rewards.append(self.default_reward) + + return rewards + + +# Legacy function for backward compatibility +def cosine_scaled_reward( + completions: List[Any], solution=None, **kwargs +) -> List[float]: + """ + Legacy function wrapper for CosineScaledReward. + + Args: + completions: List of completions to evaluate + solution: The reference solution to compare against + **kwargs: Additional parameters + + Returns: + List of rewards based on cosine similarity, scaled to [-1, 1] + """ + reward_fn = CosineScaledReward() + return reward_fn.compute(completions, solution=solution, **kwargs) diff --git a/atroposlib/envs/reward_fns/crossword_format_reward.py b/atroposlib/envs/reward_fns/crossword_format_reward.py new file mode 100644 index 00000000..03ad09eb --- /dev/null +++ b/atroposlib/envs/reward_fns/crossword_format_reward.py @@ -0,0 +1,121 @@ +"""Reward function for evaluating crossword puzzle answer formatting.""" + +import logging +import re +from typing import Any, List, Optional, Pattern + +from .registry import registry +from .reward_function import RewardFunction + +logger = logging.getLogger(__name__) + + +@registry.register +class CrosswordFormatReward(RewardFunction): + """ + Reward function for crossword puzzle game answers. + + Checks if completions follow the expected formatting for crossword puzzle answers: + - Contains answer patterns like "1-Across: WORD" + - Uses only valid characters (letters, no numbers or special chars in answers) + - Follows specified formatting patterns + """ + + def __init__( + self, + format_patterns: Optional[List[Pattern]] = None, + reward_value: float = 1.0, + penalize_invalid_chars: bool = True, + valid_chars: str = "ABCDEFGHIJKLMNOPQRSTUVWXYZ", + weight: float = 1.0, + **kwargs, + ): + """ + Initialize the crossword format reward function. + + Args: + format_patterns: List of regex patterns to match (optional) + reward_value: Value to award for correct formatting + penalize_invalid_chars: Whether to penalize invalid characters + valid_chars: String of valid characters for answers + weight: Weight for this reward + **kwargs: Additional configuration + """ + super().__init__(weight=weight, **kwargs) + self.reward_value = reward_value + self.penalize_invalid_chars = penalize_invalid_chars + self.valid_chars = valid_chars.upper() + + # Default patterns if none provided + self.format_patterns = format_patterns or [ + re.compile( + r"\d+-(?:Across|Down):\s+[A-Z\s]+", re.IGNORECASE + ), # Basic format pattern + re.compile( + r"^(?:\d+-(?:Across|Down):\s+[A-Z\s]+[\s,]*)+$", re.IGNORECASE + ), # Full response format + ] + + def compute(self, completions: List[Any], **kwargs) -> List[float]: + """ + Check if completions follow crossword answer formatting. + + Args: + completions: List of completions to evaluate + **kwargs: Additional context + + Returns: + List of rewards (reward_value for correct format, 0.0 otherwise) + """ + # Extract content from different possible formats + completion_contents = [ + self.get_content(completion) for completion in completions + ] + + rewards = [] + for content in completion_contents: + try: + # Check for format patterns + format_match = any( + pattern.search(content) for pattern in self.format_patterns + ) + + # Look for answers and check for invalid characters + valid_chars = True + if self.penalize_invalid_chars: + # Extract answers (text after "Across:" or "Down:") + answers = re.findall( + r"(?:Across|Down):\s+([A-Za-z]+)", content, re.IGNORECASE + ) + for answer in answers: + # Check if answer contains only valid characters + if not all(c.upper() in self.valid_chars for c in answer): + valid_chars = False + break + + # Both format and valid chars must be correct for full reward + correct_format = format_match and valid_chars + rewards.append(self.reward_value if correct_format else 0.0) + + except Exception as e: + logger.error(f"Error in crossword format reward calculation: {e}") + logger.exception(e) + rewards.append(0.0) + + return rewards + + +# Legacy function for backward compatibility +def crossword_format_reward(completions: List[Any], **kwargs) -> List[float]: + """ + Legacy function wrapper for CrosswordFormatReward. + + Args: + completions: List of completions to evaluate + **kwargs: Additional parameters + + Returns: + List of rewards for crossword format quality + """ + reward_fn = CrosswordFormatReward() + return reward_fn.compute(completions, **kwargs) diff --git a/atroposlib/envs/reward_fns/format_reward.py b/atroposlib/envs/reward_fns/format_reward.py new file mode 100644 index 00000000..21a59a51 --- /dev/null +++ b/atroposlib/envs/reward_fns/format_reward.py @@ -0,0 +1,105 @@ +"""Reward function for checking if completions have specific XML-style tags.""" + +import logging +import re +from typing import Any, List, Optional + +from .registry import registry +from .reward_function import RewardFunction + +logger = logging.getLogger(__name__) + + +@registry.register +class FormatReward(RewardFunction): + """Reward function that checks if completions have XML-style tags.""" + + def __init__( + self, + preferred_tags: Optional[List[str]] = None, + require_all_tags: bool = False, + case_sensitive: bool = False, + weight: float = 1.0, + **kwargs, + ): + """ + Initialize the format reward function. + + Args: + preferred_tags: List of tag names to search for (defaults to ['think', 'answer']) + require_all_tags: If True, require all tags to be present for a reward + case_sensitive: If True, perform case-sensitive tag matching + weight: Weight for this reward + **kwargs: Additional configuration + """ + super().__init__(weight=weight, **kwargs) + self.preferred_tags = preferred_tags or ["think", "answer"] + self.require_all_tags = require_all_tags + self.case_sensitive = case_sensitive + + def compute(self, completions: List[Any], **kwargs) -> List[float]: + """ + Check if completions have the expected XML-style tags. + + Args: + completions: List of completions to evaluate + **kwargs: Additional context + + Returns: + List of rewards for each completion (1.0 for good format, 0.0 otherwise) + """ + # Extract content from different possible formats + completion_contents = [ + self.get_content(completion) for completion in completions + ] + + # For each completion, check for the preferred tags + rewards = [] + + flags = 0 if self.case_sensitive else re.IGNORECASE + flags |= re.DOTALL # Allow . to match newlines + + for content in completion_contents: + if self.require_all_tags: + # All tags must be present + all_tags_present = True + for tag in self.preferred_tags: + pattern = f"<{tag}>.*?" + if not re.search(pattern, content, flags): + all_tags_present = False + break + rewards.append(1.0 if all_tags_present else 0.0) + else: + # Any tag can be present + has_tags = False + for tag in self.preferred_tags: + pattern = f"<{tag}>.*?" + if re.search(pattern, content, flags): + has_tags = True + break + rewards.append(1.0 if has_tags else 0.0) + + # Log the results + logger.info( + f"Format reward results: {sum(rewards)}/{len(rewards)} completions match format" + ) + return rewards + + +# Legacy function for backward compatibility +def format_reward( + completions: List[Any], preferred_tags: Optional[List[str]] = None, **kwargs +) -> List[float]: + """ + Legacy function wrapper for FormatReward. + + Args: + completions: List of completions to evaluate + preferred_tags: List of tag names to search for (defaults to ['think', 'answer']) + **kwargs: Additional keyword arguments + + Returns: + List of rewards for each completion (1.0 for good format, 0.0 otherwise) + """ + reward_fn = FormatReward(preferred_tags=preferred_tags) + return reward_fn.compute(completions, **kwargs) diff --git a/atroposlib/envs/reward_fns/r1_reward.py b/atroposlib/envs/reward_fns/r1_reward.py new file mode 100644 index 00000000..8b9f858d --- /dev/null +++ b/atroposlib/envs/reward_fns/r1_reward.py @@ -0,0 +1,363 @@ +"""Reward function that combines reasoning format and accuracy rewards.""" + +import logging +import re +from typing import Any, Dict, List, Optional, Union + +from .registry import registry +from .reward_function import RewardFunction + +logger = logging.getLogger(__name__) + + +def parse_reasoning_response(text: str) -> Dict[str, Any]: + """ + Parse text to extract thinking section and response section. + + Args: + text: Text to parse for thinking and response sections + + Returns: + Dictionary with thinking_content, response, and multiple_thinking flag + """ + # Check if text is actually a string + if not isinstance(text, str): + logger.warning(f"Expected string but got {type(text)}: {text}") + return { + "thinking_content": "", + "response": str(text), + "multiple_thinking": False, + } + + # Find all thinking blocks + thinking_blocks = re.findall(r".*?", text, re.DOTALL) + + # If there's more than one thinking block, fail + if len(thinking_blocks) > 1: + return {"thinking_content": "", "response": text, "multiple_thinking": True} + + # Match the single thinking block if it exists + pattern = r"\s*(.*?)\s*\s*(.*)" + match = re.search(pattern, text, re.DOTALL) + if not match: + return {"thinking_content": "", "response": text, "multiple_thinking": False} + + return { + "thinking_content": match.group(1).strip(), + "response": match.group(2).strip(), + "multiple_thinking": False, + } + + +@registry.register +class FormatReasoningReward(RewardFunction): + """ + Reward function that checks for proper reasoning format. + + Checks if completions have: + 1. A thinking section in tags + 2. A response section after the thinking tags + 3. No multiple thinking sections (only one block) + """ + + def __init__( + self, + reward_value: float = 0.5, + require_thinking: bool = True, + require_response: bool = True, + allow_multiple_thinking: bool = False, + weight: float = 1.0, + **kwargs, + ): + """ + Initialize the format reasoning reward function. + + Args: + reward_value: Value to award for correct formatting + require_thinking: Whether to require thinking content + require_response: Whether to require response content + allow_multiple_thinking: Whether to allow multiple thinking sections + weight: Weight for this reward + **kwargs: Additional configuration + """ + super().__init__(weight=weight, **kwargs) + self.reward_value = reward_value + self.require_thinking = require_thinking + self.require_response = require_response + self.allow_multiple_thinking = allow_multiple_thinking + + def compute(self, completions: List[Any], **kwargs) -> List[float]: + """ + Check if completions have proper reasoning format. + + Args: + completions: List of completions to evaluate + **kwargs: Additional context + + Returns: + List of rewards (reward_value for correct format, 0.0 otherwise) + """ + parsed = [] + for completion in completions: + try: + content = self.get_content(completion) + parsed.append(parse_reasoning_response(content)) + except Exception as e: + logger.error(f"Error parsing response: {e}") + logger.exception(e) + parsed.append( + {"thinking_content": "", "response": "", "multiple_thinking": False} + ) + + rewards = [] + for p in parsed: + try: + # Check if response meets format requirements + valid_format = True + + if self.require_thinking and not p["thinking_content"]: + valid_format = False + + if self.require_response and not p["response"]: + valid_format = False + + if not self.allow_multiple_thinking and p["multiple_thinking"]: + valid_format = False + + rewards.append(self.reward_value if valid_format else 0.0) + except Exception as e: + logger.error(f"Error in format reward calculation: {e}") + logger.exception(e) + rewards.append(0.0) + + return rewards + + +@registry.register +class AccuracyXReward(RewardFunction): + """ + Reward function that checks if completion responses contain the solution. + + First parses completions to extract the response part after thinking tags, + then checks if the solution string is contained in the response. + """ + + def __init__( + self, + exact_match: bool = False, + case_sensitive: bool = False, + reward_value: float = 1.0, + weight: float = 1.0, + **kwargs, + ): + """ + Initialize the accuracy reward function. + + Args: + exact_match: Whether to require exact match vs. contained + case_sensitive: Whether to do case-sensitive matching + reward_value: Value to award for correct answer + weight: Weight for this reward + **kwargs: Additional configuration + """ + super().__init__(weight=weight, **kwargs) + self.exact_match = exact_match + self.case_sensitive = case_sensitive + self.reward_value = reward_value + + def compute( + self, + completions: List[Any], + solution: Optional[Union[str, List[str]]] = None, + **kwargs, + ) -> List[float]: + """ + Check if completion responses contain the solution. + + Args: + completions: List of completions to evaluate + solution: The solution to check for + **kwargs: Additional context + + Returns: + List of rewards (reward_value for correct, 0.0 otherwise) + """ + if solution is None: + logger.warning("No solution provided for accuracy reward") + return [0.0] * len(completions) + + parsed_responses = [] + for completion in completions: + try: + content = self.get_content(completion) + parsed_responses.append(parse_reasoning_response(content)) + except Exception as e: + logger.error(f"Error parsing response: {e}") + logger.exception(e) + parsed_responses.append( + {"thinking_content": "", "response": "", "multiple_thinking": False} + ) + + rewards = [] + + # Ensure solution is in the right format + if not isinstance(solution, list): + solution = [solution] * len(parsed_responses) + + for resp, sol in zip(parsed_responses, solution): + try: + # Extract solution content if needed + sol_content = self.get_content(sol) if not isinstance(sol, str) else sol + resp_content = resp["response"] + + # Do the matching based on settings + if not self.case_sensitive: + sol_content = sol_content.lower() + resp_content = resp_content.lower() + + if self.exact_match: + match = resp_content == sol_content + else: + match = sol_content in resp_content + + rewards.append(self.reward_value if match else 0.0) + except Exception as e: + logger.error(f"Error in accuracy reward calculation: {e}") + logger.exception(e) + rewards.append(0.0) + + return rewards + + +@registry.register +class R1Reward(RewardFunction): + """ + Combined reward function that rewards both reasoning format and accuracy. + + This reward function combines: + 1. FormatReasoningReward - rewards for proper and response formatting + 2. AccuracyXReward - rewards for having the solution in the response + """ + + def __init__( + self, + format_weight: float = 0.5, + accuracy_weight: float = 1.0, + weight: float = 1.0, + **kwargs, + ): + """ + Initialize the R1 reward function. + + Args: + format_weight: Weight for the format component + accuracy_weight: Weight for the accuracy component + weight: Weight for the overall reward + **kwargs: Additional configuration + """ + super().__init__(weight=weight, **kwargs) + self.format_weight = format_weight + self.accuracy_weight = accuracy_weight + + # Create component reward functions + self.format_reward_fn = FormatReasoningReward( + reward_value=1.0, # Use 1.0 here, we'll apply weight in compute + weight=1.0, # This will be overridden + ) + + self.accuracy_reward_fn = AccuracyXReward( + reward_value=1.0, # Use 1.0 here, we'll apply weight in compute + weight=1.0, # This will be overridden + ) + + def compute( + self, + completions: List[Any], + solution: Optional[Union[str, List[str]]] = None, + **kwargs, + ) -> List[float]: + """ + Calculate combined format and accuracy rewards. + + Args: + completions: List of completions to evaluate + solution: The solution to check for accuracy + **kwargs: Additional context + + Returns: + List of combined rewards + """ + try: + # Calculate component rewards + format_rewards = self.format_reward_fn.compute(completions, **kwargs) + accuracy_rewards = self.accuracy_reward_fn.compute( + completions, solution=solution, **kwargs + ) + + # Apply component weights and combine + rewards = [ + (f * self.format_weight) + (a * self.accuracy_weight) + for f, a in zip(format_rewards, accuracy_rewards) + ] + + logger.info( + f"R1 rewards: accuracy={accuracy_rewards}, format={format_rewards}, combined={rewards}" + ) + return rewards + except Exception as e: + logger.error(f"Error in r1_reward: {e}") + logger.exception(e) + # Return zero rewards for all completions + return [0.0] * len(completions) + + +# Legacy function for backward compatibility +def format_reasoning_reward(completions: List[Any], **kwargs) -> List[float]: + """ + Legacy function wrapper for FormatReasoningReward. + + Args: + completions: List of completions to evaluate + **kwargs: Additional parameters + + Returns: + List of rewards for format quality + """ + reward_fn = FormatReasoningReward(reward_value=0.5) + return reward_fn.compute(completions, **kwargs) + + +def accuracy_reward( + completions: List[Any], solution: Union[str, List[str]], **kwargs +) -> List[float]: + """ + Legacy function wrapper for AccuracyXReward. + + Args: + completions: List of completions to evaluate + solution: The solution to check for + **kwargs: Additional parameters + + Returns: + List of rewards for accuracy + """ + reward_fn = AccuracyXReward() + return reward_fn.compute(completions, solution=solution, **kwargs) + + +def r1_reward( + completions: List[Any], solution: Union[str, List[str]], **kwargs +) -> List[float]: + """ + Legacy function wrapper for R1Reward. + + Args: + completions: List of completions to evaluate + solution: The solution to check for + **kwargs: Additional parameters + + Returns: + List of combined rewards + """ + reward_fn = R1Reward() + return reward_fn.compute(completions, solution=solution, **kwargs) diff --git a/atroposlib/envs/reward_fns/reasoning_steps_reward.py b/atroposlib/envs/reward_fns/reasoning_steps_reward.py new file mode 100644 index 00000000..529e2ae0 --- /dev/null +++ b/atroposlib/envs/reward_fns/reasoning_steps_reward.py @@ -0,0 +1,138 @@ +"""Reward function for evaluating step-by-step reasoning in completions.""" + +import logging +import re +from typing import Any, Dict, List, Optional + +from .registry import registry +from .reward_function import RewardFunction + +logger = logging.getLogger(__name__) + + +@registry.register +class ReasoningStepsReward(RewardFunction): + r""" + Reward function that evaluates step-by-step reasoning in completions. + + Looks for several types of step-by-step reasoning indicators: + 1. Numbered step patterns like "Step 1:", "Step 2:" + 2. Numbered lists like "1.", "2." at start of line + 3. Bullet points with hyphens or asterisks + 4. Sequential transition words (First, Second, Next, Finally, etc.) + """ + + def __init__( + self, + min_words: int = 10, + min_steps: int = 3, + base_score: float = 0.1, + pattern_weights: Optional[Dict[str, float]] = None, + weight: float = 1.0, + **kwargs, + ): + """ + Initialize the reasoning steps reward function. + + Args: + min_words: Minimum number of words to consider for base score + min_steps: Number of steps needed for full points in each category + base_score: Base score for having content longer than min_words + pattern_weights: Custom weights for each pattern type (optional) + weight: Weight for this reward + **kwargs: Additional configuration + """ + super().__init__(weight=weight, **kwargs) + self.min_words = min_words + self.min_steps = min_steps + self.base_score = base_score + + # Default pattern weights + self.pattern_weights = { + "numbered_steps": 0.5, # Strong indicators + "list_numbers": 0.5, # Strong indicators + "bullet_points": 0.4, # Medium indicators + "transition_words": 0.3, # Weaker indicators + } + + # Override with custom weights if provided + if pattern_weights: + self.pattern_weights.update(pattern_weights) + + # Patterns for different types of step indicators + self.patterns = { + # Step 1: style numbered steps + "numbered_steps": r"Step\s+\d+[\s:]+", + # Numbered lists (1., 2., etc.) + "list_numbers": r"(?:^|\n)\s*\d+\.\s+", + # Bullet points + "bullet_points": r"(?:^|\n)\s*[\-\*•]\s+", + # Sequential transition words - expanded to include more phrases + "transition_words": r"\b(?:First|Second|Third|Fourth|Fifth|Next|Then|Finally|" + r"Subsequently|Afterward|Lastly|Initially|To begin|Let\'s begin|" + r"I\'ll first|After that|In conclusion|Eventually|Subsequently|" + r"To solve|begin by|understand|analyze|apply|compute)\b", + } + + def compute(self, completions: List[Any], **kwargs) -> List[float]: + """ + Calculate reasoning quality scores based on pattern matching. + + Args: + completions: List of completions to evaluate + **kwargs: Additional context + + Returns: + List of reward scores between 0.0 and 1.0 + """ + # Extract content from different possible formats + completion_contents = [ + self.get_content(completion) for completion in completions + ] + + rewards = [] + for content in completion_contents: + score = 0.0 + pattern_matches = {} + + # Check for each type of pattern + for pattern_type, pattern in self.patterns.items(): + matches = re.findall(pattern, content, re.IGNORECASE | re.MULTILINE) + pattern_matches[pattern_type] = len(matches) + + # Add score based on matches and pattern weight + weight = self.pattern_weights.get( + pattern_type, 0.3 + ) # Default weight if not specified + score += min(1.0, len(matches) / self.min_steps) * weight + + # Add a small base score for any content that has more than just an answer + # This helps differentiate minimal reasoning from no reasoning + if len(content.split()) > self.min_words: + score += self.base_score + + # Cap the total score at 1.0 + score = min(1.0, score) + rewards.append(score) + + logger.info( + f"Reasoning steps reward for completion: {pattern_matches}, score: {score}" + ) + + return rewards + + +# Legacy function for backward compatibility +def reasoning_steps_reward(completions: List[Any], **kwargs) -> List[float]: + """ + Legacy function wrapper for ReasoningStepsReward. + + Args: + completions: List of completions to evaluate + **kwargs: Additional parameters + + Returns: + List of reward scores between 0.0 and 1.0 + """ + reward_fn = ReasoningStepsReward() + return reward_fn.compute(completions, **kwargs) diff --git a/atroposlib/envs/reward_fns/registry.py b/atroposlib/envs/reward_fns/registry.py new file mode 100644 index 00000000..59b9075b --- /dev/null +++ b/atroposlib/envs/reward_fns/registry.py @@ -0,0 +1,279 @@ +import importlib +import inspect +import logging +from pathlib import Path +from typing import Any, Callable, Dict, List, Set, Type, Union + +from .reward_function import RewardFunction + +logger = logging.getLogger(__name__) + + +class RewardRegistry: + """Registry for reward functions with factory pattern""" + + def __init__(self): + self._registry: Dict[str, Type[RewardFunction]] = {} + self._reward_fns_dir = Path(__file__).parent + + def register(self, cls=None, name=None): + """ + Register a reward function class. + + Can be used as a decorator: + + @registry.register + class MyReward(RewardFunction): + ... + + or with a custom name: + + @registry.register(name="custom_name") + class MyReward(RewardFunction): + ... + + Args: + cls: The reward function class to register + name: Optional custom name to register the class under + + Returns: + The registered class (for decorator use) + """ + + def _register(cls): + # Validate that it's a subclass of RewardFunction + if not inspect.isclass(cls) or not issubclass(cls, RewardFunction): + raise TypeError( + f"Class {cls.__name__} is not a subclass of RewardFunction" + ) + + registered_name = name or cls.__name__.lower() + if registered_name.endswith("reward"): + # Convert ClassNameReward to class_name + registered_name = registered_name[:-6].lower() + + self._registry[registered_name] = cls + logger.debug(f"Registered reward function: {registered_name}") + return cls + + if cls is None: + return _register + return _register(cls) + + def register_function(self, name, function): + """ + Register a legacy function-based reward function. + This is temporary for backward compatibility. + + Args: + name: Name to register the function under + function: The reward function to register + """ + self._registry[name] = function + + def create(self, name_or_config: Union[str, Dict], **kwargs) -> RewardFunction: + """ + Create a reward function from name or config dict. + + Args: + name_or_config: Either a string name of a registered reward function, + or a dict with 'type' key and optional parameters + **kwargs: Default parameters that can be overridden by config + + Returns: + Instantiated RewardFunction object + """ + if isinstance(name_or_config, str): + # Simple case: just a name + reward_type = name_or_config + reward_params = kwargs + else: + # Dict case with config + reward_config = name_or_config.copy() + reward_type = reward_config.pop("type") + + # Handle params dictionary if present + if "params" in reward_config: + params = reward_config.pop("params") + reward_config.update(params) + + # Start with kwargs as defaults, override with config + reward_params = {**kwargs} + reward_params.update(reward_config) + + # Make sure the reward function is loaded + if reward_type not in self._registry: + self._load_reward_function(reward_type) + + reward_class = self._registry[reward_type] + + # Handle legacy function-based reward functions + if not inspect.isclass(reward_class): + # This is a function not a class - handle legacy case + return LegacyFunctionWrapper(reward_class, **reward_params) + + # Create instance of the reward function class + return reward_class(**reward_params) + + def get(self, name: str) -> Union[Type[RewardFunction], Callable]: + """ + Get a reward function class by name. + + This is for backward compatibility with the old registry interface. + New code should use create() instead. + + Args: + name: The name of the reward function to get + + Returns: + The reward function class or function + """ + if name not in self._registry: + self._load_reward_function(name) + return self._registry[name] + + def _load_reward_function(self, name: str) -> None: + """ + Load a reward function from a file. + + This supports both new class-based and legacy function-based reward functions. + Files can be named either "name.py" or "name_reward.py". + + Args: + name: The name of the reward function to load + + Raises: + ImportError: If the reward function file is not found or can't be loaded + """ + try: + # Try different file name patterns + base_name = name + if name.endswith("_reward"): + base_name = name[:-7] # Remove "_reward" suffix + + module_paths = [ + self._reward_fns_dir / f"{base_name}.py", + self._reward_fns_dir / f"{base_name}_reward.py", + self._reward_fns_dir / f"{name}.py", + ] + + module_path = None + for path in module_paths: + if path.exists(): + module_path = path + break + + if module_path is None: + raise ImportError( + f"No reward function file found for {name} (tried {', '.join(str(p) for p in module_paths)})" + ) + + # Generate a unique module name to avoid import conflicts + module_name = f"atroposlib.envs.reward_fns.{module_path.stem}" + spec = importlib.util.spec_from_file_location(module_name, str(module_path)) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + + # First try to find a class that inherits from RewardFunction + for obj_name, obj in inspect.getmembers(module): + if ( + inspect.isclass(obj) + and issubclass(obj, RewardFunction) + and obj is not RewardFunction + ): + # Register the class with the requested name + # This ensures it's accessible by the name the test expects + self.register_function(name, obj) + return + + # If no class found, look for functions with matching name patterns + func_patterns = [f"{base_name}", f"{base_name}_reward", "format_reward"] + for func_name in func_patterns: + if hasattr(module, func_name): + self.register_function(name, getattr(module, func_name)) + return + + raise AttributeError(f"No reward function found in {module_path}") + + except Exception as e: + raise ImportError(f"Failed to load reward function {name}: {str(e)}") + + def list_registered(self) -> List[str]: + """Return list of all registered reward function names""" + return list(self._registry.keys()) + + def load_required_functions(self, config) -> Set[str]: + """ + Load all reward functions required by a config. + + This is for backward compatibility with the old registry interface. + + Args: + config: The config object to load reward functions from + + Returns: + A set of all loaded reward function names + """ + required_funcs = set() + + if hasattr(config, "datasets"): + for dataset in config.datasets: + for field in ["dataset_reward_funcs", "reward_funcs"]: + if hasattr(dataset, field) and getattr(dataset, field): + required_funcs.update(getattr(dataset, field)) + + if hasattr(dataset, "types") and dataset.types: + for type_config in dataset.types: + if "reward_funcs" in type_config: + required_funcs.update(type_config["reward_funcs"]) + + # Also check for reward_functions and reward_funcs at the top level + for field in ["reward_functions", "reward_funcs"]: + if hasattr(config, field) and getattr(config, field): + field_value = getattr(config, field) + for item in field_value: + if isinstance(item, str): + required_funcs.add(item) + elif isinstance(item, dict) and "type" in item: + required_funcs.add(item["type"]) + + for func_name in required_funcs: + self.get(func_name) + + return required_funcs + + +class LegacyFunctionWrapper(RewardFunction): + """Wrapper for legacy function-based reward functions to fit the new class-based interface""" + + def __init__(self, func: Callable, weight: float = 1.0, **kwargs): + """ + Initialize with a legacy reward function. + + Args: + func: The legacy reward function to wrap + weight: The weight for this reward function + **kwargs: Additional configuration parameters + """ + super().__init__(weight=weight, **kwargs) + self.func = func + self._func_name = func.__name__ + + @property + def name(self) -> str: + """Get the name of the wrapped function""" + return self._func_name + + def compute(self, completions: List[Any], **kwargs) -> List[float]: + """Call the wrapped function""" + result = self.func(completions, **kwargs) + + # Convert to list if it's a single value + if not isinstance(result, list): + result = [result] * len(completions) + + return result + + +# Global registry instance +registry = RewardRegistry() diff --git a/atroposlib/envs/reward_fns/repetition_penalty_reward.py b/atroposlib/envs/reward_fns/repetition_penalty_reward.py new file mode 100644 index 00000000..40298243 --- /dev/null +++ b/atroposlib/envs/reward_fns/repetition_penalty_reward.py @@ -0,0 +1,286 @@ +"""Reward function for penalizing repetitive content in completions.""" + +import logging +import re +from collections import Counter +from typing import Any, Dict, List, Optional, Set + +from .registry import registry +from .reward_function import RewardFunction + +logger = logging.getLogger(__name__) + + +@registry.register +class RepetitionPenaltyReward(RewardFunction): + """ + Reward function that penalizes repetitive content in completions. + + Analyzes various types of repetition: + 1. Repeated sentences/paragraphs + 2. Repeated words beyond natural frequency + 3. Repeated phrases (n-grams) + 4. Consecutive word repetition ("stuttering") + 5. Repeated sentence beginnings + """ + + def __init__( + self, + threshold: float = 0.05, + min_words: int = 10, + min_sentences: int = 2, + short_text_penalty: float = -0.1, + paragraph_repetition_base_penalty: float = -0.6, + component_weights: Optional[Dict[str, float]] = None, + stopwords: Optional[Set[str]] = None, + weight: float = 1.0, + **kwargs, + ): + """ + Initialize repetition penalty reward function. + + Args: + threshold: Maximum acceptable repetition rate + min_words: Minimum words required for full analysis + min_sentences: Minimum sentences required for full analysis + short_text_penalty: Penalty to apply for very short texts + paragraph_repetition_base_penalty: Base penalty for repeated paragraphs + component_weights: Custom weights for each repetition component + stopwords: Set of common words to ignore in word repetition check + weight: Weight for this reward + **kwargs: Additional configuration + """ + super().__init__(weight=weight, **kwargs) + self.threshold = threshold + self.min_words = min_words + self.min_sentences = min_sentences + self.short_text_penalty = short_text_penalty + self.paragraph_repetition_base_penalty = paragraph_repetition_base_penalty + + # Default component weights + self.component_weights = { + "word_repetition": 0.3, + "phrase_repetition": 0.4, + "consecutive_repetition": 0.6, + "beginning_repetition": 0.5, + } + + # Override with custom weights if provided + if component_weights: + self.component_weights.update(component_weights) + + # Default stopwords (common words to ignore) + self.stopwords = stopwords or { + "this", + "that", + "with", + "from", + "what", + "when", + "where", + "which", + "who", + "whom", + "whose", + "will", + "shall", + "should", + "would", + "could", + "have", + "has", + "had", + "been", + "being", + "than", + "then", + "there", + "these", + "those", + "their", + } + + def compute(self, completions: List[Any], **kwargs) -> List[float]: + """ + Calculate penalties for repetitive content. + + Args: + completions: List of completions to evaluate + **kwargs: Additional context + + Returns: + List of penalty scores between -1.0 and 0.0 (0.0 = no repetition) + """ + # Extract content from different possible formats + completion_contents = [ + self.get_content(completion) for completion in completions + ] + + rewards = [] + for content in completion_contents: + try: + # Split text into words and sentences + words = re.findall(r"\b\w+\b", content.lower()) + sentences = re.split(r"[.!?]", content) + sentences = [s.strip() for s in sentences if s.strip()] + + # For very short responses, apply a small penalty to be safe + if len(words) < self.min_words or len(sentences) < self.min_sentences: + logger.info("Text too short for detailed repetition analysis") + rewards.append(self.short_text_penalty) + continue + + # Check for identical sentences (paragraph repetition) + sentence_counts = Counter(sentences) + repeated_sentences = sum( + count - 1 + for sentence, count in sentence_counts.items() + if count > 1 + and len(sentence.split()) > 3 # Only count substantial sentences + ) + + # Severe penalty for repeated paragraphs or sentences + if repeated_sentences > 0: + paragraph_repetition_score = min( + 1.0, repeated_sentences / len(sentences) + ) + # Apply a strong penalty for paragraph repetition + penalty = self.paragraph_repetition_base_penalty - ( + paragraph_repetition_score * 0.4 + ) + logger.info( + f"Paragraph repetition detected: {repeated_sentences} instances" + ) + rewards.append(penalty) + continue + + # Check for word repetition + word_counts = Counter(words) + content_words = [ + w + for w in word_counts.keys() + if len(w) > 3 and w not in self.stopwords + ] + repeated_words = sum( + count - 1 # Count repetitions beyond the first occurrence + for word, count in word_counts.items() + if count > 2 + and word in content_words # Only count meaningful words + ) + word_repetition_rate = repeated_words / len(words) if words else 0 + + # Check for phrase repetition (n-grams) + phrase_repetition = 0 + if len(words) >= 5: + for n in range(3, 6): # Check 3, 4, and 5-grams + if len(words) >= n: + ngrams = [ + " ".join(words[i : i + n]) + for i in range(len(words) - n + 1) + ] + ngram_counts = Counter(ngrams) + repeated_ngrams = sum( + count - 1 + for phrase, count in ngram_counts.items() + if count > 1 + ) + phrase_repetition += repeated_ngrams * ( + n / 3 + ) # Weight by n-gram size + + phrase_repetition_rate = phrase_repetition / len(words) if words else 0 + + # Check for consecutive word repetition (stuttering) + consecutive_repeats = 0 + for i in range(1, len(words)): + if ( + words[i] == words[i - 1] and len(words[i]) > 2 + ): # Ignore short words + consecutive_repeats += 1 + consecutive_repetition_rate = ( + consecutive_repeats / len(words) if words else 0 + ) + + # Check for beginning of sentence repetition + sentence_beginnings = [] + for sentence in sentences: + words = sentence.split() + if len(words) >= 3: + sentence_beginnings.append(" ".join(words[:3])) + + beginning_counts = Counter(sentence_beginnings) + repeated_beginnings = sum( + count - 1 + for beginning, count in beginning_counts.items() + if count > 1 + ) + beginning_repetition_rate = ( + repeated_beginnings / len(sentences) if sentences else 0 + ) + + # Calculate overall repetition score with weights + repetition_score = ( + ( + word_repetition_rate + * self.component_weights.get("word_repetition", 0.3) + ) + + ( + phrase_repetition_rate + * self.component_weights.get("phrase_repetition", 0.4) + ) + + ( + consecutive_repetition_rate + * self.component_weights.get("consecutive_repetition", 0.6) + ) + + ( + beginning_repetition_rate + * self.component_weights.get("beginning_repetition", 0.5) + ) + ) + + # Calculate penalty: 0.0 for no repetition, up to -1.0 for high repetition + if repetition_score <= self.threshold: + penalty = 0.0 + else: + # Scale penalty from 0 to -1 based on how much repetition exceeds threshold + penalty = -min( + 1.0, (repetition_score - self.threshold) / (1 - self.threshold) + ) + + # Make sure we have at least some penalty for any repetition + penalty = min(penalty, -0.1) + + logger.info( + f"Word rep: {word_repetition_rate:.3f}, Phrase rep: {phrase_repetition_rate:.3f}, " + f"Consecutive rep: {consecutive_repetition_rate:.3f}, ", + f"Sentence rep: {beginning_repetition_rate:.3f}, " + f"Overall score: {repetition_score:.3f}, penalty: {penalty:.3f}", + ) + rewards.append(penalty) + + except Exception as e: + logger.error(f"Error in repetition penalty calculation: {e}") + logger.exception(e) + rewards.append(-0.2) # Apply small penalty on error + + return rewards + + +# Legacy function for backward compatibility +def repetition_penalty_reward( + completions: List[Any], threshold: float = 0.05, **kwargs +) -> List[float]: + """ + Legacy function wrapper for RepetitionPenaltyReward. + + Args: + completions: List of completions to evaluate + threshold: Maximum acceptable repetition rate (default 0.05) + Lower threshold means stricter penalties for repetition + **kwargs: Additional parameters + + Returns: + List of rewards between -1.0 and 0.0, where 0.0 means no repetition + """ + reward_fn = RepetitionPenaltyReward(threshold=threshold) + return reward_fn.compute(completions, **kwargs) diff --git a/atroposlib/envs/reward_fns/reward_function.py b/atroposlib/envs/reward_fns/reward_function.py new file mode 100644 index 00000000..223502e0 --- /dev/null +++ b/atroposlib/envs/reward_fns/reward_function.py @@ -0,0 +1,125 @@ +import logging +from abc import ABC, abstractmethod +from typing import Any, List, Optional + +logger = logging.getLogger(__name__) + + +class RewardFunction(ABC): + """Abstract base class for all reward functions""" + + def __init__(self, weight: float = 1.0, name: Optional[str] = None, **kwargs): + """ + Initialize reward function with a weight and optional configuration. + + Args: + weight: Importance factor when combining with other rewards + name: Optional custom name for this reward function instance + **kwargs: Additional configuration parameters specific to the reward function + """ + self.weight = weight + self._name = name + self.config = kwargs + self.wandb_logger = None + + @property + def name(self) -> str: + """Unique identifier for this reward function""" + return self._name or self.__class__.__name__.lower() + + @abstractmethod + def compute(self, completions: List[Any], **kwargs) -> List[float]: + """ + Compute reward scores for the given completions. + + Args: + completions: List of completions to evaluate + **kwargs: Additional context like solution, ground_truth, etc. + + Returns: + List of reward scores, one for each completion + """ + pass + + def __call__(self, completions: List[Any], **kwargs) -> List[float]: + """Wrapper that applies weight to the computed rewards""" + try: + rewards = self.compute(completions, **kwargs) + # Apply weight + weighted_rewards = [r * self.weight for r in rewards] + + # Log to wandb if available + if self.wandb_logger: + self.log_metrics(rewards, weighted_rewards) + + return weighted_rewards + except Exception as e: + logger.error(f"Error in reward function {self.name}: {e}") + logger.exception(e) + return [0.0] * len(completions) + + def set_wandb_logger(self, logger): + """Set the WandB logger for this reward function""" + self.wandb_logger = logger + + def log_metrics(self, raw_rewards: List[float], weighted_rewards: List[float]): + """Log reward metrics to WandB""" + if not self.wandb_logger or not raw_rewards: + return + + metrics = { + f"reward/{self.name}/mean_raw": sum(raw_rewards) / len(raw_rewards), + f"reward/{self.name}/mean_weighted": sum(weighted_rewards) + / len(weighted_rewards), + f"reward/{self.name}/min": min(raw_rewards), + f"reward/{self.name}/max": max(raw_rewards), + } + + self.wandb_logger.log(metrics) + + @staticmethod + def get_content(completion: Any) -> str: + """ + Extract content from different completion formats. + + Supports: + - String completions + - Dict with {"role": "assistant", "content": "text"} + - Dict with {"message": {"role": "assistant", "content": "text"}} + - List of messages where one has role "assistant" + + Args: + completion: The completion in any supported format + + Returns: + The extracted content as a string + """ + if isinstance(completion, str): + return completion + elif isinstance(completion, dict): + if ( + "role" in completion + and completion["role"] == "assistant" + and "content" in completion + ): + return completion["content"] + if "message" in completion and isinstance(completion["message"], dict): + if ( + "role" in completion["message"] + and completion["message"]["role"] == "assistant" + and "content" in completion["message"] + ): + return completion["message"]["content"] + elif isinstance(completion, list) and len(completion) > 0: + # Look for assistant messages + for msg in completion: + if ( + isinstance(msg, dict) + and "role" in msg + and msg["role"] == "assistant" + and "content" in msg + ): + return msg["content"] + + # If no assistant content found, return empty string + return "" diff --git a/atroposlib/envs/server_handling/openai_server.py b/atroposlib/envs/server_handling/openai_server.py new file mode 100644 index 00000000..4a3562cb --- /dev/null +++ b/atroposlib/envs/server_handling/openai_server.py @@ -0,0 +1,296 @@ +import asyncio +import collections +import time +from asyncio import exceptions +from typing import Optional + +import aiohttp +import numpy as np +import openai +from openai.types.chat.chat_completion import ChatCompletion +from openai.types.completion import Completion +from pydantic import BaseModel, Field +from tenacity import retry, stop_after_attempt, wait_random_exponential + + +class OpenaiConfig(BaseModel): + """ + Configuration for the server manager. + """ + + api_key: Optional[str] = Field( + default=None, description="API key for OpenAI API. Use 'x' for local servers." + ) + base_url: Optional[str] = Field( + default=None, + description="URL of the API endpoint. None if using official OpenAI API, otherwise local server URL.", + ) + timeout: int = Field( + default=1200, description="Timeout for the request in seconds." + ) + num_max_requests_at_once: int = Field( + default=512, + description="Maximum number of concurrent requests. Note: You should divide this by the n kwarg.", + ) + num_requests_for_eval: int = Field( + default=64, description="Maximum number of concurrent requests for evaluation." + ) + model_name: str = Field( + default="default", + description="The model name to use. Required for both OpenAI and local models.", + ) + rolling_buffer_length: int = Field( + default=1000, description="Length of the rolling buffer to store metrics." + ) + + +class AsyncSemWithAdaptiveWeight(asyncio.Semaphore): + def __init__(self, value: int): + super().__init__(value=value) + self.max_val = value + self.weight = 1.0 + + def update_weight(self, weight: float) -> None: + self.weight = weight + + def min_val(self): + return self.max_val * (1.0 - self.weight) + + def release(self): + """Release a semaphore, incrementing the internal counter by one. + + When it was zero on entry and another coroutine is waiting for it to + become larger than zero again, wake up that coroutine. + + If weight is set, it'll only wake up next if the value is greater than the max_val * weight + """ + self._value += 1 + if self._value > self.min_val(): + self._wake_up_next() + + def locked(self): + """Returns True if semaphore cannot be acquired immediately.""" + return self._value <= self.min_val() or ( + any(not w.cancelled() for w in (self._waiters or ())) + ) + + async def acquire(self): + """Acquire a semaphore. + + If the internal counter is larger than zero on entry, + decrement it by one and return True immediately. If it is + zero on entry, block, waiting until some other coroutine has + called release() to make it larger than 0, and then return + True. + """ + if not self.locked(): + self._value -= 1 + return True + + if self._waiters is None: + self._waiters = collections.deque() + fut = self._get_loop().create_future() + self._waiters.append(fut) + + # Finally block should be called before the CancelledError + # handling as we don't want CancelledError to call + # _wake_up_first() and attempt to wake up itself. + try: + try: + await fut + finally: + self._waiters.remove(fut) + except exceptions.CancelledError: + if not fut.cancelled(): + self._value += 1 + self._wake_up_next() + raise + + if self._value > self.min_val(): + self._wake_up_next() + return True + + +class OpenAIServer: + def __init__(self, config: OpenaiConfig): + self.config = config + self.openai = openai.AsyncClient( + api_key=config.api_key, + base_url=config.base_url, + timeout=config.timeout, + ) + self.sem = AsyncSemWithAdaptiveWeight(config.num_max_requests_at_once) + self.eval_sem = AsyncSemWithAdaptiveWeight(config.num_requests_for_eval) + self.server_healthy = True + self.attempts_list = [] + self.request_timings = [] + # in case eval is much different, we should keep different buffers + self.eval_attempts_list = [] + self.eval_request_timings = [] + self.check_task = None + self.initialized = False + + async def update_weight(self, weight: float) -> None: + # need to update sems + self.sem.update_weight(weight) + self.eval_sem.update_weight(weight) + + async def check_server_status_task(self): + while True: + try: + await self.openai.completions.create( + model=self.config.model_name, + prompt="hi", + max_tokens=1, + ) + self.server_healthy = True + except ( + aiohttp.ClientError, + openai.OpenAIError, + openai.APITimeoutError, + Exception, + ): + self.server_healthy = False + await asyncio.sleep(1) + + async def wandb_metrics( + self, metrics_dict: Optional[dict], server_name: Optional[str] + ): + if server_name is None: + server_name = "server" + if len(self.request_timings) > 0: + metrics_dict[f"server/{server_name}_request_time_avg"] = np.mean( + self.request_timings + ) + metrics_dict[f"server/{server_name}_request_time_std"] = np.std( + self.request_timings + ) + metrics_dict[f"server/{server_name}_request_time_99p"] = np.percentile( + self.request_timings, 99 + ) + if len(self.eval_request_timings) > 0: + metrics_dict[f"server/{server_name}_eval_request_time_avg"] = np.mean( + self.eval_request_timings + ) + metrics_dict[f"server/{server_name}_eval_request_time_std"] = np.std( + self.eval_request_timings + ) + metrics_dict[f"server/{server_name}_eval_request_time_99p"] = np.percentile( + self.eval_request_timings, 99 + ) + if len(self.attempts_list) > 0: + metrics_dict[f"server/{server_name}_average_num_attempts"] = np.mean( + self.attempts_list + ) + if len(self.eval_attempts_list) > 0: + metrics_dict[f"server/{server_name}_eval_retry_rate"] = np.mean( + self.eval_attempts_list + ) + return metrics_dict + + @retry( + stop=stop_after_attempt(3), wait=wait_random_exponential(multiplier=1, max=10) + ) + async def _chat_comp(self, stat_dict, **kwargs) -> ChatCompletion: + while not self.server_healthy: + await asyncio.sleep(1) + async with self.sem: + if stat_dict.get("start", None) is None: + stat_dict["start"] = time.time() + stat_dict["attempts"] += 1 + completions = await self.openai.chat.completions.create(**kwargs) + stat_dict["end"] = time.time() + return completions + + @retry( + stop=stop_after_attempt(3), wait=wait_random_exponential(multiplier=1, max=10) + ) + async def _chat_eval(self, stat_dict, **kwargs) -> ChatCompletion: + while not self.server_healthy: + await asyncio.sleep(1) + async with self.eval_sem: + if stat_dict.get("start", None) is None: + stat_dict["start"] = time.time() + stat_dict["attempts"] += 1 + completions = await self.openai.chat.completions.create(**kwargs) + stat_dict["end"] = time.time() + return completions + + @retry( + stop=stop_after_attempt(3), wait=wait_random_exponential(multiplier=1, max=10) + ) + async def chat_completion(self, **kwargs) -> ChatCompletion: + if not self.initialized: + if ( + self.config.base_url is not None + ): # skip health check if using OpenAI API + self.check_task = asyncio.create_task(self.check_server_status_task()) + else: + self.server_healthy = True + self.initialized = True + kwargs["model"] = self.config.model_name + split = kwargs.pop("split", "train") + stat_dict = {} + stat_dict["attempts"] = 0 + if split == "train": + ret_data = await self._chat_comp(stat_dict, **kwargs) + self.request_timings.append(stat_dict["end"] - stat_dict["start"]) + self.attempts_list.append(stat_dict["attempts"]) + else: + # Give separate eval workers, if desired, gotta go fast for those evals + ret_data = await self._chat_eval(stat_dict, **kwargs) + self.eval_request_timings.append(stat_dict["end"] - stat_dict["start"]) + self.eval_attempts_list.append(stat_dict["attempts"]) + return ret_data + + @retry( + stop=stop_after_attempt(3), wait=wait_random_exponential(multiplier=1, max=10) + ) + async def _comp(self, stat_dict, **kwargs) -> Completion: + while not self.server_healthy: + await asyncio.sleep(1) + async with self.sem: + if stat_dict.get("start", None) is None: + stat_dict["start"] = time.time() + stat_dict["attempts"] += 1 + completions = await self.openai.completions.create(**kwargs) + stat_dict["end"] = time.time() + return completions + + @retry( + stop=stop_after_attempt(3), wait=wait_random_exponential(multiplier=1, max=10) + ) + async def _comp_eval(self, stat_dict, **kwargs) -> Completion: + while not self.server_healthy: + await asyncio.sleep(1) + async with self.eval_sem: + if stat_dict.get("start", None) is None: + stat_dict["start"] = time.time() + stat_dict["attempts"] += 1 + completions = await self.openai.completions.create(**kwargs) + stat_dict["end"] = time.time() + return completions + + async def completion(self, **kwargs) -> Completion: + if not self.initialized: + if ( + self.config.base_url is not None + ): # skip health check if using OpenAI API + self.check_task = asyncio.create_task(self.check_server_status_task()) + else: + self.server_healthy = True + self.initialized = True + kwargs["model"] = self.config.model_name + split = kwargs.pop("split", "train") + stat_dict = {} + stat_dict["attempts"] = 0 + if split == "train": + ret_data = await self._comp(stat_dict, **kwargs) + self.request_timings.append(stat_dict["end"] - stat_dict["start"]) + self.attempts_list.append(stat_dict["attempts"]) + else: + # Give separate eval workers, if desired, gotta go fast for those evals + ret_data = await self._comp_eval(stat_dict, **kwargs) + self.eval_request_timings.append(stat_dict["end"] - stat_dict["start"]) + self.eval_attempts_list.append(stat_dict["attempts"]) + return ret_data diff --git a/atroposlib/envs/server_handling/server_harness.py b/atroposlib/envs/server_handling/server_harness.py new file mode 100644 index 00000000..b0807d98 --- /dev/null +++ b/atroposlib/envs/server_handling/server_harness.py @@ -0,0 +1,146 @@ +import asyncio +from typing import Dict, List, Literal, Union + +from openai.types.chat.chat_completion import ( + ChatCompletion, + ChatCompletionMessage, + Choice, +) +from openai.types.completion import Completion, CompletionChoice + + +def create_chat_completion( + resp: Union[str, List[str]], + n: int = 1, + finish_reason: Union[ + Literal["stop", "length", "tool_calls", "content_filter", "function_call"], + List[ + Literal["stop", "length", "tool_calls", "content_filter", "function_call"] + ], + ] = "stop", +) -> ChatCompletion: + """ + Simple helper for creating a ChatCompletion object, if you need it + :param resp: + :param n: + :param finish_reason: + :return: + """ + choices = [ + Choice( + finish_reason=( + finish_reason if isinstance(finish_reason, str) else finish_reason[i] + ), + index=i, + message=ChatCompletionMessage( + content=resp if isinstance(resp, str) else resp[i], + role="assistant", + ), + ) + for i in range(n) + ] + return ChatCompletion( + id="test_id", + created=0, + model="test_model", + object="chat.completion", + choices=choices, + ) + + +def create_completion( + resp: Union[str, List[str]], + n: int = 1, + finish_reason: Union[ + Literal["stop", "length", "tool_calls", "content_filter", "function_call"], + List[ + Literal["stop", "length", "tool_calls", "content_filter", "function_call"] + ], + ] = "stop", +) -> Completion: + """ + Simple helper for creating a Completion object, if you need it + :param resp: + :param n: + :param finish_reason: + :return: + """ + choices = [ + CompletionChoice( + finish_reason=( + finish_reason if isinstance(finish_reason, str) else finish_reason[i] + ), + index=i, + text=resp if isinstance(resp, str) else resp[i], + ) + for i in range(n) + ] + return Completion( + id="test_id", + created=0, + model="test_model", + object="text_completion", + choices=choices, + ) + + +class ServerHarness: + def __init__(self): + self.response_map = dict() + self.sem = asyncio.Semaphore(1) + self.eval_sem = asyncio.Semaphore(1) + pass + + def conv_to_dictkey(self, input_message: List[Dict[str, str]]) -> str: + dictkey = list() + for item in input_message: + dictkey.append(f"role:{item['role']}") + dictkey.append(f"content:{item['content']}") + return "\n".join(dictkey) + + async def update_weight(self, weight): + pass + + def set_desired_response( + self, input_message: List[Dict[str, str]], desired_response: ChatCompletion + ): + dictkey = self.conv_to_dictkey(input_message) + self.response_map[dictkey] = desired_response + + def set_desired_completion(self, input_message: str, completion: Completion): + self.response_map[input_message] = completion + + async def chat_completion(self, *args, **kwargs) -> ChatCompletion: + messages = kwargs.get("messages") + dictkey = self.conv_to_dictkey(messages) + try: + return self.response_map.get(dictkey) + except KeyError as e: + raise KeyError(f"KeyError: {e} for key:\n{dictkey}") + + async def completion(self, *args, **kwargs) -> Completion: + prompt = kwargs.get("prompt") + try: + return self.response_map.get(prompt) + except KeyError as e: + raise KeyError(f"KeyError: {e} for key:\n{prompt}") + + +if __name__ == "__main__": + + async def main(): + test_compl = create_chat_completion("hello") + harness = ServerHarness() + harness.set_desired_response([{"role": "user", "content": "hi"}], test_compl) + print(harness.response_map) + print(harness.conv_to_dictkey([{"role": "user", "content": "hi"}])) + print( + await harness.chat_completion(messages=[{"role": "user", "content": "hi"}]) + ) + # now, let's test the completion + test_completion = create_completion("\nhello") + harness.set_desired_completion("hi", test_completion) + print(harness.response_map) + print(await harness.completion(prompt="hi")) + + asyncio.run(main()) diff --git a/atroposlib/envs/server_handling/server_manager.py b/atroposlib/envs/server_handling/server_manager.py new file mode 100644 index 00000000..8b1db618 --- /dev/null +++ b/atroposlib/envs/server_handling/server_manager.py @@ -0,0 +1,208 @@ +import asyncio +import os +from contextlib import asynccontextmanager +from typing import AsyncGenerator, List, Union + +from openai.types.chat.chat_completion import ChatCompletion +from openai.types.completion import Completion +from pydantic import BaseModel, Field + +from atroposlib.envs.server_handling.openai_server import OpenaiConfig, OpenAIServer +from atroposlib.envs.server_handling.server_harness import ServerHarness + + +class ServerManagerConfig(BaseModel): + slurm: bool = Field( + default=True, description="Whether environment is running on slurm or not." + ) + testing: bool = Field( + default=False, description="If set to True, environment uses mock OpenAI data." + ) + + +class ServerBaseline(BaseModel): + """ + Baseline configuration for server information. If local, uses ports 9004-9007 for the servers, + assuming a 1:1 split of GPUs. + """ + + timeout: int = Field( + default=1200, description="Timeout for the request in seconds." + ) + num_max_requests_at_once: int = Field( + default=512, + description="Maximum number of concurrent requests. You should divide this by the n kwarg.", + ) + num_requests_for_eval: int = Field( + default=64, description="Maximum number of concurrent requests for evaluation." + ) + model_name: str = Field( + default="default", + description="The model name to use. Only works with sglang, please provide the model name.", + ) + rolling_buffer_length: int = Field( + default=1000, description="Length of the rolling buffer to store metrics." + ) + + +class ServerManager: + def __init__( + self, + configs: Union[ServerBaseline, List[OpenaiConfig]], + slurm=False, + testing=False, + ): + if testing: + # testing :) + self.servers = [ServerHarness()] + return + if isinstance(configs, ServerBaseline): + urls = [] + if os.environ.get("SLURM_JOB_NODELIST", None) is not None: + nodelist = ( + os.popen( + f'scontrol show hostnames {os.environ["SLURM_JOB_NODELIST"]}' + ) + .read() + .split("\n") + ) + nodelist = [node for node in nodelist if node != ""] + if len(nodelist) < 2: + # localhost! + for i in range(4): + urls.append(f"http://localhost:{9000 + i + 4}/v1") + else: + num_training_nodes = int(os.environ.get("NUM_TRAINING_NODES")) + for node in nodelist[num_training_nodes:]: + for i in range(8 // os.environ.get("INFER_TP", 1)): + urls.append(f"http://{node}:{9000 + i}/v1") + openai_configs = [] + else: + # localhost! + for i in range(4): + urls.append(f"http://localhost:{9000 + i + 4}/v1") + openai_configs = [] + for url in urls: + openai_configs.append( + OpenaiConfig( + base_url=url, + timeout=configs.timeout, + num_max_requests_at_once=configs.num_max_requests_at_once, + num_requests_for_eval=configs.num_requests_for_eval, + model_name=configs.model_name, + rolling_buffer_length=configs.rolling_buffer_length, + api_key="x", + ) + ) + self.servers = [OpenAIServer(config) for config in openai_configs] + if not slurm: + self.servers = [OpenAIServer(config) for config in configs] + else: + nodelist = ( + os.popen(f'scontrol show hostnames {os.environ["SLURM_JOB_NODELIST"]}') + .read() + .split("\n") + ) + nodelist = [node for node in nodelist if node != ""] + if len(nodelist) < 2: + print( + "Not enough nodes to distribute to, assuming single node" + " and you've setup your sglang appropriately." + ) + self.servers = [OpenAIServer(config) for config in configs] + return + urls = [] + num_training_nodes = int(os.environ.get("NUM_TRAINING_NODES")) + for node in nodelist[num_training_nodes:]: + if node == "": + continue + for i in range(8 // os.environ.get("INFER_TP", 1)): + urls.append(f"http://{node}:{9000 + i}/v1") + # assume at least one good config is passed in + new_configs = [] + for i in range(len(urls)): + new_conf = configs[0].model_copy(deep=True) + new_conf.base_url = urls[i] + new_configs.append(new_conf) + self.servers = [OpenAIServer(config) for config in new_configs] + + async def update_weight(self, weight: float): + for server in self.servers: + await server.update_weight(weight) + + async def wait_for_sem(self, is_training): + """ + Wait for a server to be available. This is used to prevent the client from + overwhelming the server with requests. + """ + if is_training: + eval_vals = [ + ( + max(0, server.eval_sem._value - server.eval_sem.min_val()) + if server.eval_sem._value != server.eval_sem.max_val + else 0 + ) + for server in self.servers + ] + sem_vals = [ + max(0, (server.sem._value - server.sem.min_val()) - eval_val) + for server, eval_val in zip(self.servers, eval_vals) + ] + else: + sem_vals = [ + max(0, server.eval_sem._value - server.eval_sem.min_val()) + for server in self.servers + ] + while all([sem_val <= 0 for sem_val in sem_vals]): + # None available... wait + await asyncio.sleep(1) + + async def chat_completion(self, **kwargs) -> ChatCompletion: + is_train = kwargs.get("split", "train") == "train" + most_available_server = 0 + most_available_server_num_slots = -1 + await self.wait_for_sem(is_train) + for i, server in enumerate(self.servers): + if not server.server_healthy: + continue + if ( + server.sem._value if is_train else server.eval_sem._value + ) > most_available_server_num_slots: + most_available_server = i + most_available_server_num_slots = ( + server.sem._value if is_train else server.eval_sem._value + ) + return await self.servers[most_available_server].chat_completion(**kwargs) + + async def completion(self, **kwargs) -> Completion: + is_train = kwargs.get("split", "train") == "train" + most_available_server = 0 + most_available_server_num_slots = -1 + await self.wait_for_sem(is_train) + for i, server in enumerate(self.servers): + if not server.server_healthy: + continue + if ( + server.sem._value if is_train else server.eval_sem._value + ) > most_available_server_num_slots: + most_available_server = i + most_available_server_num_slots = ( + server.sem._value if is_train else server.eval_sem._value + ) + return await self.servers[most_available_server].completion(**kwargs) + + @asynccontextmanager + async def dedicated_server(self) -> AsyncGenerator[OpenAIServer, None]: + most_available_server = 0 + most_available_server_num_slots = -1 + for i, server in enumerate(self.servers): + if not server.server_healthy: + continue + if server.sem._value > most_available_server_num_slots: + most_available_server = i + most_available_server_num_slots = server.sem._value + async with self.servers[most_available_server].sem: + try: + yield self.servers[most_available_server] + finally: + pass diff --git a/atroposlib/tests/test_advantages.py b/atroposlib/tests/test_advantages.py new file mode 100644 index 00000000..151ebd2b --- /dev/null +++ b/atroposlib/tests/test_advantages.py @@ -0,0 +1,169 @@ +import math + +import pytest +import torch + +# Adjust the import below if your functions are in a different module. +from atroposlib.utils.advantages import ( + allclose_to_first, + compute_discounted_returns, + compute_grpo_process_supervision_advantages, + compute_stats, +) + + +def test_allclose_to_first_all_close(): + """Test that identical values return True.""" + values = [1.0, 1.0, 1.0] + result = allclose_to_first(values) + assert result is True + + +def test_allclose_to_first_vector(): + """Test that return_vector=True returns a tensor of booleans.""" + values = [1.0, 1.000000001, 1.000000002] + result = allclose_to_first(values, return_vector=True) + assert isinstance(result, torch.Tensor) + # All comparisons should be True. + assert torch.all(result) + + +def test_allclose_to_first_not_close(): + """Test that values which are not close yield False.""" + values = [1.0, 1.0, 1.1] + result = allclose_to_first(values) + assert result is False + + +def test_allclose_to_first_nan(): + """Test handling of NaN values with equal_nan parameter.""" + values = [float("nan"), float("nan")] + # With equal_nan False, the result should be False. + result = allclose_to_first(values, equal_nan=False) + assert result is False + # With equal_nan True, NaNs are treated as equal. + result = allclose_to_first(values, equal_nan=True) + assert result is True + + +def test_compute_stats(): + """Test compute_stats with a nested list of numbers.""" + data = [1, 2, 3, [4, 5]] + stats = compute_stats(data) + # mean = (1+2+3+4+5)/5 = 3.0 + assert math.isclose(stats["mean"], 3.0, rel_tol=1e-5) + # variance = (11 - 9) = 2.0, since average of squares = 55/5 = 11 and mean^2 = 9. + assert math.isclose(stats["var"], 2.0, rel_tol=1e-5) + + +def test_compute_stats_empty(): + """Test that an empty list raises a ValueError.""" + with pytest.raises(ValueError): + compute_stats([]) + + +def test_compute_stats_jagged(): + """Test compute_stats with a deeper, jagged nested list.""" + data = [[1, 2], 3, [4, [5, 6]]] + stats = compute_stats(data) + expected_mean = (1 + 2 + 3 + 4 + 5 + 6) / 6 # 21/6 = 3.5 + expected_var = ((1**2 + 2**2 + 3**2 + 4**2 + 5**2 + 6**2) / 6) - expected_mean**2 + assert math.isclose(stats["mean"], expected_mean, rel_tol=1e-5) + assert math.isclose(stats["var"], expected_var, rel_tol=1e-5) + + +def test_compute_discounted_returns(): + """Test compute_discounted_returns with a tensor input.""" + rewards = torch.tensor([1.0, 1.0, 1.0]) + gamma = 0.9 + returns = compute_discounted_returns(rewards, gamma) + # For a 3-element vector: + # t=2: 1.0 + # t=1: 1.0 + 0.9*1.0 = 1.9 + # t=0: 1.0 + 0.9*1.9 = 2.71 + expected = torch.tensor([2.71, 1.9, 1.0]) + assert torch.allclose(returns, expected, rtol=1e-5, atol=1e-8) + + +def test_compute_discounted_returns_list_input(): + """Test compute_discounted_returns when the input is a list.""" + rewards = [1, 1, 1] + gamma = 0.0 # With gamma=0, the returns should equal the rewards. + returns = compute_discounted_returns(rewards, gamma) + expected = torch.tensor([1.0, 1.0, 1.0]) + assert torch.allclose(returns, expected, rtol=1e-5, atol=1e-8) + + +def test_compute_grpo_process_supervision_advantages_cumsum(): + """ + Test compute_grpo_process_supervision_advantages with gamma=None, + which should now compute a reversed cumulative sum on normalized rewards. + For each trajectory, the expected advantage at index i is the sum of normalized rewards from i to the end. + """ + rewards = [[1, 2, 3], [4, 5]] + advantages = compute_grpo_process_supervision_advantages(rewards, gamma=None) + # Compute normalized rewards using flattened stats (mean=3, var=2 so std=sqrt(2)) + sqrt2 = math.sqrt(2) + # For trajectory 1, normalized rewards: + # Reversed cumulative sum: + # index 0: sum(traj1) = (-2/sqrt2) + (-1/sqrt2) + 0 = -3/sqrt2 + # index 1: sum(traj1[1:]) = (-1/sqrt2) + 0 = -1/sqrt2 + # index 2: sum(traj1[2:]) = 0 + expected_traj1 = [-3 / sqrt2, -1 / sqrt2, 0] + # For trajectory 2, normalized rewards: + # Reversed cumulative sum: + # index 0: (1/sqrt2) + (2/sqrt2) = 3/sqrt2 + # index 1: (2/sqrt2) + expected_traj2 = [3 / sqrt2, 2 / sqrt2] + + adv1 = advantages[0].tolist() + adv2 = advantages[1].tolist() + + for computed, expected in zip(adv1, expected_traj1): + assert math.isclose( + computed, expected, rel_tol=1e-5 + ), f"Computed {computed} vs expected {expected} in trajectory 1" + for computed, expected in zip(adv2, expected_traj2): + assert math.isclose( + computed, expected, rel_tol=1e-5 + ), f"Computed {computed} vs expected {expected} in trajectory 2" + + +def test_compute_grpo_process_supervision_advantages_discounted(): + """ + Test compute_grpo_process_supervision_advantages with a provided gamma, + which should compute discounted returns on normalized rewards. + """ + rewards = [[1, 2, 3], [4, 5]] + gamma = 0.9 + advantages = compute_grpo_process_supervision_advantages(rewards, gamma=gamma) + sqrt2 = math.sqrt(2) + # Normalized first trajectory: + a1 = (1 - 3) / sqrt2 # -2/sqrt2 + a2 = (2 - 3) / sqrt2 # -1/sqrt2 + a3 = (3 - 3) / sqrt2 # 0 + # Discounted returns for trajectory 1: + # t=2: a3 + # t=1: a2 + gamma * a3 = a2 + # t=0: a1 + gamma * (a2 + gamma * a3) = a1 + gamma * a2 + expected_traj1 = [a1 + gamma * a2, a2, a3] + # Normalized second trajectory: + b1 = (4 - 3) / sqrt2 # 1/sqrt2 + b2 = (5 - 3) / sqrt2 # 2/sqrt2 + # Discounted returns for trajectory 2: + # t=1: b2 + # t=0: b1 + gamma * b2 + expected_traj2 = [b1 + gamma * b2, b2] + adv1 = advantages[0].tolist() + adv2 = advantages[1].tolist() + for computed, expected in zip(adv1, expected_traj1): + assert math.isclose(computed, expected, rel_tol=1e-5) + for computed, expected in zip(adv2, expected_traj2): + assert math.isclose(computed, expected, rel_tol=1e-5) + + +def test_compute_grpo_process_supervision_advantages_std_tol(): + """Test that a constant reward trajectory raises ValueError due to low std.""" + rewards = [[1, 1, 1]] + with pytest.raises(ValueError): + compute_grpo_process_supervision_advantages(rewards) diff --git a/atroposlib/tests/test_utils/__init__.py b/atroposlib/tests/test_utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/atroposlib/tests/test_utils/test_heterogeneous_batching.py b/atroposlib/tests/test_utils/test_heterogeneous_batching.py new file mode 100644 index 00000000..f38f7831 --- /dev/null +++ b/atroposlib/tests/test_utils/test_heterogeneous_batching.py @@ -0,0 +1,28 @@ +import random + +from atroposlib.api.utils import grab_exact_from_heterogeneous_queue + + +def test_grab_exact_from_heterogeneous_queue(): + "randomly samples from the space of potential inputs to grab_exact_from_heterogeneous_queue" + for random_bs in range(10000): + bs = 64 * random.randint(1, 20) + queue = [] + for i in range(random.randint(1, 100)): + # queue.append( + # { + # "tokens": [[2 * i] for _ in range(2)], + # } + # ) + queue.append( + { + "tokens": [[2 * i + 1] for _ in range(8)], + } + ) + batch, queue = grab_exact_from_heterogeneous_queue(queue, bs) + if random_bs == 0: + print(batch) + if batch is not None: + assert ( + sum(len(item["tokens"]) for item in batch) == bs + ), f"expected batch size {bs}, got {len(batch)}" diff --git a/atroposlib/type_definitions.py b/atroposlib/type_definitions.py new file mode 100644 index 00000000..444af95d --- /dev/null +++ b/atroposlib/type_definitions.py @@ -0,0 +1,70 @@ +from typing import Any, Dict, List, Literal, Optional, TypedDict + +from openai.types.chat import ChatCompletionContentPartParam + +Content = str | list[ChatCompletionContentPartParam] +Item = Any +number = int | float +UUID = str + + +class Message(TypedDict): + role: Literal["system", "user", "assistant", "tool"] + content: Content + reward: Optional[float] + + +class AgentStep(TypedDict, total=False): + """Represents a single step in an agent's history. + + Attributes: + step: The step number. + messages: A list of messages exchanged during the step. + reward: The reward received at this step. + """ + + step: int + messages: List[Message] + reward: float + + +# AgentHistory maps agent ids (e.g. "Player 1", "Player 2") to their respective list of steps. +AgentHistory = Dict[str, List[AgentStep]] + + +class Observation(TypedDict): + """Represents an observation in a game history. + + Attributes: + raw: The raw observation data (as a dictionary). + rendered: The rendered string of the observation suitable for input into an LLM. + """ + + raw: Dict[str, Any] + rendered: Content + + +class GameStep(TypedDict): + """Represents a single step in a game history. Essentially an (s,a,r) triple with metadata. + + Attributes: + step: The step number. + agent: The agent who took the action (optional for final steps). + observation: The observation at this step. + action: The action taken by the agent (if any). + reward: The reward received; can be a float or a dictionary mapping agent names to rewards. + done: A flag indicating whether the game has ended after this step. + info: Additional information related to the step. + """ + + step: int + agent_id: str + observation: Observation + action: str + reward: float | Dict[str, float] + done: bool + info: Dict[str, Any] + + +# GameHistory is represented as a list of game steps. +GameHistory = List[GameStep] diff --git a/atroposlib/utils/__init__.py b/atroposlib/utils/__init__.py new file mode 100644 index 00000000..98fd052e --- /dev/null +++ b/atroposlib/utils/__init__.py @@ -0,0 +1,7 @@ +""" +Utility functions and classes for the atroposlib package. +""" + +from .config_handler import ConfigHandler + +__all__ = ["ConfigHandler"] diff --git a/atroposlib/utils/advantages.py b/atroposlib/utils/advantages.py new file mode 100644 index 00000000..dcb31b60 --- /dev/null +++ b/atroposlib/utils/advantages.py @@ -0,0 +1,173 @@ +from typing import Sequence + +import torch + +from atroposlib.type_definitions import number + +TensorLike = torch.Tensor | Sequence[torch.Tensor] | Sequence[Sequence] +# Type alias for vector of bools +BoolVector = torch.Tensor + + +def allclose_to_first( + values: TensorLike, + rtol: float = 1e-05, + atol: float = 1e-08, + equal_nan: bool = False, + return_vector: bool = False, +) -> BoolVector | bool: + """ + Check if all tensors in `values` are close to the first tensor `values[0]` using a vectorized approach. + + If `return_vector` is False (default), returns a single boolean indicating whether + every tensor is close to the first tensor. If `return_vector` is True, returns a list + of booleans where each element corresponds to whether the respective tensor in + `values` is close to the first tensor. The first element is always True. + + Args: + values (torch.Tensor | Sequence[torch.Tensor] | Sequence[Sequence]): + Nested list of values to compare. Must be rectangular, but not necessarily 2D. + rtol (float, optional): Relative tolerance. Defaults to 1e-05. + atol (float, optional): Absolute tolerance. Defaults to 1e-08. + equal_nan (bool, optional): Whether to consider NaNs as equal. Defaults to False. + return_vector (bool, optional): If True, returns a list of booleans for each comparison. + Defaults to False. + + Returns: + bool or BoolVector: + - If `return_vector` is False, returns True if all tensors are close to the first tensor; + otherwise, returns False. + - If `return_vector` is True, returns a 1D tensor of bools where the first element is True + (as the reference tensor is trivially close to itself), and each subsequent element indicates + whether the corresponding tensor is close to the first tensor. + """ + if not isinstance(values, torch.Tensor): + values = torch.tensor(values) + + reference = values[0] + is_close = torch.isclose( + values, reference, rtol=rtol, atol=atol, equal_nan=equal_nan + ) + + # flatten dimensions after first + result_vector = torch.all(is_close.view(is_close.size(0), -1), dim=1) + + return result_vector if return_vector else bool(torch.all(result_vector)) + + +def compute_stats(data: Sequence[number | Sequence]) -> dict[str, float]: + """Compute mean and standard deviation from a possibly jagged nested list. + + This function recursively traverses a nested list of numbers (ints or floats) + and computes the overall mean and standard deviation of the flattened list. + + Args: + data (Sequence[number | Sequence]): A possibly jagged nested list of numbers. + + Returns: + dict: A dictionary with two keys: + - "mean": The mean of all numerical elements. + - "var": The variance of all numerical elements. + """ + + def accumulate(x): + """Recursively accumulate the sum, sum of squares, and count of numerical values. + + Args: + x (int | float | list): A number or a nested list of numbers. + + Returns: + tuple: A tuple of three elements: + - total (float): Sum of all numbers. + - total_sq (float): Sum of squares of all numbers. + - count (int): Count of numerical elements. + """ + if isinstance(x, (int, float)): + return x, x * x, 1 + elif isinstance(x, list): + total, total_sq, count = 0.0, 0.0, 0 + for item in x: + s, ss, c = accumulate(item) + total += s + total_sq += ss + count += c + return total, total_sq, count + else: + raise ValueError(f"Invalid element type encountered: {type(x)}") + + total, total_sq, count = accumulate(data) + if count == 0: + raise ValueError("No numerical elements found in the input data.") + + mean = total / count + variance = total_sq / count - mean * mean + return {"mean": mean, "var": variance} + + +def compute_discounted_returns(rewards: torch.Tensor, gamma: float) -> torch.Tensor: + """Compute discounted returns from a 1D vector of rewards. + + Given a list or torch tensor of rewards and a discount factor, this function computes + the discounted return at each timestep. The discounted return at time t is defined as: + G_t = rewards[t] + gamma * rewards[t+1] + gamma^2 * rewards[t+2] + ... + + Args: + rewards (list[float] or torch.Tensor): A 1D list or tensor of rewards. + gamma (float): The discount factor (should be between 0 and 1). + + Returns: + list[float]: A list containing the discounted returns for each timestep. + """ + if not isinstance(rewards, torch.Tensor): + rewards = torch.tensor(rewards, dtype=torch.float) + discounted_returns = torch.empty_like(rewards) + running_return = 0.0 + + for t in reversed(range(len(rewards))): + running_return = rewards[t] + gamma * running_return + discounted_returns[t] = running_return + + return discounted_returns + + +def compute_grpo_process_supervision_advantages( + rewards: Sequence[Sequence[number]], gamma: float = None, std_tol: float = 1e-8 +) -> list[torch.Tensor]: + """ + Given a (possibly jagged) list of list of rewards, compute advantages for GRPO. + + Args: + rewards (Sequence[Sequence[number]]): A list of list of rewards. Each inner list + contains a reward for each "step" in a trajectory, where a "step" is an + abstract unit of time left up to the modeler to define. + gamma (float): The discount factor. + std_tol (float): The tolerance for the standard deviation. + + Returns: + A list of tensors of advantages. + + Raises: + ValueError: If the standard deviation of the flattened rewards is smaller than the tolerance. + """ + stats = compute_stats(rewards) + mean = stats["mean"] + std = stats["var"] ** 0.5 + if std < std_tol: + raise ValueError(f"`std` is smaller than tolerance of {std_tol}.") + + normalized_rewards = [ + (torch.tensor(trajectory) - mean) / std for trajectory in rewards + ] + + if gamma is None: + advantages = [ + trajectory.flip(dims=[0]).cumsum(dim=0).flip(dims=[0]) + for trajectory in normalized_rewards + ] + else: + advantages = [ + compute_discounted_returns(trajectory, gamma) + for trajectory in normalized_rewards + ] + + return advantages diff --git a/atroposlib/utils/config_handler.py b/atroposlib/utils/config_handler.py new file mode 100644 index 00000000..3cfd1f7b --- /dev/null +++ b/atroposlib/utils/config_handler.py @@ -0,0 +1,184 @@ +import argparse +import os +from typing import Any, Dict, Optional + +import torch +import yaml + + +class ConfigHandler: + """Handles loading and merging of configuration files with CLI overrides""" + + def __init__(self, config_dir: Optional[str] = None): + self.config_dir = config_dir or os.path.join( + os.path.dirname(__file__), "../../configs" + ) + self.parser = self._setup_argument_parser() + + def _setup_argument_parser(self) -> argparse.ArgumentParser: + parser = argparse.ArgumentParser(description="Training configuration") + + # Config files + parser.add_argument( + "--env", + type=str, + default="crosswords", + help="Environment config file name (without .yaml)", + ) + parser.add_argument( + "--agent", + type=str, + default="nous_hermes", + help="Agent config file name (without .yaml)", + ) + parser.add_argument( + "--config", + type=str, + help="Configuration file name (without .yaml)", + ) + + # CLI overrides + parser.add_argument("--group-size", type=int, help="Override group size") + parser.add_argument("--total-steps", type=int, help="Override total steps") + parser.add_argument("--batch-size", type=int, help="Override batch size") + parser.add_argument("--seed", type=int, help="Override random seed") + parser.add_argument("--device", type=str, help="Override device (cuda/cpu/mps)") + parser.add_argument("--server-url", type=str, help="Override server URL") + + # Dataset-specific overrides + parser.add_argument("--dataset-name", type=str, help="Override dataset name") + parser.add_argument("--dataset-split", type=str, help="Override dataset split") + parser.add_argument( + "--prompt-field", type=str, help="Override prompt field name" + ) + parser.add_argument( + "--answer-field", type=str, help="Override answer field name" + ) + parser.add_argument("--system-prompt", type=str, help="Override system prompt") + parser.add_argument( + "--max-generations", type=int, help="Override max generations per prompt" + ) + parser.add_argument( + "--reward-funcs", + type=str, + nargs="+", + help="Override reward functions to use", + ) + + return parser + + def _load_yaml(self, path: str) -> Dict[str, Any]: + """Load a YAML configuration file""" + with open(path, "r") as f: + return yaml.safe_load(f) + + def _determine_device(self, config: Dict[str, Any]) -> str: + if config.get("device") == "auto": + if torch.backends.mps.is_available(): + return "mps" + elif torch.cuda.is_available(): + return "cuda" + return "cpu" + return config.get("device", "cpu") + + def load_config(self, args: Optional[argparse.Namespace] = None) -> Dict[str, Any]: + """Load and merge configurations with CLI overrides""" + if args is None: + args = self.parser.parse_args() + + # environment config + config = self._load_yaml(os.path.join(self.config_dir, f"envs/{args.env}.yaml")) + + # agent/model config + agent_config = self._load_yaml( + os.path.join(self.config_dir, f"agents/{args.agent}.yaml") + ) + config["agent"] = agent_config + + # CLI overrides + if args.group_size: + config["group_size"] = args.group_size + if args.total_steps: + config["total_steps"] = args.total_steps + if args.batch_size: + config["batch_size"] = args.batch_size + if args.seed: + config["initial_seed"] = args.seed + if args.device: + config["agent"]["device"] = args.device + if args.server_url: + config["rollout_server_url"] = args.server_url + + # Ensure player_names is populated based on group_size + if "env_kwargs" in config and "player_names" in config["env_kwargs"]: + config["env_kwargs"]["player_names"] = { + i: f"Player_{i}" for i in range(config["group_size"]) + } + + config["agent"]["device"] = self._determine_device(config["agent"]) + + return config + + def load_dataset_config( + self, args: Optional[argparse.Namespace] = None + ) -> Dict[str, Any]: + """Load and merge dataset environment configurations with CLI overrides""" + if args is None: + args = self.parser.parse_args() + + # Start with base environment config + config = self._load_yaml(os.path.join(self.config_dir, f"envs/{args.env}.yaml")) + + # Load agent config + agent_config = self._load_yaml( + os.path.join(self.config_dir, f"agents/{args.agent}.yaml") + ) + config["agent"] = agent_config + + # Load dataset config if specified + if args.config: + dataset_config = self._load_yaml( + os.path.join(self.config_dir, f"datasets/{args.config}.yaml") + ) + # Merge dataset config with main config instead of nesting + for key, value in dataset_config.items(): + config[key] = value + + # Apply CLI overrides for common parameters + if args.group_size: + config["group_size"] = args.group_size + if args.total_steps: + config["total_steps"] = args.total_steps + if args.batch_size: + config["batch_size"] = args.batch_size + if args.seed: + config["initial_seed"] = args.seed + if args.device: + config["agent"]["device"] = args.device + if args.server_url: + config["rollout_server_url"] = args.server_url + + # Apply dataset-specific overrides + if "dataset" in config: + if args.dataset_name: + config["dataset"]["dataset_name"] = args.dataset_name + if args.dataset_split: + config["dataset"]["split"] = args.dataset_split + if args.prompt_field: + config["dataset"]["prompt_field"] = args.prompt_field + if args.answer_field: + config["dataset"]["answer_field"] = args.answer_field + if args.system_prompt: + config["dataset"]["system_prompt"] = args.system_prompt + if args.max_generations: + config["dataset"]["max_generations_per_prompt"] = args.max_generations + if args.reward_funcs: + config["dataset"]["reward_funcs"] = args.reward_funcs + + # Set device + config["agent"]["device"] = self._determine_device(config["agent"]) + + # Add slurm flag to config if running in a Slurm environment + config["use_slurm"] = "SLURM_JOB_ID" in os.environ + + return config \ No newline at end of file diff --git a/atroposlib/utils/force_diverse_samples.py b/atroposlib/utils/force_diverse_samples.py new file mode 100644 index 00000000..951d8c54 --- /dev/null +++ b/atroposlib/utils/force_diverse_samples.py @@ -0,0 +1,112 @@ +import math +import random + + +# TODO: move this to the server manager +async def generate_with_diverse_first_tokens( + self, messages, prefill="", n=8, max_tokens=4096, temperature=1.0 +): + """ + Generate diverse completions by sampling different first tokens. + + Parameters: + - messages: List of message dictionaries for chat completion + - prefill: Prefix text to add to assistant's message + - n: Number of diverse completions to generate + - max_tokens: Maximum tokens per completion + - temperature: Sampling temperature + + Returns: + - List of completion strings + """ + # Step 1: First get the logprobs for just the first token + first_token_messages = messages + [{"role": "assistant", "content": prefill}] + + first_token_completion = await self.server.chat_completion( + messages=first_token_messages, + n=1, + max_tokens=1, + temperature=0.0, # Use 0 temperature to get raw logprobs + logprobs=True, + top_logprobs=20, # Get top 20 logprobs for the first token + ) + + # Extract logprobs from the completion + try: + # Get the logprobs for the first token + logprobs_dict = ( + first_token_completion.choices[0].logprobs.content[0].top_logprobs + ) + + # Convert to list of (token, logprob) tuples + logprobs_list = [(item.token, item.logprob) for item in logprobs_dict] + + # Convert logprobs to probabilities with temperature + logprobs_array = [lp for _, lp in logprobs_list] + probs = [math.exp(lp / temperature) for lp in logprobs_array] + total = sum(probs) + probs = [p / total for p in probs] + + # Sample n unique tokens + sampled_indices = random.choices( + range(len(logprobs_list)), weights=probs, k=min(n, len(logprobs_list)) + ) + + # Ensure unique indices + sampled_indices = list(set(sampled_indices)) + + # If we have fewer than n tokens, sample again to fill + while len(sampled_indices) < n and len(sampled_indices) < len(logprobs_list): + remaining = min( + n - len(sampled_indices), len(logprobs_list) - len(sampled_indices) + ) + available_indices = [ + i for i in range(len(logprobs_list)) if i not in sampled_indices + ] + available_probs = [probs[i] for i in available_indices] + total = sum(available_probs) + if total > 0: + available_probs = [p / total for p in available_probs] + additional_indices = random.choices( + available_indices, weights=available_probs, k=remaining + ) + sampled_indices.extend(additional_indices) + else: + # If all remaining probs are 0, just pick randomly + additional_indices = random.sample(available_indices, k=remaining) + sampled_indices.extend(additional_indices) + + # Get the selected first tokens + first_tokens = [logprobs_list[i][0] for i in sampled_indices] + + except (AttributeError, IndexError, KeyError) as e: + # Fallback if we can't extract logprobs properly + print(f"Error extracting logprobs: {e}") + return await self.fallback_generate( + messages, prefill, n, max_tokens, temperature + ) + + # Step 2: Generate completions with each selected first token + completions = [] + for token in first_tokens: + # Create a prompt with the first token already included + prompt_with_token = messages + [ + {"role": "assistant", "content": prefill + token} + ] + + # Generate the rest of the completion + completion = await self.server.chat_completion( + messages=prompt_with_token, + n=1, + max_tokens=max_tokens - 1, # Subtract 1 for the token we already used + temperature=temperature, + top_p=0.3, + extra_body={ + "min_p": 0.5, + "repetition_penalty": 1.05, + }, + ) + + # Extract the completion content and remove the prefill+token + full_content = completion.choices[0].message.content + completions.append(token + full_content) diff --git a/atroposlib/utils/metrics.py b/atroposlib/utils/metrics.py new file mode 100644 index 00000000..0a1c75d4 --- /dev/null +++ b/atroposlib/utils/metrics.py @@ -0,0 +1,19 @@ +import numpy as np + + +def get_std_min_max_avg(name: str, data: list, metrics_dict: dict) -> dict: + """ + Calculate the standard deviation, minimum, maximum, and average of a list of numbers. + Adds it to the wandb dict for logging. + + Args: + data (list): A list of numbers. + + Returns: + dict: A dictionary containing the standard deviation, minimum, maximum, and average. + """ + metrics_dict[f"{name}_mean"] = np.mean(data) + metrics_dict[f"{name}_std"] = np.std(data) + metrics_dict[f"{name}_max"] = np.max(data) + metrics_dict[f"{name}_min"] = np.min(data) + return metrics_dict diff --git a/atroposlib/utils/tokenize_for_trainer.py b/atroposlib/utils/tokenize_for_trainer.py new file mode 100644 index 00000000..8d9ea3dc --- /dev/null +++ b/atroposlib/utils/tokenize_for_trainer.py @@ -0,0 +1,192 @@ +import torch +from transformers import PreTrainedTokenizer + +from atroposlib.type_definitions import Message + +# Roles that should be masked in the loss calculation (not used for training) +UNMASKED_ROLES = ["assistant"] + + +def tokenize_for_trainer( + tokenizer: PreTrainedTokenizer, + chat: list[Message], + include_messages: bool = False, + train_on_all_assistant_turns: bool = False, + finish_reason: str = "", +) -> dict: + """ + Tokenize a list of chat messages for the trainer. + + Args: + tokenizer (PreTrainedTokenizer): The tokenizer to use. + chat (list): A list of chat messages. + include_messages (bool): Whether to include the messages in the output. + train_on_all_assistant_turns (bool): If True, mask out system/user/tool roles. + If False, use the original prefix masking. + Returns: + dict: A dictionary containing the tokenized chat messages. + """ + + tokens = tokenizer.apply_chat_template(chat) + + if not train_on_all_assistant_turns: + prefix_len = len( + tokenizer.apply_chat_template(chat[:-1], add_generation_prompt=True) + ) + masks = [-100] * prefix_len + tokens[prefix_len:] + else: + # NOTE: This implementation will break if the default system prompt is used and depends on world state + # (e.g. current date). e.g. consider a system prompt that depends on the current date and a run that crosses + # midnight from 3/9 to 3/10 under a tokenizer that tokenizes 3/9 and 3/10 with a different number of tokens. + + masks = torch.ones(len(tokens), dtype=torch.long) * -100 + + for i, msg in enumerate(chat): + if msg["role"] in UNMASKED_ROLES: + prefix_tokens = tokenizer.apply_chat_template( + chat[:i], tokenize=True, add_generation_prompt=True + ) + unmasked_tokens = tokenizer.apply_chat_template( + chat[: i + 1], tokenize=True + ) + start_idx = len(prefix_tokens) + end_idx = len(unmasked_tokens) + masks[start_idx:end_idx] = torch.tensor(unmasked_tokens[start_idx:]) + + masks = masks.tolist() + if finish_reason == "length": + if tokens[-1] == tokenizer.eos_token_id: + print("bad token\n") + # truncate the last token + tokens = tokens[:-1] + masks = masks[:-1] + + return { + "tokens": tokens, + "masks": masks, + } | ({"messages": chat} if include_messages else {}) + + +if __name__ == "__main__": + + # Inspired by `preprocess --debug`` of https://github.com/axolotl-ai-cloud/axolotl + def decode_token_ids( + token_ids: list, mask, tokenizer, use_rich: bool = False + ) -> str: + """Convert a list of token IDs to a formatted string using tokenizer.decode, + with an option to highlight masked tokens in red using rich markup. + + Each token is represented as decoded(tokenid, mask). If decoding a token returns an empty string + and the token is a known special token, it is replaced with a descriptive placeholder. + When use_rich is True, any token whose corresponding mask is -100 is wrapped with red highlighting. + + Args: + token_ids (list[int]): A list of integer token IDs, + e.g. [50256, 329]. + mask (list[int]): A list of masks corresponding to token_ids. + A mask value of -100 indicates the token is masked. + tokenizer: The Hugging Face tokenizer. + use_rich (bool): If True, wrap tokens with a mask of -100 in red highlighting. + Defaults to False. + + Returns: + str: A space-separated string where each token is represented as decoded(tokenid, mask). + + Raises: + ValueError: If any element in token_ids is not an integer. + + Example: + >>> decode_token_ids([50256, 329], mask=[-100, 329], tokenizer=tokenizer, use_rich=True) + '[red]<|eos|>(50256, -100)[/red] ' + 'tokenX(329, 329)' # (actual output will vary based on the model's tokenizer) + """ + # Validate that all token_ids are integers. + if not all(isinstance(t, int) for t in token_ids): + raise ValueError("All token IDs must be integers.") + + tokens_str_list = [] + for tid, mid in zip(token_ids, mask): + # Use decode with flags to include special tokens. + decoded = tokenizer.decode( + [tid], skip_special_tokens=False, clean_up_tokenization_spaces=False + ).strip() + # If the decoded string is empty and it's a special token, replace with a placeholder. + if not decoded and tid in tokenizer.all_special_ids: + if tid == tokenizer.eos_token_id: + decoded = "<|eos|>" + else: + decoded = f"" + + # Highlight token in red if use_rich is True and the token is masked (mid == -100) + if use_rich: + if mid == -100: + token_str = f"[pink3][bold]{decoded}[/bold][/pink3][steel_blue]({tid}, {mid})[/steel_blue]" + else: + token_str = ( + f"[pale_green3][bold]{decoded}[/bold][/pale_green3]" + f"[steel_blue]({tid}, {mid})[/steel_blue]" + ) + else: + token_str = f"{decoded}({tid}, {mid})" + + tokens_str_list.append(token_str) + + return " ".join(tokens_str_list) + + messages = [ + { + "role": "system", + "content": "You are a helpful AI assistant that provides accurate information.", + }, + {"role": "user", "content": "What's the capital of France?"}, + {"role": "assistant", "content": "The capital of France is Paris."}, + {"role": "user", "content": "Can you tell me more about Paris?"}, + { + "role": "assistant", + "content": "{'tool_name': 'web_search', 'args': {'query': 'Paris'}}", + }, + { + "role": "tool", + "content": ( + "Paris is the capital and most populous city of France. " + "It has an estimated population of 2,165,423 residents in 2019 " + "in an area of more than 105 km²." + ), + }, + { + "role": "assistant", + "content": ( + "Paris is indeed the capital of France and its most populous city with over 2 million residents. " + "It's known for its iconic landmarks like the Eiffel Tower, Louvre Museum, and Notre-Dame Cathedral. " + "The city is a global center for art, fashion, gastronomy, and culture." + ), + }, + ] + + from transformers import AutoTokenizer + + tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct") + + last_turn_only = tokenize_for_trainer( + tokenizer, messages, train_on_all_assistant_turns=False + ) + last_turn_only["repr"] = decode_token_ids( + last_turn_only["tokens"], last_turn_only["masks"], tokenizer, use_rich=True + ) + all_assistant_turns = tokenize_for_trainer( + tokenizer, messages, train_on_all_assistant_turns=True + ) + all_assistant_turns["repr"] = decode_token_ids( + all_assistant_turns["tokens"], + all_assistant_turns["masks"], + tokenizer, + use_rich=True, + ) + + from rich import print + + print("[bold cyan]last turn only[/]") + print(last_turn_only["repr"]) + print() + print("[bold cyan]all assistant turns[/]") + print(all_assistant_turns["repr"]) diff --git a/environments/README.md b/environments/README.md new file mode 100644 index 00000000..3d8ddf23 --- /dev/null +++ b/environments/README.md @@ -0,0 +1,129 @@ +# Environments + +This directory contains various environments for training and evaluating language models on different tasks. Each environment implements a specific task with its own input format, reward function, and evaluation metrics. + +## Available Environments + +--- + +### MCQA Thinking Environment (`mcqa_thinking_env.py`) + +Multiple Choice Question Answering environment that requires models to think through problems systematically. + +**Input Format:** +- Questions from the MMLU (Massive Multitask Language Understanding) dataset +- Each item contains: + - `prompt`: The question text + - `answer`: Index of correct answer + - `ground_truth`: Letter (A, B, C, D) of correct answer + - `options`: List of possible answers + +**System Prompt:** +``` +You are a deep thinking AI, you may use extremely long chains of thought to deeply consider the problem and deliberate with yourself via systematic reasoning processes to help come to a correct solution prior to answering. You should enclose your thoughts and internal monologue inside tags, and then provide your solution or response to the problem. +``` + +**Reward Function:** +- Score of 1.0 if the model's answer matches the ground truth letter +- Score of 0.0 if incorrect or invalid response (multiple think tags, malformed thinking sections) +- Length penalty applied if all responses are correct: + - No penalty for responses under 50% of max token length + - Linear penalty scaling from 1.0 down to 0.0 for responses between 50% and 100% of max length + - Returns None if all scores are identical (no learning signal) + +--- + +### GSM8K Environment (`gsm8k_server.py`) + +Mathematical reasoning environment using the GSM8K dataset. + +**Input Format:** +- Questions from GSM8K dataset +- Each item contains: + - `question`: The math problem + - `answer`: The numerical answer + +**System Prompt:** +``` +You are a deep thinking AI, you may use extremely long chains of thought to deeply consider the problem and deliberate with yourself via systematic reasoning processes to help come to a correct solution prior to answering. You should enclose your thoughts and internal monologue inside tags, and then provide your solution or response to the problem. + +You are allocated a maximum of 2048 tokens, please strive to use less. + +You will then provide your answer like this: \boxed{your answer here} +It is important that you provide your answer in the correct format. +If you do not, you will not receive credit for your answer. +So please end your answer with \boxed{your answer here} +``` + +**Reward Function:** +- Score of 1.0 if the model's answer matches the ground truth (using LaTeX verification) +- Score of 0.0 if incorrect or if ground truth is not parseable +- Length penalty applied if all responses are correct: + - No penalty for responses under 50% of max token length + - Linear penalty scaling from 1.0 down to 0.0 for responses between 50% and 100% of max length + - Returns None if all scores are identical (no learning signal) + +--- + +### Tool Calling Environment (`tool_calling_server.py`) + +Environment for training models to make function calls in a structured format. + +**Input Format:** +- Conversations from ShareGPT-Hermes function call dataset +- Each item contains: + - `conversations`: List of messages with roles (system, human, gpt) + - Expected tool calls in JSON format + +**System Prompt:** +``` +You are a deep thinking AI, you may use extremely long chains of thought to deeply consider the problem and deliberate with yourself via systematic reasoning processes to help come to a correct solution prior to answering. You should enclose your thoughts and internal monologue inside tags, and then provide your solution or response to the problem. +``` + +**Reward Function:** +- Score of 1.0 if all expected tool calls are present and match exactly (including nested JSON fields) +- Score of 0.0 if any tool calls are missing, incorrect, or malformed +- Length penalty applied if all responses are correct: + - No penalty for responses under 50% of max token length + - Linear penalty scaling from 1.0 down to 0.0 for responses between 50% and 100% of max length + - Returns None if all scores are identical (no learning signal) + +## Common Features + +All environments share these common features: + +1. **Training/Test Split:** + - 98% training, 2% test split + - Random shuffling with fixed seed (42) + +2. **Metrics Tracking:** + - Percent correct buffer + - Completion lengths + - Wandb integration for visualization + - Rollout tracking + +3. **Token Management:** + - Maximum token length limits + - Token length statistics tracking + - Length penalty for excessive responses + +4. **Evaluation:** + - Separate evaluation on test set + - Comprehensive metrics logging + - Support for multiple model completions per prompt + +## Usage + +Each environment can be initialized with: +- `config`: BaseEnvConfig object +- `server_configs`: List of OpenAI API configurations +- `slurm`: Boolean for distributed training +- `testing`: Boolean for testing mode + +The environments follow a common interface with methods for: +- `setup()`: Loading and preparing datasets +- `get_next_item()`: Retrieving next training item +- `collect_trajectories()`: Generating model responses +- `score()`: Computing rewards +- `evaluate()`: Running evaluation on test set +- `wandb_log()`: Logging metrics to Weights & Biases diff --git a/environments/dataset_environment/LOCAL_TESTING.md b/environments/dataset_environment/LOCAL_TESTING.md new file mode 100644 index 00000000..a5c8eb87 --- /dev/null +++ b/environments/dataset_environment/LOCAL_TESTING.md @@ -0,0 +1,155 @@ +# Dataset Environment Local Testing Guide + +This document explains how to run the Dataset Environment locally for testing purposes. + +## Prerequisites + +1. Make sure you have the repository cloned and dependencies installed +2. Ensure you have a compatible model available (local or API) + +## Option 1: Single Script End-to-End Execution + +The easiest way to test the Dataset Environment is to use the unified launcher script: + +```bash +python -m environments.dataset_environment.launch_local_dataset_run +``` + +This script: +1. Starts the Trajectory Handler API server via uvicorn +2. Launches the Dataset Environment in serve mode (connected to the API) +3. Runs the example GRPO trainer directly + +The script has environment defaults configured for: +- Using a small LLM (Qwen2.5-1.5B) running on localhost:9001 +- A test subset of a public HF dataset +- Basic length-based rewards + +## Option 2: Step-by-step Manual Testing + +### 1. Start the API Server + +```bash +uvicorn atroposlib.api.server:app --host 127.0.0.1 --port 8000 +``` + +### 2. Launch the Environment + +```bash +python -m environments.dataset_environment.dataset_env serve \ + --group_size 4 \ + --max_num_workers 2 \ + --rollout_server_url http://127.0.0.1:8000 \ + --tokenizer_name Qwen/Qwen2.5-1.5B-Instruct \ + --use_wandb --wandb_name dataset_env_local_test \ + --max_token_length 512 \ + --ensure_scores_are_not_same \ + --dataset_name HuggingFaceH4/testing_self_instruct_process_essays \ + --split train[:100] \ + --prompt_field prompt --answer_field answer \ + --reward_functions length \ + --max_tokens 128 --temperature 0.7 \ + --model_name Qwen/Qwen2.5-1.5B-Instruct \ + --base_url http://127.0.0.1:9001 \ + --slurm --testing +``` + +### 3. Launch the Trainer + +In a separate terminal: + +```bash +python -m example_trainer.grpo.train \ + --model_name Qwen/Qwen2.5-1.5B-Instruct \ + --training_steps 20 \ + --batch_size 2 \ + --gradient_accumulation_steps 2 \ + --seq_len 512 +``` + +## Option N: Use the Dataset Local Server + +For easier configuration via YAML files, you can use the local server script: + +```bash +# This command will look for environments/dataset_environment/configs/gsm8k.yaml +python environments/dataset_environment/dataset_local_server.py --config gsm8k + +# You can also provide a full path: +# python environments/dataset_environment/dataset_local_server.py --config /path/to/your/custom_config.yaml +``` + +This will load the specified config and run the environment accordingly. + +## Debugging + +To check if requests are properly sent to and received by the API server, you can inspect the logs from both the environment and the API server. Look for: + +- API logs showing incoming requests +- Environment logs showing completions being generated and scored + +For model-specific issues, check: +- Ensure your model server is running at the specified URL +- Check model server logs for any errors related to generation + +## Configuration Structure + +Configuration files placed in `environments/dataset_environment/configs/` typically contain: + +```yaml +# Example: environments/dataset_environment/configs/my_config.yaml + +# Base environment parameters (can be overridden by dataset specifics) +tokenizer_name: "NousResearch/DeepHermes-3-Llama-3-8B-Preview" +group_size: 1 +use_wandb: false +# ... other base parameters + +# Dataset specific configuration +dataset: + # Dataset parameters + dataset_name: "databricks/databricks-dolly-15k" + prompt_field: "instruction" + # ... other dataset parameters + reward_functions: + - type: "accuracy" + weight: 1.0 + - type: "repetition_penalty" + weight: 0.2 + +# Optional Server configuration (if not using CLI flags in dataset_env) +server_configs: + - model_name: "gpt-4.1-nano" + api_key: ${OPENAI_API_KEY} + timeout: 600 +``` + +### Important Configuration Parameters + +#### Base Parameters + +- `tokenizer_name`: The tokenizer to use for encoding/decoding text +- `group_size`: Number of responses to collect per prompt +- `max_token_length`: Maximum token length for generation +- `steps_per_eval`: How often to run evaluations + +#### Dataset Specific Parameters (`dataset:` section) + +- `dataset_name`: HuggingFace dataset name (required) +- `dataset_config`: Dataset configuration name (optional) +- `prompt_field`: Field in dataset to use as prompt (required) +- `answer_field`: Field in dataset to use as answer (optional) +- `system_prompt`: System prompt to use (optional) +- `reward_functions`: List of reward functions to apply (optional) + +#### Server Configuration (`server_configs:` section, optional in local server) + +- `model_name`: LLM model to use +- `api_key`: API key for the model (can use environment variables with ${VAR_NAME} syntax) +- `timeout`: Request timeout in seconds + +## Troubleshooting + +If you encounter issues with reward functions, make sure they are properly registered in the registry. + +For dataset-related issues, verify that the dataset exists on HuggingFace and that the specified fields exist in the dataset. \ No newline at end of file diff --git a/environments/dataset_environment/README.md b/environments/dataset_environment/README.md new file mode 100644 index 00000000..c5de21b0 --- /dev/null +++ b/environments/dataset_environment/README.md @@ -0,0 +1,355 @@ +## Quick Start + +### Option A: Unified End-to-End Launcher + +```bash +python -m environments.dataset_environment.launch_local_dataset_run +``` +This single command spins up: +1. The Trajectory Handler API server (`uvicorn atroposlib.api.server:app`) +2. The DatasetEnv in serve mode (connected to the API) +3. The example GRPO trainer (via `example_trainer.grpo.train`) + +### Option B: Manual Steps + +1. **Start the API server** + + ```bash + uvicorn atroposlib.api.server:app --host 127.0.0.1 --port 8000 + ``` + +2. **Launch the Dataset Environment** + + - **Using CLI flags**: + (These flags override any config file settings) + ```bash + python -m environments.dataset_environment.dataset_env serve \ + --group_size 4 \ + --max_num_workers 2 \ + --rollout_server_url http://127.0.0.1:8000 \ + --tokenizer_name Qwen/Qwen2.5-1.5B-Instruct \ + --use_wandb --wandb_name dataset_env_local_test \ + --max_token_length 512 \ + --ensure_scores_are_not_same \ + --dataset_name HuggingFaceH4/testing_self_instruct_process_essays \ + --split train[:100] \ + --prompt_field prompt --answer_field answer \ + --reward_functions length \ + --max_tokens 128 --temperature 0.7 \ + --model_name Qwen/Qwen2.5-1.5B-Instruct \ + --base_url http://127.0.0.1:9001 \ + --slurm --testing + ``` + + - **Using YAML config files**: + + Place a dataset config under `environments/dataset_environment/configs/.yaml`: + ```yaml + # Example: environments/dataset_environment/configs/gsm8k.yaml + dataset: + dataset_name: "gsm8k" + dataset_config: "main" + split: "train" + prompt_field: "question" + answer_field: "answer" + system_prompt: "You are a mathematical problem solver..." + + generation: + temperature: 0.7 + top_p: 0.95 + + reward_functions: + - type: "accuracy" + weight: 1.0 + ``` + + Then run the local test server: + ```bash + # Will look for environments/dataset_environment/configs/gsm8k.yaml + python environments/dataset_environment/dataset_local_server.py --config gsm8k + ``` + +3. **Launch the Trainer** + + ```bash + python -m example_trainer.grpo + ``` + +## Configuration Files Directory + +Dataset environment specific configurations now live in `environments/dataset_environment/configs/`. +Shared configurations (like agents) might still reside in the project's root `configs/` directory. + +- `environments/dataset_environment/configs/` for dataset-specific configs (used by `dataset_local_server.py`). +- You can reference any `.yaml` within this directory via the `--config` flag in the local server script. + +## Reward Function Registry & Customization + +Reward functions are managed by a centralized registry (see `atroposlib/envs/reward_fns/reward_function.py`). Built-in types include: + +- `accuracy`: exact match to ground truth (tolerance, split_on_think_tag) +- `format`: checks for specific tags (preferred_tags) +- `reasoning_steps`: quality of step-by-step reasoning +- `repetition_penalty`: penalizes repetition +- `cosine_scaled`: semantic similarity scaled from embeddings +- `crossword_format`: crossword-specific penalty +- `r1`: combined accuracy + format + +To preview all available functions: +```python +from atroposlib.envs.reward_fns import registry +print(registry.list()) +``` + +### Creating Custom Reward Functions + +1. Create a new file under `atroposlib/envs/reward_fns/my_reward.py`. +2. Subclass `RewardFunction` and register it: + + ```python + from atroposlib.envs.reward_fns import registry, RewardFunction + + @registry.register + class MyCustomReward(RewardFunction): + def __init__(self, custom_param=1.0, weight=1.0, **kwargs): + super().__init__(weight=weight, **kwargs) + self.custom_param = custom_param + + def compute(self, completions, **kwargs): + return [1.0 if "good answer" in self.get_content(c) else 0.0 for c in completions] + ``` + +3. Reference it in your YAML config: + + ```yaml + reward_functions: + - type: "my_custom" + weight: 1.0 + params: + custom_param: 2.0 + ``` + +### Dataset Environments + +Dataset environments load data from HuggingFace datasets and evaluate LLM responses against ground truth. They're ideal for academic benchmarks and datasets with clear evaluation criteria. + +Example configuration: +```yaml +dataset: + dataset_name: "gsm8k" + dataset_config: "main" + split: "train" + prompt_field: "question" + answer_field: "answer" + system_prompt: "You are a mathematical problem solver..." + reward_functions: + - type: "accuracy" + weight: 1.0 +``` + +## Reward Functions + +The system features a flexible reward function architecture for evaluating model outputs. + +### Basic Usage + +In your environment config, specify reward functions: + +```yaml +reward_functions: + - type: "accuracy" + weight: 1.0 + - type: "format" + weight: 0.5 +``` + +### Combining Reward Functions + +Combine multiple reward functions with weights: + +```yaml +reward_functions: + - type: "combined" + params: + normalization: "sum" + rewards: + - type: "accuracy" + weight: 1.5 + - type: "format" + weight: 0.5 +``` + +### Available Reward Functions + +#### `accuracy` +Evaluates if completions match ground truth answers. + +```yaml +type: "accuracy" +weight: 1.0 +params: + tolerance: 1e-6 + split_on_think_tag: true + max_boxed_threshold: 6 +``` + +#### `format` +Checks if completions include specific XML-style tags. + +```yaml +type: "format" +weight: 1.0 +params: + preferred_tags: ["think", "reasoning"] + require_all_tags: false + case_sensitive: false +``` + +#### `reasoning_steps` +Evaluates step-by-step reasoning quality. + +```yaml +type: "reasoning_steps" +weight: 1.0 +params: + min_words: 10 + min_steps: 3 + base_score: 0.1 +``` + +#### `repetition_penalty` +Penalizes repetitive content. + +```yaml +type: "repetition_penalty" +weight: 0.5 +params: + threshold: 0.05 + min_words: 10 + min_sentences: 2 +``` + +#### `cosine_scaled` +Measures semantic similarity between completions and solutions. + +```yaml +type: "cosine_scaled" +weight: 0.8 +params: + model_name: "sentence-transformers/all-MiniLM-L6-v2" + scale_factor: 1.0 + min_reward: -1.0 + max_reward: 1.0 +``` + +#### `crossword_format` +Game-specific reward for crossword puzzles. + +```yaml +type: "crossword_format" +weight: 1.0 +params: + reward_value: 1.0 + penalize_invalid_chars: true +``` + +#### `r1` +Combined reward using both reasoning format and accuracy. + +```yaml +type: "r1" +weight: 1.0 +params: + format_weight: 0.5 + accuracy_weight: 1.0 +``` + +### Creating Custom Reward Functions + +To create a custom reward function: + +1. Create a new file in `atroposlib/envs/reward_fns/my_reward.py` + +2. Define your reward function class: + +```python +from typing import Any, List +from atroposlib.envs.reward_fns import registry, RewardFunction + +@registry.register +class MyCustomReward(RewardFunction): + def __init__(self, custom_param=1.0, weight=1.0, **kwargs): + super().__init__(weight=weight, **kwargs) + self.custom_param = custom_param + + def compute(self, completions: List[Any], **kwargs) -> List[float]: + rewards = [] + for completion in completions: + content = self.get_content(completion) + # Implement your reward logic + reward = 1.0 if "good answer" in content else 0.0 + rewards.append(reward) + return rewards +``` + +3. Use it in your config: + +```yaml +reward_functions: + - type: "my_custom" + weight: 1.0 + params: + custom_param: 2.0 +``` + +### Dataset Environment Debugger + +The dataset environment debugger allows you to run a dataset environment locally with a Hugging Face model, providing enhanced visibility into reward function performance and model responses. + +```bash +# Run with default settings +python -m atroposlib.cli.dataset_env_debugger --env gsm8k_debug --agent nous_hermes_8b + +# List available environments and agents +python -m atroposlib.cli.dataset_env_debugger --list-configs + +# Interactive mode with debugging information +python -m atroposlib.cli.dataset_env_debugger --env gsm8k_debug --agent nous_hermes_8b --interactive --debug + +# Run with custom generation parameters +python -m atroposlib.cli.dataset_env_debugger --env gsm8k_debug --agent nous_hermes_8b --temperature 0.5 --top-p 0.95 + +# Run with detailed logging +python -m atroposlib.cli.dataset_env_debugger --env gsm8k_debug --agent nous_hermes_8b --verbose +``` + +## Environment Overview + +This environment demonstrates how to use a standard dataset (e.g., from Hugging Face Datasets) as a source for generating prompts and evaluating LLM responses. It allows for testing and training models on established benchmarks or custom datasets where prompts and expected answers/ground truth are available. + +**Demonstrates:** +- Loading and processing data from Hugging Face Datasets. +- Configuring system prompts, prompt/answer fields. +- Applying various reward functions (accuracy, format, semantic similarity, etc.) to evaluate generations. +- Integrating with the `atroposlib` framework for data collection and scoring. + +**Training Goal:** +- To train LLMs to follow instructions and generate responses that align with the format and content specified by the dataset and reward functions. +- To improve performance on specific tasks defined by datasets (e.g., math problem solving, code generation, question answering). + +## Local Testing + +To test this environment locally, you can run the provided local server. This server simulates the interaction flow without needing the full distributed setup. + +First, ensure you have the necessary dependencies installed. + +Then, run the local server script from the root of the repository: + +```bash +python environments/dataset_environment/dataset_local_server.py --config-path path/to/your/dataset_config.yaml +``` + +Replace `path/to/your/dataset_config.yaml` with the actual path to your environment configuration file (e.g., `configs/envs/gsm8k.yaml`). The server will load the dataset specified in the config, process items, and simulate generating responses. + + +FOR RELEASE - FIX diff --git a/environments/dataset_environment/__init__.py b/environments/dataset_environment/__init__.py new file mode 100644 index 00000000..6ee01b1f --- /dev/null +++ b/environments/dataset_environment/__init__.py @@ -0,0 +1,11 @@ +""" +Dataset Environment for training models with Hugging Face datasets. + +This environment provides a flexible way to train models using a variety of datasets +from Hugging Face or other sources, evaluating generations against reference answers +with customizable reward functions. +""" + +from environments.dataset_environment.dataset_env import DatasetEnv, DatasetEnvConfig + +__all__ = ["DatasetEnv", "DatasetEnvConfig"] diff --git a/environments/dataset_environment/configs/dataset_local.yaml b/environments/dataset_environment/configs/dataset_local.yaml new file mode 100644 index 00000000..7849de34 --- /dev/null +++ b/environments/dataset_environment/configs/dataset_local.yaml @@ -0,0 +1,52 @@ +# Dataset Environment Local Testing Configuration + +# Base environment parameters +tokenizer_name: "NousResearch/DeepHermes-3-Llama-3-8B-Preview" +group_size: 1 +use_wandb: false +max_num_workers: 1 +rollout_server_url: "http://localhost:8000" +total_steps: 1 +batch_size: 1 +steps_per_eval: 5 +max_token_length: 4096 +wandb_name: "dataset_test_local" +ensure_scores_are_not_same: false + +# Dataset specific configuration +dataset: + # Dataset parameters + dataset_name: "gsm8k" # Example dataset + dataset_config: "main" + split: "train" + prompt_field: "question" + answer_field: "answer" + + # Generation parameters + system_prompt: "You are an expert mathematician. You need to solve the given math problem step-by-step, showing your reasoning clearly. You should enclose your thoughts and internal monologue inside tags, and then provide your final answer.\n\nFollow these steps:\n1. Understand the problem carefully\n2. Plan your approach\n3. Execute the calculations step-by-step\n4. Verify your solution\n\nYou may use extremely long chains of thought to deeply consider the problem and deliberate with yourself via systematic reasoning processes to help come to a correct solution." + prefill: "\n" + shuffle_dataset: true + max_generations_per_prompt: 1 + + # Generation length parameters + max_tokens: 4096 + length_warmup_steps: 0 + min_tokens: 0 + + # Completion parameters + temperature: 0.7 + top_p: 0.9 + + # Reward functions + reward_functions: + - "accuracy" + - "format" + accuracy_reward_weight: 1.0 + format_reward_weight: 0.2 + + +# Server configuration +server_configs: + - model_name: "gpt-4.1-nano" + api_key: ${OPENAI_API_KEY} + timeout: 600 \ No newline at end of file diff --git a/environments/dataset_environment/configs/gsm8k.yaml b/environments/dataset_environment/configs/gsm8k.yaml new file mode 100644 index 00000000..c4614196 --- /dev/null +++ b/environments/dataset_environment/configs/gsm8k.yaml @@ -0,0 +1,73 @@ +tokenizer_name: "NousResearch/DeepHermes-3-Llama-3-1B-Preview" +group_size: 8 +use_wandb: true +max_num_workers: 256 +max_eval_workers: 16 +steps_per_eval: 100 +batch_size: 1024 +max_batches_offpolicy: 3 +total_steps: 1000 +rollout_server_url: "http://localhost:8000" + +use_local_agents: true + +dataset: + dataset_name: "gsm8k" + dataset_config: "main" + split: "train" + + prompt_field: "question" + answer_field: "answer" + + system_prompt: "You are a deep thinking AI, you may use extremely long chains of thought to deeply consider the problem and deliberate with yourself via systematic reasoning processes to help come to a correct solution prior to answering. You should enclose your thoughts and internal monologue inside tags, and then provide your solution or response to the problem." + shuffle_dataset: true + max_generations_per_prompt: 1 + include_messages_in_scoring: false + + # New configurable reward functions + reward_functions: + - type: "r1" + weight: 1.5 + params: + format_weight: 0.5 + accuracy_weight: 1.0 + - type: "cosine_scaled" + weight: 0.8 + params: + scale_factor: 1.2 + min_reward: -1.0 + max_reward: 1.0 + - type: "accuracy" + weight: 2.0 + params: + split_on_think_tag: true + - type: "format" + weight: 0.7 + params: + preferred_tags: ["think", "reasoning"] + require_all_tags: false + - type: "reasoning_steps" + weight: 1.0 + params: + min_steps: 3 + - type: "repetition_penalty" + weight: 0.5 + params: + threshold: 0.1 + + # Legacy format still supported for backward compatibility + # reward_funcs: + # - "r1_reward" + # - "cosine_scaled_reward" + # - "accuracy_reward" + # - "format_reward" + # - "reasoning_steps_reward" + # - "repetition_penalty_reward" + + max_tokens: 16000 + length_warmup_steps: 100 + min_tokens: 2048 + + eval_dataset_name: "gsm8k" + eval_dataset_config: "main" + eval_split: "test" \ No newline at end of file diff --git a/environments/dataset_environment/configs/gsm8k_debug.yaml b/environments/dataset_environment/configs/gsm8k_debug.yaml new file mode 100644 index 00000000..f928e9e2 --- /dev/null +++ b/environments/dataset_environment/configs/gsm8k_debug.yaml @@ -0,0 +1,30 @@ +tokenizer_name: "NousResearch/DeepHermes-3-Llama-3-8B-Preview" +group_size: 1 +use_wandb: false +max_num_workers: 1 +max_eval_workers: 0 +batch_size: 1 +total_steps: 100 +rollout_server_url: "http://localhost:8000" + +dataset: + dataset_name: "gsm8k" + dataset_config: "main" + split: "train" + + prompt_field: "question" + answer_field: "answer" + + system_prompt: "You are a deep thinking AI, you may use extremely long chains of thought to deeply consider the problem and deliberate with yourself via systematic reasoning processes to help come to a correct solution prior to answering. You should enclose your thoughts and internal monologue inside tags, and then provide your solution or response to the problem." + shuffle_dataset: true + max_generations_per_prompt: 1 + include_messages_in_scoring: true + + # Using multiple reward functions for testing + reward_funcs: + - "accuracy_reward" + - "format_reward" + + max_tokens: 4096 + length_warmup_steps: 0 + min_tokens: 200 \ No newline at end of file diff --git a/environments/dataset_environment/dataset_env.py b/environments/dataset_environment/dataset_env.py new file mode 100644 index 00000000..602cf812 --- /dev/null +++ b/environments/dataset_environment/dataset_env.py @@ -0,0 +1,407 @@ +import asyncio +import logging +from typing import Any, Dict, List, Optional, Tuple, Union + +from datasets import load_dataset +from pydantic import Field + +from atroposlib.envs.base import BaseEnv, BaseEnvConfig, ScoredDataGroup +from atroposlib.envs.reward_fns import registry +from atroposlib.envs.reward_fns.combined_reward import CombinedReward +from atroposlib.type_definitions import Item +from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + + +class DatasetEnvConfig(BaseEnvConfig): + dataset_name: str = Field(..., description="HuggingFace dataset name") + dataset_config: Optional[str] = Field( + None, description="Dataset configuration name" + ) + split: str = Field("train", description="Dataset split to use") + dataset_path: Optional[str] = Field( + None, description="Local path to dataset (alternative to dataset_name)" + ) + prompt_field: str = Field(..., description="Field in dataset to use as prompt") + answer_field: Optional[str] = Field( + None, description="Field in dataset to use as answer" + ) + ground_truth_field: Optional[str] = Field( + None, description="Field in dataset containing canonical correct answer" + ) + system_prompt: Optional[str] = Field(None, description="System prompt to use") + prefill: Optional[str] = Field(None, description="Text to prefill the completion with (e.g. '')") + shuffle_dataset: bool = Field(True, description="Whether to shuffle the dataset") + max_generations_per_prompt: int = Field( + 1, description="Number of generations per prompt for collection" + ) + include_messages_in_scoring: bool = Field( + False, description="Whether to include messages in scoring" + ) + reward_funcs: List[str] = Field( + default_factory=list, + description="List of reward function names to apply (legacy)", + ) + reward_functions: List[Union[str, Dict[str, Any]]] = Field( + default_factory=list, + description="List of reward functions to apply (string names or full configs)", + ) + + # Completion parameters + temperature: float = Field(0.7, description="Temperature for generation") + top_p: float = Field(0.9, description="Top-p for generation") + max_tokens: int = Field(4096, description="Maximum tokens for generation") + length_warmup_steps: int = Field(0, description="Steps for length warmup") + min_tokens: int = Field(0, description="Minimum tokens for generation") + + eval_dataset_name: Optional[str] = Field( + None, description="Evaluation dataset name" + ) + eval_dataset_config: Optional[str] = Field( + None, description="Evaluation dataset config" + ) + eval_split: Optional[str] = Field(None, description="Evaluation dataset split") + + +class DatasetEnv(BaseEnv): + def __init__( + self, config: DatasetEnvConfig, server_configs, slurm=True, testing=False + ): + super().__init__(config, server_configs, slurm, testing) + self.config = config + self.dataset = None + self.iter = 0 + self.metric_buffer = {} + + self.reward_function = self._initialize_reward_function() + + def _initialize_reward_function(self): + if hasattr(self.config, "reward_functions") and self.config.reward_functions: + if len(self.config.reward_functions) == 1: + return registry.create(self.config.reward_functions[0]) + else: + return CombinedReward( + rewards=self.config.reward_functions, normalization="sum" + ) + elif hasattr(self.config, "reward_funcs") and self.config.reward_funcs: + if len(self.config.reward_funcs) == 1: + return registry.create(self.config.reward_funcs[0]) + else: + return CombinedReward( + rewards=self.config.reward_funcs, normalization="none" + ) + + async def setup(self): + if self.config.dataset_path: + self.dataset = load_dataset( + self.config.dataset_path, split=self.config.split + ) + else: + self.dataset = load_dataset( + self.config.dataset_name, + self.config.dataset_config, + split=self.config.split, + ) + logger.info(f"Dataset features: {self.dataset.features}") + logger.info(f"Sample item keys: {list(self.dataset[0].keys())}") + logger.info(f"Sample item: {self.dataset[0]}") + + if self.config.shuffle_dataset: + self.dataset = self.dataset.shuffle() + + self.metric_buffer = {} + + async def get_next_item(self) -> Item: + if not self.dataset: + await self.setup() + + item = self.dataset[self.iter % len(self.dataset)] + self.iter += 1 + + user_msg = {"role": "user", "content": item[self.config.prompt_field]} + prompt = tuple([frozenset(user_msg.items())]) + + answer = None + if self.config.answer_field and self.config.answer_field in item: + answer = item[self.config.answer_field] + + ground_truth = None + if self.config.ground_truth_field and self.config.ground_truth_field in item: + ground_truth = item[self.config.ground_truth_field] + + return (prompt, answer, ground_truth) + + async def collect_trajectory(self, item: Item) -> Tuple[List, List]: + # Extract user prompt and answer from item + user_content = dict(item[0][0])["content"] + answer = item[1] if len(item) > 1 else None + + # Create messages list + messages = [] + if self.config.system_prompt: + messages.append({"role": "system", "content": self.config.system_prompt}) + + messages.append({"role": "user", "content": user_content}) + + # Add prefill as assistant message if configured + if self.config.prefill: + messages.append({"role": "assistant", "content": self.config.prefill}) + + # Convert messages to a prompt string using the tokenizer + prompt = self.tokenizer.apply_chat_template(messages, tokenize=False) + + # Calculate max tokens for generation (with optional warmup) + max_tokens = self.config.max_tokens + if self.config.length_warmup_steps > 0: + warmup_progress = min(1.0, self.curr_step / self.config.length_warmup_steps) + max_tokens = int( + self.config.min_tokens + + warmup_progress * (self.config.max_tokens - self.config.min_tokens) + ) + + # Generate completion using completions API + completions = await self.server.completion( + prompt=prompt, + n=self.config.max_generations_per_prompt, + max_tokens=max_tokens, + temperature=self.config.temperature, + top_p=self.config.top_p, + ) + + to_score = [] + to_backlog = [] + + # Process completions + for completion in completions.choices: + # Get the completion text + completion_text = completion.text if hasattr(completion, "text") else completion.message.content + + # Build full message sequence for scoring + full_messages = [] + if self.config.system_prompt: + full_messages.append({"role": "system", "content": self.config.system_prompt}) + + full_messages.append({"role": "user", "content": user_content}) + + # Combine prefill with completion if prefill was used + response_content = completion_text + if self.config.prefill: + response_content = self.config.prefill + completion_text + + full_messages.append({"role": "assistant", "content": response_content}) + + # Add to scoring list with answer and ground truth + to_score.append( + (full_messages, answer, item[2] if len(item) > 2 else None) + ) + + return to_score, to_backlog + + async def postprocess_histories(self, trajectories: List) -> Tuple[List, List]: + return trajectories, [] + + async def collect_trajectories(self, item: Item) -> Tuple[List, List]: + self.current_item = item + + # Extract user prompt from item + user_content = dict(item[0][0])["content"] + + # Create messages list + messages = [] + if self.config.system_prompt: + messages.append({"role": "system", "content": self.config.system_prompt}) + + messages.append({"role": "user", "content": user_content}) + + # Add prefill as assistant message if configured + if self.config.prefill: + messages.append({"role": "assistant", "content": self.config.prefill}) + + # Convert messages to a prompt string using the tokenizer + prompt = self.tokenizer.apply_chat_template(messages, tokenize=False) + + # Calculate max tokens for generation (with optional warmup) + max_tokens = self.config.max_tokens + + # Generate completions + completions = await self.server.completion( + prompt=prompt, + n=self.config.group_size, + max_tokens=max_tokens, + temperature=self.config.temperature, + top_p=self.config.top_p, + ) + + print(f"Completions: {completions}") + # Process completions + trajectories = [] + for completion in completions.choices: + # Get the completion text + completion_text = completion.text if hasattr(completion, "text") else completion.message.content + + # Build complete message sequence + full_messages = [] + if self.config.system_prompt: + full_messages.append({"role": "system", "content": self.config.system_prompt}) + + full_messages.append({"role": "user", "content": user_content}) + + # Combine prefill with completion if prefill was used + response_content = completion_text + if self.config.prefill: + response_content = self.config.prefill + completion_text + + full_messages.append({"role": "assistant", "content": response_content}) + + trajectories.append(full_messages) + + return trajectories, [] + + async def score(self, rollout_group_data: List) -> Optional[ScoredDataGroup]: + logger.warning(f"Scoring {len(rollout_group_data)} rollout items") + + scores = ScoredDataGroup() + scores["tokens"] = [] + scores["masks"] = [] + scores["scores"] = [] + scores["advantages"] = None + scores["ref_logprobs"] = None + scores["messages"] = None if not self.config.include_messages_in_scoring else [] + + answer = ( + self.current_item[1] + if self.current_item and len(self.current_item) > 1 + else None + ) + logger.warning(f"Answer for current item: {answer}") + + ground_truth = ( + self.current_item[2] + if self.current_item and len(self.current_item) > 2 + else None + ) + logger.warning(f"Ground truth for current item: {ground_truth}") + + formatted_completions = [] + for trajectory in rollout_group_data: + if trajectory and isinstance(trajectory, list): + assistant_messages = [ + msg + for msg in trajectory + if isinstance(msg, dict) and msg.get("role") == "assistant" + ] + if assistant_messages: + formatted_completions.append([assistant_messages[-1]]) + + if not formatted_completions: + logger.warning("No valid completions to score") + return None + + try: + reward_kwargs = { + "solution": answer, + "ground_truth": ground_truth, + "item": self.current_item, + "config": self.config, + } + + all_rewards = self.reward_function(formatted_completions, **reward_kwargs) + + logger.info(f"Calculated rewards: {all_rewards}") + + except Exception as e: + logger.error(f"Error applying reward functions: {e}") + logger.exception(e) + all_rewards = [0.0] * len(formatted_completions) + + for i, (trajectory, reward) in enumerate(zip(rollout_group_data, all_rewards)): + try: + tokenized = tokenize_for_trainer(self.tokenizer, trajectory) + + scores["tokens"].append(tokenized["tokens"]) + scores["masks"].append(tokenized["masks"]) + scores["scores"].append(reward) + + if self.config.include_messages_in_scoring: + if "messages" not in scores: + scores["messages"] = [] + scores["messages"].append(trajectory) + logger.warning(f"Scores: {scores['scores']}") + except Exception as e: + logger.error(f"Error processing trajectory {i}: {e}") + logger.exception(e) + + if not scores["tokens"]: + logger.warning("No valid scores generated") + return None + + logger.info(f"Generated scores: {scores['scores']}") + return scores + + async def evaluate(self): + if ( + not hasattr(self.config, "eval_dataset_name") + or not self.config.eval_dataset_name + ): + return + + if not hasattr(self, "eval_dataset"): + self.eval_dataset = load_dataset( + self.config.eval_dataset_name, + self.config.eval_dataset_config, + split=self.config.eval_split, + ) + self.eval_dataset = self.eval_dataset.select( + range(min(100, len(self.eval_dataset))) + ) + + eval_metrics = {} + eval_tasks = [] + + for i in range(min(self.config.max_eval_workers, len(self.eval_dataset))): + item = self.eval_dataset[i] + user_msg = {"role": "user", "content": item[self.config.prompt_field]} + prompt = tuple([frozenset(user_msg.items())]) + + answer = None + if self.config.answer_field and self.config.answer_field in item: + answer = item[self.config.answer_field] + + eval_tasks.append(self.collect_trajectory((prompt, answer))) + + eval_results = await asyncio.gather(*eval_tasks) + + eval_scores = [] + for result in eval_results: + if result[0]: + scored_data = await self.score(result[0]) + if scored_data and "scores" in scored_data: + eval_scores.extend(scored_data["scores"]) + + if eval_scores: + eval_metrics["eval/mean_score"] = sum(eval_scores) / len(eval_scores) + eval_metrics["eval/max_score"] = max(eval_scores) + eval_metrics["eval/min_score"] = min(eval_scores) + + await self.wandb_log(eval_metrics) + + async def wandb_log(self, wandb_metrics: Optional[Dict] = None): + metrics = wandb_metrics or {} + + for key, values in self.metric_buffer.items(): + if values: + metrics[f"train/{key}"] = sum(values) / len(values) + + self.metric_buffer = {k: [] for k in self.metric_buffer} + + if hasattr(self, "reward_function") and self.wandb: + if hasattr(self.reward_function, "set_wandb_logger"): + self.reward_function.set_wandb_logger(self.wandb) + + await super().wandb_log(metrics) + +if __name__ == "__main__": + # Launch the DatasetEnv via the BaseEnv CLI (serve or process) + DatasetEnv.cli() diff --git a/environments/dataset_environment/dataset_local_server.py b/environments/dataset_environment/dataset_local_server.py new file mode 100644 index 00000000..7fdf047a --- /dev/null +++ b/environments/dataset_environment/dataset_local_server.py @@ -0,0 +1,248 @@ +#!/usr/bin/env python3 +import argparse +import asyncio +import logging +import os + +from dotenv import load_dotenv + +from atroposlib.envs.base import OpenaiConfig +from atroposlib.envs.reward_fns import registry +from atroposlib.utils.config_handler import ConfigHandler +from environments.dataset_environment.dataset_env import DatasetEnv, DatasetEnvConfig + +load_dotenv() + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +def parse_arguments(): + parser = argparse.ArgumentParser(description="Dataset environment local server") + parser.add_argument( + "--config", + type=str, + default="dataset_local", + help="Configuration file name (without .yaml extension) relative to environments/dataset_environment/configs/, or full path to a YAML file.", + ) + return parser.parse_args() + + +async def main(): + logger.info("Starting Dataset environment local server") + + # Parse command line arguments + args = parse_arguments() + + # Initialize config handler + config_handler = ConfigHandler() + + # Determine config path + if ( + os.path.isabs(args.config) + or "/" in args.config + or args.config.endswith(".yaml") + ): + config_path = args.config + else: + # Assume it's a name relative to the new default directory + config_path = os.path.join( + os.path.dirname(__file__), "configs", f"{args.config}.yaml" + ) + + logger.info(f"Loading configuration from: {config_path}") + + try: + with open(config_path, "r") as f: + import yaml + + raw_config = yaml.safe_load(f) + logger.info("Loaded configuration successfully") + except FileNotFoundError: + logger.error(f"Configuration file not found at: {config_path}") + logger.info("Ensure the --config argument is correct or the file exists.") + return + except Exception as e: + logger.error(f"Error loading config from {config_path}: {e}") + return + + # Ensure dataset configuration exists (assuming it's top-level in these files) + if "dataset" not in raw_config: + logger.warning( + "'dataset' key not found at the top level of the config file. " + "Assuming the entire file is the dataset configuration." + ) + # Treat the whole raw_config as the 'dataset' section for compatibility + dataset_section = raw_config + else: + dataset_section = raw_config["dataset"] + + if "dataset_name" not in dataset_section: + logger.error("dataset_name not found in dataset configuration") + return + if "prompt_field" not in dataset_section: + logger.error("prompt_field not found in dataset configuration") + return + + # Configure the dataset environment + # Merging logic: Start with raw_config defaults, then dataset_section specifics + env_config_data = {**raw_config, **dataset_section} + + # Pydantic will ignore extra fields, so we just pass everything + try: + env_config = DatasetEnvConfig(**env_config_data) + except Exception as pydantic_error: + logger.error(f"Error validating configuration: {pydantic_error}") + return + + # Preload reward functions + reward_names_to_load = set() + if env_config.reward_funcs: + reward_names_to_load.update(env_config.reward_funcs) + if env_config.reward_functions: + for rf_config in env_config.reward_functions: + if isinstance(rf_config, str): + reward_names_to_load.add(rf_config) + elif isinstance(rf_config, dict) and "type" in rf_config: + reward_names_to_load.add(rf_config["type"]) + + if reward_names_to_load: + logger.info(f"Preloading reward functions: {list(reward_names_to_load)}") + for func_name in reward_names_to_load: + try: + registry.get(func_name) + logger.info(f"Successfully loaded reward function: {func_name}") + except Exception as e: + logger.error(f"Failed to load reward function {func_name}: {e}") + + # Server configuration - process env vars + server_configs = [] + + if "server_configs" in raw_config: + for server_config in raw_config["server_configs"]: + api_key = server_config.get("api_key", os.environ.get("OPENAI_API_KEY")) + # Handle environment variable references like ${OPENAI_API_KEY} + if ( + isinstance(api_key, str) + and api_key.startswith("${") + and api_key.endswith("}") + ): + env_var = api_key[2:-1] + api_key = os.environ.get(env_var, "") + + server_configs.append( + OpenaiConfig( + model_name=server_config.get("model_name", "gpt-4.1-nano"), + base_url=server_config.get("base_url", None), + api_key=api_key, + timeout=server_config.get("timeout", 600), + ) + ) + else: + # Default configuration if not specified in config file + logger.warning( + "No 'server_configs' found in config. Using default OpenAI config." + ) + server_configs.append( + OpenaiConfig( + model_name="gpt-4.1-nano", + base_url=None, + api_key=os.environ.get("OPENAI_API_KEY"), + timeout=600, + ) + ) + + # Create the environment + logger.info("Creating dataset environment...") + env = DatasetEnv( + config=env_config, + server_configs=server_configs, + slurm=False, + ) + + # Setup the environment directly + try: + await env.setup() + logger.info("Environment setup complete") + except Exception as setup_error: + logger.error(f"Error during environment setup: {setup_error}") + return + + # --- Start Test Run --- # + logger.info("\n=== Starting Local Test Run ===") + test_items_count = 5 + successful_runs = 0 + + for i in range(test_items_count): + logger.info(f"\n--- Running Test Item {i+1}/{test_items_count} ---") + try: + # Get a sample item from the dataset + item = await env.get_next_item() + if not item or not item[0]: + logger.warning("Failed to get a valid item from the environment.") + continue + + prompt, answer, ground_truth = item + user_content = dict(prompt[0])["content"] + logger.info( + f"Prompt: {user_content[:200]}..." if user_content else "(Empty Prompt)" + ) + if answer: + logger.info( + f"Answer: {answer[:200]}..." if answer else "(Empty Answer)" + ) + if ground_truth: + logger.info( + f"Ground Truth: {ground_truth[:200]}..." + if ground_truth + else "(Empty Ground Truth)" + ) + + # Collect trajectories (using group_size from config) + logger.info( + f"Collecting {env_config.group_size} trajectories for this item..." + ) + trajectories_data, backlog = await env.collect_trajectories(item) + + if not trajectories_data: + logger.warning("No trajectories were collected.") + continue + + logger.info(f"Collected {len(trajectories_data)} trajectories.") + # Log first trajectory message content for inspection + if trajectories_data[0] and isinstance(trajectories_data[0], list): + first_response = "(Empty or invalid trajectory format)" + assistant_msgs = [ + m + for m in trajectories_data[0] + if isinstance(m, dict) and m.get("role") == "assistant" + ] + if assistant_msgs: + first_response = assistant_msgs[-1].get("content", "(No content)") + logger.info(f"First Response Content: {first_response[:300]}...") + + # Score the collected trajectories + logger.info("Scoring trajectories...") + scored_data = await env.score(trajectories_data) + + # Print scores + if scored_data and "scores" in scored_data: + scores_list = scored_data["scores"] + logger.info(f"Scores: {scores_list}") + logger.info(f" Avg Score: {sum(scores_list)/len(scores_list):.4f}") + successful_runs += 1 + else: + logger.warning("No scores available in the scored data for this item.") + + except Exception as run_error: + logger.error(f"Error during test item {i+1}: {run_error}") + # Optionally continue to the next item or break + # break + + logger.info( + f"\n=== Local Test Run Complete ({successful_runs}/{test_items_count} items processed successfully) ===" + ) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/environments/dataset_environment/launch_local_dataset_run.py b/environments/dataset_environment/launch_local_dataset_run.py new file mode 100644 index 00000000..26f6a2ac --- /dev/null +++ b/environments/dataset_environment/launch_local_dataset_run.py @@ -0,0 +1,154 @@ +#!/usr/bin/env python3 +""" +Local dataset training launcher. + +Usage: + python -m environments.dataset_environment.launch_local_dataset_run + +This script does: + 1) Starts the Trajectory Handler API server via uvicorn + 2) Launches the DatasetEnv in local serve mode + 3) Imports and runs the example trainer (GRPO) directly + +Requirements: + - Run from project root so example_trainer is on PYTHONPATH + - example_trainer/ is a valid Python package (with __init__.py) +""" +import os +import sys +import subprocess +import time +import atexit +import signal +import traceback + +# Ensure project root is on PYTHONPATH +project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..')) +if project_root not in sys.path: + sys.path.insert(0, project_root) + +# Import trainer via standard module import +try: + from example_trainer.grpo import TrainingConfig, train +except ImportError as e: + print(f"Error importing example_trainer.grpo: {e}") + print("Ensure you're running from project root and that example_trainer/ is a package.") + sys.exit(1) + +# ----------------------------------------------------------------------------- +# Configuration +# ----------------------------------------------------------------------------- +API_HOST = '127.0.0.1' +API_PORT = 8000 + +VLLM_HOST = '127.0.0.1' +VLLM_PORT = 9001 + +MODEL_NAME = 'Qwen/Qwen2.5-1.5B-Instruct' +TOKENIZER_NAME = MODEL_NAME + +TRAINER_CONFIG = { + 'model_name': MODEL_NAME, + 'training_steps': 20, + 'batch_size': 2, + 'gradient_accumulation_steps': 2, + 'seq_len': 512, + 'vllm_port': VLLM_PORT, + 'vllm_restart_interval': 10, + 'use_wandb': False, + 'wandb_project': '', + 'wandb_group': '', + 'save_path': './trained_model_checkpoints_local_test', +} + +# Flags for launching DatasetEnv serve +DATASET_FLAGS = [ + '--group_size', '4', + '--max_num_workers', '2', + '--rollout_server_url', f"http://{API_HOST}:{API_PORT}", + '--tokenizer_name', TOKENIZER_NAME, + '--use_wandb', + '--wandb_name', 'dataset_env_local_test', + '--max_token_length', str(TRAINER_CONFIG['seq_len']), + '--ensure_scores_are_not_same', + '--dataset_name', 'HuggingFaceH4/testing_self_instruct_process_essays', + '--split', 'train[:100]', + '--prompt_field', 'prompt', + '--answer_field', 'answer', + '--reward_functions', 'length', + '--max_tokens', '128', + '--temperature', '0.7', + '--model_name', MODEL_NAME, + '--base_url', f"http://{VLLM_HOST}:{VLLM_PORT}", + '--slurm', + '--testing', +] + +# Track background processes for cleanup +processes = [] + + +def cleanup_processes(): + print("\nCleaning up background processes...") + for p in reversed(processes): + if p.poll() is None: + print(f"Terminating PID {p.pid}...") + p.terminate() + try: + p.wait(timeout=5) + print(f"PID {p.pid} terminated.") + except subprocess.TimeoutExpired: + print(f"PID {p.pid} did not terminate; killing.") + p.kill() + p.wait() + print(f"PID {p.pid} killed.") + else: + print(f"PID {p.pid} already exited.") + print("Cleanup complete.") + +atexit.register(cleanup_processes) + + +def handle_signal(sig, frame): + print(f"\nSignal {sig} received; exiting.") + sys.exit(0) + +signal.signal(signal.SIGINT, handle_signal) +signal.signal(signal.SIGTERM, handle_signal) + + +def main(): + # 1) Start the API server + print("--- Starting Trajectory Handler API Server ---") + api_cmd = [ + 'uvicorn', + 'atroposlib.api.server:app', + '--host', API_HOST, + '--port', str(API_PORT), + ] + print(f"$ {' '.join(api_cmd)}") + api_proc = subprocess.Popen(api_cmd) + processes.append(api_proc) + time.sleep(3) + + # 2) Start the dataset environment + print("\n--- Starting Dataset Environment ---") + env_cmd = ['python', '-m', 'environments.dataset_environment.dataset_env', 'serve'] + DATASET_FLAGS + print(f"$ {' '.join(env_cmd)}") + env_proc = subprocess.Popen(env_cmd) + processes.append(env_proc) + time.sleep(3) + + # 3) Run the example trainer + print("\n--- Running Example Trainer (GRPO) ---") + config = TrainingConfig(**TRAINER_CONFIG) + try: + train(config) + except Exception: + print("Error during training:") + traceback.print_exc() + print("--- Training complete ---") + + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/environments/fundamental_prediction_environment.py b/environments/fundamental_prediction_environment.py new file mode 100644 index 00000000..30f0ac9a --- /dev/null +++ b/environments/fundamental_prediction_environment.py @@ -0,0 +1,505 @@ +import random +import re +from typing import List, Optional, Tuple, Union + +from datasets import load_dataset +from tqdm.asyncio import tqdm_asyncio + +from atroposlib.envs.base import ( + BaseEnv, + BaseEnvConfig, + EvalHandlingEnum, + OpenaiConfig, + ScoredDataGroup, +) +from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer + +# System prompt only contains thinking instructions +system_prompt = """You are a deep thinking AI financial analyst. +You may use extremely long chains of thought to deeply consider the problem and deliberate with yourself via systematic reasoning processes to help come to a correct solution prior to answering. + +You should enclose your thoughts and internal monologue inside tags, and then provide your final prediction.""" # noqa E501 + +# User message template that contains task instructions +user_message_template = """Your task is to analyze the following company fundamentals, news, and macroeconomic data to predict whether the company's {fundamental_metric} will be maintained, raised, or reduced in the next quarter, as well as the magnitude of any change. + +Your final answer MUST use the exact format: +"The {fundamental_metric} will be: {{answer}} and the magnitude will be: {{percentage}}%" + +Where {{answer}} is one of: "maintained", "raised", or "reduced" +And {{percentage}} is the expected percentage change (0% if maintained). + +Here is the data to analyze: + +{context}""" # noqa E501 + + +class FundamentalPredictionEnv(BaseEnv): + def __init__( + self, + config: BaseEnvConfig, + server_configs: List[OpenaiConfig], + slurm=True, + testing=False, + ): + """ + Initialize the Fundamental Metric Prediction environment. + + Args: + config: Configuration for the base environment + server_configs: List of server configurations for OpenAI API + slurm: Whether to use Slurm for distributed training + testing: Whether in testing mode + """ + super().__init__(config, server_configs, slurm, testing) + self.percent_correct_buffer = list() + self.magnitude_accuracy_buffer = list() + self.eval_metrics = list() + + @classmethod + def config_init(self) -> Tuple[BaseEnvConfig, List[OpenaiConfig]]: + env_config = BaseEnvConfig( + tokenizer_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview", + group_size=32, + use_wandb=True, + max_num_workers=128, + rollout_server_url="http://localhost:8000", + total_steps=2000, + batch_size=1024, + steps_per_eval=20, + max_token_length=1024 * 16, + inference_weight=1.0, + wandb_name="fundamental_metric_prediction", + data_path_to_save_groups=None, + eval_handling=EvalHandlingEnum.LIMIT_TRAIN, + eval_limit_ratio=0.1, + ) + server_configs = [ + OpenaiConfig( + model_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview", + base_url="http://localhost:9004/v1", + api_key="x", + num_requests_for_eval=256, + ) + ] + + return env_config, server_configs + + async def setup(self): + """ + Set up the environment by loading and preparing the dataset. + """ + # Load the full dataset + full_dataset = load_dataset( + "NousResearch/company-fundamentals-prediction-lite", + "default", + split="train", + ) + + full_dataset = full_dataset.shuffle(seed=42) + + # Create train/test split (95% train, 5% test) + split_dataset = full_dataset.train_test_split(test_size=0.05, seed=42) + + # Keep the splits as is - no need to reformat + self.train = split_dataset["train"] + self.test = split_dataset["test"] + + # Print some dataset statistics + print( + f"Loaded dataset with {len(self.train)} training examples and {len(self.test)} test examples" + ) + print(f"Example item format: {self.train[0]}") + + # Initialize iteration counter + self.iter = 0 + + def save_checkpoint(self, step, data=None): + if data is None: + data = {} + data["iter"] = self.iter + super().save_checkpoint(step, data) + + async def get_next_item(self): + """ + Get the next training item from the dataset. + + Returns: + A tuple containing prompt and expected answer + """ + next_item = self.train[self.iter % len(self.train)] + self.iter += 1 + + # Extract context, answer, magnitude and fundamental metric from the dataset item + context = next_item["context"] + answer = next_item["answer"] # "maintained", "raised", or "reduced" + magnitude = next_item["magnitude"] # Percentage as string + fundamental_metric = next_item[ + "fundamental_metric" + ] # Type of metric to predict + + # Create prompt tuple using frozensets as required + prompt = [] + + # Add system prompt + prompt.append(frozenset({"role": "system", "content": system_prompt}.items())) + + # Format user message with context and fundamental metric + user_content = user_message_template.format( + context=context, fundamental_metric=fundamental_metric + ) + prompt.append(frozenset({"role": "user", "content": user_content}.items())) + + return (tuple(prompt), answer, magnitude, fundamental_metric) + + async def collect_trajectories(self, item) -> Tuple[ScoredDataGroup, List]: + """ + Generate and collect model responses for scoring. + + Args: + item: Input item containing prompt and expected answer + + Returns: + Tuple of lists containing scored data groups and backlog + """ + # Extract messages from the item + messages = [] + for role_dict in item[0]: + messages.append(dict(role_dict)) + + # Apply chat template to convert messages to a single string + prompt = self.tokenizer.apply_chat_template( + messages, add_generation_prompt=True, tokenize=False + ) + + # Get completions from the model + completions = await self.server.completion( + prompt=prompt, + n=self.config.group_size, + max_tokens=1024 * 15, + temperature=0.8, # Using higher temperature for diverse responses + ) + + to_score = list() + + for _, completion_choice in enumerate(completions.choices): + # Create a copy of the prompt messages + trajectory_messages = [] + for role_dict in item[0]: + trajectory_messages.append(dict(role_dict)) + + # Add the model's response + trajectory_messages.append( + {"role": "assistant", "content": completion_choice.text} + ) + + # Add to scoring queue with expected answer, magnitude, and fundamental metric + to_score.append( + ( + tuple(trajectory_messages), + item[1], # answer (maintained/raised/reduced) + item[2], # magnitude + item[3], # fundamental_metric + ) + ) + + # Call score to get the scored data + scored_data = await self.score(to_score) + to_backlog = [] + + return scored_data, to_backlog + + def _extract_prediction(self, text, fundamental_metric): + """ + Extract the prediction and magnitude from the model's response. + + Args: + text: Text containing the model's response + fundamental_metric: The fundamental metric being predicted + + Returns: + Tuple of (prediction, magnitude) or (None, None) if extraction fails + """ + # Check for thinking section + think_tags = re.findall(r"", text, re.IGNORECASE) + think_close_tags = re.findall(r"", text, re.IGNORECASE) + + # Verify thinking format - must have exactly one opening and one closing tag + if len(think_tags) != 1 or len(think_close_tags) != 1: + return None, None + + # Split on to separate thinking from answer + parts = re.split(r"", text, flags=re.IGNORECASE, maxsplit=1) + if len(parts) != 2: + return None, None + + thinking_section, answer_section = parts + + # Validate thinking section contains opening tag + if "" not in thinking_section.lower(): + return None, None + + # Escape fundamental_metric for regex + escaped_metric = re.escape(fundamental_metric) + + # Extract prediction and magnitude using regex - dynamic to match the fundamental metric + pattern = f"The {escaped_metric} will be:\\s*(maintained|raised|reduced)\\s*and\\s*the\\s*magnitude\\s*will\\s*be:\\s*([-+]?\\d+(?:\\.\\d+)?)%" # noqa E501 + + # Find all matches to check if there are multiple predictions + all_matches = re.findall(pattern, answer_section, re.IGNORECASE) + + # If no matches or multiple matches found, return None + if len(all_matches) != 1: + return None, None + + # Extract single match + matches = re.search(pattern, answer_section, re.IGNORECASE) + prediction = matches.group(1).lower() + magnitude = matches.group(2) + + return prediction, magnitude + + def _calculate_magnitude_score(self, predicted_magnitude, expected_magnitude): + """ + Calculate a score for magnitude prediction accuracy. + + Args: + predicted_magnitude: The model's predicted magnitude percentage + expected_magnitude: The expected magnitude percentage + + Returns: + Score between 0.0 and 1.0 based on how close the prediction is + """ + try: + # Convert to float for comparison + pred_mag = float(predicted_magnitude) + exp_mag = float(expected_magnitude) + + # Calculate absolute difference + diff = abs(pred_mag - exp_mag) + + # Score based on closeness to expected magnitude + # Perfect match = 1.0 + # Within 1% = 0.9 + # Within 5% = 0.7 + # Within 10% = 0.5 + # Within 20% = 0.3 + # More than 20% off = 0.0 + + if diff == 0: + return 1.0 + elif diff <= 1: + return 0.9 + elif diff <= 5: + return 0.7 + elif diff <= 10: + return 0.5 + elif diff <= 20: + return 0.3 + else: + return 0.0 + + except ValueError: + # If conversion fails, return 0 + return 0.0 + + async def score( + self, rollout_group_data + ) -> Union[Optional[ScoredDataGroup], List[Optional[ScoredDataGroup]]]: + """ + Score the generated model responses for fundamental metric predictions. + + Args: + rollout_group_data: List of generated responses with expected answers + + Returns: + ScoredDataGroup with tokenized inputs and scores, or None if no valid scores + """ + scores = ScoredDataGroup() + scores["tokens"] = list() + scores["masks"] = list() + scores["scores"] = list() + + # Get the expected answer, magnitude, and fundamental metric + expected_answer = rollout_group_data[0][ + 1 + ] # "maintained", "raised", or "reduced" + expected_magnitude = rollout_group_data[0][2] # Expected percentage change + fundamental_metric = rollout_group_data[0][3] # Type of fundamental metric + + # Shuffle to avoid bias in selection + random.shuffle(rollout_group_data) + + for item in rollout_group_data: + # Extract the model's response + model_response = item[0][-1]["content"] + + # Extract the prediction and magnitude from the model's response + prediction, magnitude = self._extract_prediction( + model_response, fundamental_metric + ) + + # Calculate final score + if prediction is None: + final_score = 0.0 # Invalid format + elif prediction == expected_answer: + # Correct direction: base score of 1 + magnitude bonus + magnitude_score = ( + self._calculate_magnitude_score(magnitude, expected_magnitude) + if magnitude is not None + else 0.0 + ) + final_score = 1.0 + magnitude_score + else: + final_score = 0.0 # Incorrect direction + + # Apply length penalty for responses that are too long + response_tokens = len(self.tokenizer.encode(model_response)) + if response_tokens > self.config.max_token_length * 0.95: + # Penalize responses that are close to the max token limit + final_score -= 0.5 * (response_tokens / self.config.max_token_length) + + # For binary reward signal, any positive score gets +1, otherwise -1 + binary_reward = 1.0 if final_score > 0 else -1.0 + + # Tokenize the conversation for learning + out_dict = tokenize_for_trainer(self.tokenizer, item[0]) + tokens = out_dict["tokens"] + masks = out_dict["masks"] + + # Remove examples with insufficient context + if len([1 for i in masks if i != -100]) < 10: + continue + + scores["tokens"].append(tokens) + scores["masks"].append(masks) + scores["scores"].append(binary_reward) + + # For tracking metrics + directional_correct = ( + 1.0 if prediction == expected_answer and prediction is not None else 0.0 + ) + self.percent_correct_buffer.append(directional_correct) + if prediction == expected_answer and magnitude is not None: + self.magnitude_accuracy_buffer.append( + self._calculate_magnitude_score(magnitude, expected_magnitude) + ) + + # Break once we have enough examples + if len(scores["tokens"]) >= self.config.group_size: + break + + # Return None if all scores are the same (no learning signal) + if all(scores["scores"][0] == score for score in scores["scores"]): + return None + + return scores + + async def rollout_and_score_eval(self, test_item): + """ + Generate and score model responses for a single test item. + + Args: + test_item: Test item from dataset + + Returns: + Dictionary with direction and magnitude scores + """ + # Extract context, answer, magnitude and fundamental metric from the test item + context = test_item["context"] + expected_answer = test_item["answer"] + expected_magnitude = test_item["magnitude"] + fundamental_metric = test_item["fundamental_metric"] + + # Format user message with context and fundamental metric + user_content = user_message_template.format( + context=context, fundamental_metric=fundamental_metric + ) + + # Create messages for model + messages = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_content}, + ] + + # Apply chat template to convert messages to a single string + prompt = self.tokenizer.apply_chat_template( + messages, add_generation_prompt=True, tokenize=False + ) + + # Get model completion + completion = await self.server.completion( + prompt=prompt, + n=1, + max_tokens=1024 * 16, + temperature=0.2, # Lower for eval + split="eval", + ) + + # Extract the model's response + model_response = completion.choices[0].text + + # Extract prediction and magnitude + prediction, magnitude = self._extract_prediction( + model_response, fundamental_metric + ) + + # Calculate direction score (1 for correct, 0 for incorrect) + direction_score = ( + 1 if prediction == expected_answer and prediction is not None else 0 + ) + + # Calculate magnitude score if direction is correct + magnitude_score = 0 + if direction_score == 1 and magnitude is not None: + magnitude_score = self._calculate_magnitude_score( + magnitude, expected_magnitude + ) + + # Calculate combined score (1 + magnitude_score for correct direction, 0 for incorrect) + combined_score = (1 + magnitude_score) if direction_score == 1 else 0 + + return { + "direction_score": direction_score, + "magnitude_score": magnitude_score, + "combined_score": combined_score, + } + + async def evaluate(self, *args, **kwargs): + """ + Evaluate the model on test data. + """ + eval_tasks = [] + for test_item in self.test: + eval_tasks.append(self.rollout_and_score_eval(test_item)) + + # Run evaluation + all_scores = await tqdm_asyncio.gather(*eval_tasks) + + # Calculate aggregate metrics + direction_scores = [score["direction_score"] for score in all_scores] + magnitude_scores = [ + score["magnitude_score"] + for score in all_scores + if score["direction_score"] == 1 + ] + combined_scores = [score["combined_score"] for score in all_scores] + + # Calculate and log metrics + direction_accuracy = ( + sum(direction_scores) / len(direction_scores) if direction_scores else 0 + ) + magnitude_accuracy = ( + sum(magnitude_scores) / len(magnitude_scores) if magnitude_scores else 0 + ) + average_combined_score = ( + sum(combined_scores) / len(combined_scores) if combined_scores else 0 + ) + + self.eval_metrics.append(("eval/direction_accuracy", direction_accuracy)) + self.eval_metrics.append(("eval/magnitude_accuracy", magnitude_accuracy)) + self.eval_metrics.append(("eval/combined_score", average_combined_score)) + + +if __name__ == "__main__": + FundamentalPredictionEnv.cli() diff --git a/environments/gsm8k_server.py b/environments/gsm8k_server.py new file mode 100644 index 00000000..75dc3e0f --- /dev/null +++ b/environments/gsm8k_server.py @@ -0,0 +1,295 @@ +import random +from typing import Dict, List, Optional, Tuple, TypedDict, Union + +from datasets import load_dataset +from latex2sympy2_extended import NormalizationConfig +from math_verify import LatexExtractionConfig, parse, verify +from tqdm.asyncio import tqdm_asyncio + +from atroposlib.envs.base import BaseEnv, BaseEnvConfig, OpenaiConfig, ScoredDataGroup +from atroposlib.type_definitions import Item, number +from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer + +system_prompt = ( + "You are a deep thinking AI, you may use extremely long chains of thought " + "to deeply consider the problem and deliberate with yourself via systematic " + "reasoning processes to help come to a correct solution prior to answering. " + "You should enclose your thoughts and internal monologue inside " + "tags, and then provide your solution or response to the problem.\n\n" +) + +system_prompt += """You are allocated a maximum of 2048 tokens, please strive to use less. + +You will then provide your answer like this: \\boxed{your answer here} +It is important that you provide your answer in the correct format. +If you do not, you will not receive credit for your answer. +So please end your answer with \\boxed{your answer here}""" + + +class GSM8kRow(TypedDict): + question: str + answer: str + + +class GSM8kEnv(BaseEnv): + + name = "gsm8k" + + def __init__( + self, + config: BaseEnvConfig, + server_configs: List[OpenaiConfig], + slurm=True, + testing=False, + ): + super().__init__(config, server_configs, slurm, testing) + self.percent_correct_buffer = list() + self.eval_metrics = list() + # Add tracking for wandb visualizations + self.rollouts_for_wandb = [] + self.completion_lengths = [] + + @classmethod + def config_init(cls) -> Tuple[BaseEnvConfig, List[OpenaiConfig]]: + env_config = BaseEnvConfig( + tokenizer_name="NousResearch/DeepHermes-3-Llama-3-3B-Preview", + group_size=8, + use_wandb=True, + rollout_server_url="http://localhost:8000", + total_steps=1000, + batch_size=12, + steps_per_eval=100, + max_token_length=2048, + wandb_name="gsm8k", + ) + server_configs = [ + OpenaiConfig( + model_name="NousResearch/DeepHermes-3-Llama-3-3B-Preview", + base_url="http://localhost:9001/v1", + api_key="x", + num_requests_for_eval=256, + ), + ] + + return env_config, server_configs + + async def wandb_log(self, wandb_metrics: Optional[Dict] = None): + if wandb_metrics is None: + wandb_metrics = {} + + # Try to calculate percent_correct, pass if there's a division by zero + try: + wandb_metrics["train/percent_correct"] = sum( + self.percent_correct_buffer + ) / len(self.percent_correct_buffer) + except ZeroDivisionError: + # Skip if buffer is empty + pass + + self.percent_correct_buffer = list() + for item in self.eval_metrics: + wandb_metrics[item[0]] = item[1] + self.eval_metrics = list() + # Call the parent method to handle the server metrics + await super().wandb_log(wandb_metrics) + + async def setup(self): + self.train = load_dataset("gsm8k", "main", split="train").shuffle(seed=42) + test_data = load_dataset("gsm8k", "main", split="test").shuffle(seed=42) + self.test = list() + for item in test_data: + self.test.append( + { + "question": item["question"], + "gold_answer": item["answer"] + .split("#")[-1] + .strip() + .replace(",", ""), + } + ) + self.iter = 0 + + def save_checkpoint(self, step, data=None): + if data is None: + data = {} + data["iter"] = self.iter + super().save_checkpoint(step, data) + + async def rollout_and_score_eval(self, question: str, answer: str) -> number: + completion = await self.server.chat_completion( + messages=[ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": question}, + ], + n=1, + max_tokens=self.config.max_token_length, + temperature=0.0, + split="eval", + ) + gold_parsed = parse( + "\\boxed{" + answer + "}", + extraction_mode="first_match", + extraction_config=[LatexExtractionConfig()], + ) + answer_parsed = parse( + completion.choices[0].message.content.split("")[-1], + extraction_config=[ + LatexExtractionConfig( + normalization_config=NormalizationConfig( + nits=False, + malformed_operators=False, + basic_latex=True, + equations=True, + boxed="all", + units=True, + ), + # Ensures that boxed is tried first + boxed_match_priority=0, + try_extract_without_anchor=False, + ) + ], + extraction_mode="first_match", + ) + score = 1 if verify(answer_parsed, gold_parsed) else 0 + return score + + async def evaluate(self, *args, **kwargs): + eval_tasks = [] + for item in self.test: + eval_tasks.append( + self.rollout_and_score_eval(item["question"], item["gold_answer"]) + ) + scores = await tqdm_asyncio.gather(*eval_tasks) + self.eval_metrics.append(("eval/percent_correct", sum(scores) / len(scores))) + + async def collect_trajectories( + self, item: GSM8kRow + ) -> Tuple[ScoredDataGroup, list[Item]]: + user_message = {"role": "user", "content": item["question"]} + gold_answer = ( + "\\boxed{" + item["answer"].split("#")[-1].strip().replace(",", "") + "}" + ) + + chat_completions = await self.server.chat_completion( + messages=[{"role": "system", "content": system_prompt}, user_message], + n=self.config.group_size, + max_tokens=self.config.max_token_length, + ) + to_score = list() + to_backlog = list() + for i, chat_completion in enumerate(chat_completions.choices): + messages = ( + {"role": "system", "content": system_prompt}, + user_message, + {"role": "assistant", "content": chat_completion.message.content}, + ) + to_score.append( + { + "messages": messages, + "gold_answer": gold_answer, + "finish_reason": chat_completion.finish_reason, + } + ) + to_postprocess = await self.score(to_score) + return to_postprocess, to_backlog + + async def score( + self, rollout_group_data + ) -> Union[Optional[ScoredDataGroup], List[Optional[ScoredDataGroup]]]: + scores = ScoredDataGroup() + scores["tokens"] = list() + scores["masks"] = list() + scores["scores"] = list() + gold_parsed = parse( + rollout_group_data[0]["gold_answer"], + extraction_mode="first_match", + extraction_config=[LatexExtractionConfig()], + ) + if len(gold_parsed) != 0: + # We require the answer to be provided in correct latex (no malformed operators) + random.shuffle(rollout_group_data) + for item in rollout_group_data: + # print(item[0][-1]["content"]) + answer_parsed = parse( + item["messages"][-1]["content"].split("")[-1], + extraction_config=[ + LatexExtractionConfig( + normalization_config=NormalizationConfig( + nits=False, + malformed_operators=False, + basic_latex=True, + equations=True, + boxed="all", + units=True, + ), + # Ensures that boxed is tried first + boxed_match_priority=0, + try_extract_without_anchor=False, + ) + ], + extraction_mode="first_match", + ) + # Reward 1 if the content is the same as the ground truth, 0 otherwise + reward = verify(answer_parsed, gold_parsed) + # print( + # f"message: {item[0][-1]['content']}, ground_truth: {item[1]}, reward: {reward}" + # ) + out_dict = tokenize_for_trainer( + self.tokenizer, item["messages"], item["finish_reason"] + ) + tokens = out_dict["tokens"] + masks = out_dict["masks"] + # remove obviously bad examples + if len([1 for i in masks if i != -100]) < 10: + continue + scores["tokens"].append(tokens) + scores["masks"].append(masks) + scores["scores"].append(1.0 if reward else -1.0) + if len(scores["tokens"]) >= self.config.group_size: + break + for score in scores["scores"]: + self.percent_correct_buffer.append(max(score, 0)) + # check if all the same + # print(scores['scores']) + if all([score == 1 for score in scores["scores"]]): + # Do length penalty :) + token_lengths = [len(token) for token in scores["tokens"]] + if max(token_lengths) == 0: + # What? But don't want to crash a run so just in case... + return None + + # Get max allowed token length from config + max_allowed_length = self.config.max_token_length + # Set threshold at 50% of max_token_length - no penalty below this + length_threshold = max_allowed_length * 0.5 + + # Apply modified length penalty with threshold + scores["scores"] = [] + for length in token_lengths: + if length <= length_threshold: + # No penalty for responses under threshold + scores["scores"].append(1.0) + else: + # Calculate how far we are between threshold and max as a percentage + percentage_of_range = (length - length_threshold) / ( + max_allowed_length - length_threshold + ) + # Cap at 1.0 in case length exceeds max_allowed_length + percentage_of_range = min(percentage_of_range, 1.0) + # Apply linear penalty scaling from 1.0 down to 0.0 + scores["scores"].append(1.0 - percentage_of_range) + if all([scores["scores"][0] == score for score in scores["scores"]]): + return None # If all the same, we return None + return scores + else: + # If the gold solution is not parseable, we return None + return None + + async def get_next_item(self) -> GSM8kRow: + next_item = self.train[self.iter % len(self.train)] + self.iter += 1 + return next_item + + +if __name__ == "__main__": + GSM8kEnv.cli() diff --git a/environments/math_server.py b/environments/math_server.py new file mode 100644 index 00000000..e5060016 --- /dev/null +++ b/environments/math_server.py @@ -0,0 +1,1030 @@ +import asyncio +import math +import random +from concurrent.futures import ProcessPoolExecutor +from difflib import SequenceMatcher +from typing import Dict, List, Optional, Tuple + +from datasets import load_dataset +from latex2sympy2_extended import NormalizationConfig +from math_verify import LatexExtractionConfig, parse, verify +from math_verify.errors import TimeoutException +from pydantic import Field +from tqdm.asyncio import tqdm_asyncio + +import wandb +from atroposlib.envs.base import ( + BaseEnv, + BaseEnvConfig, + EvalHandlingEnum, + OpenaiConfig, + ScoredDataGroup, +) +from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer + +problem_format = "{problem}" + +judge_format = """Here is a math problem and a proposed solution: + +[START PROBLEM] +{problem} +[END PROBLEM] +[START SOLUTION] +{solution} +[END SOLUTION] + +Please verify if it is correct or not. + +If it's correct submit your answer in your response with \\boxed{{True}}. +If it's incorrect, please submit your answer in your response with \\boxed{{False}}. + +Please include how to solve the problem correctly in your answer.""" + + +retry_format = """Here is a math problem, a proposed solution, and a verification of the solution: +[START PROBLEM] +{problem} +[END PROBLEM] +[START SOLUTION] +{solution} +[END SOLUTION] +[START VERIFICATION] +{verification} +[END VERIFICATION] + +Please use this verification to help you solve the problem correctly. + +Provide your answer in your response with \\boxed{{answer}}.""" # noqa: E501 + + +rlaif_format = """Here is a math problem, and two solutions that are correct. Please choose whichever answer you prefer. +[START PROBLEM] +{problem} +[END PROBLEM] +[START SOLUTION 1] +{solution1} +[END SOLUTION 1] +[START SOLUTION 2] +{solution2} +[END SOLUTION 2] + +Here are some metrics for you to use to grade the two solutions: +- Conciseness: How concise is the solution? Is it too long or too short? +- Clarity: How clear is the solution? Is it easy to understand? +- Correctness: Is the reasoning correct? The answer has been prechecked to be correct, but there may be errors in the reasoning. + +Please use these metrics to help you choose the best solution, in order of priority. + +Please provide your answer in your response with \\boxed{{1}}, for the first solution, or \\boxed{{2}} for the second solution.""" # noqa: E501 + + +class RSConfig(BaseEnvConfig): + run_evaluation: bool = Field(True, description="If this should run evaluation") + mask_too_long_completions: bool = Field( + True, description="If this should mask too long completions" + ) + percent_to_judge: float = Field(0.3, description="The percentage of items to judge") + percent_length_penalty: float = Field( + 0.0, description="The percentage of items to have length penalty" + ) + + +def quick_similarity(a, b): + return SequenceMatcher(None, a, b).ratio() + + +def score_answer(gold, resp) -> Optional[bool]: + try: + gold_parsed = parse( + gold, + extraction_mode="first_match", + extraction_config=[LatexExtractionConfig()], + ) + except (Exception, TimeoutException, KeyError, TypeError, NotImplementedError): + return None + if len(gold_parsed) != 0: + # print(item[0][-1]["content"]) + try: + answer_parsed = parse( + resp, + extraction_config=[ + LatexExtractionConfig( + normalization_config=NormalizationConfig( + nits=False, + malformed_operators=False, + basic_latex=True, + boxed="all", + units=True, + ), + # Ensures that boxed is tried first + boxed_match_priority=0, + try_extract_without_anchor=False, + ) + ], + extraction_mode="first_match", + ) + except ( + Exception, + TimeoutException, + KeyError, + TypeError, + NotImplementedError, + ): + # Can't parse, so we skip + return None + # Reward 1 if the content is the same as the ground truth, 0 otherwise + try: + return verify(answer_parsed, gold_parsed) + except ( + Exception, + TimeoutException, + KeyError, + TypeError, + NotImplementedError, + ): + return None + return None + + +class MathEnv(BaseEnv): + + name = "math" + env_config_cls = RSConfig + + def __init__( + self, + config: RSConfig, + server_configs: List[OpenaiConfig], + slurm=True, + testing=False, + ): + print("Initializing MathEnv") + print(f"Slurm: {slurm}, Testing: {testing}") + super().__init__(config, server_configs, slurm, testing) + self.percent_correct_buffer = list() + self.eval_metrics = list() + self.mp_executor = ProcessPoolExecutor(64) + self.percent_overanswer = list() + self.percent_judge_correct = list() + self.correct_answer_len = list() + self.incorrect_answer_len = list() + self.normal_rollouts = list() + self.rlaif_rollouts = list() + self.pass_at_groupsize = list() + self.judge_rollouts = list() + self.selfcorrect_rollouts = list() + self.judge_success_rate = list() + self.iter = 0 + + @classmethod + def config_init(self) -> Tuple[RSConfig, List[OpenaiConfig]]: + env_config = RSConfig( + tokenizer_name="deepseek-ai/DeepSeek-R1-Distill-Llama-8B", + group_size=8, + use_wandb=True, + rollout_server_url="http://localhost:8000", + total_steps=1000, + batch_size=1024, + steps_per_eval=25, + max_token_length=31000, # 22000 // (2 ** i), + wandb_name="math", + eval_handling=EvalHandlingEnum.LIMIT_TRAIN, + eval_limit_ratio=0.1, + ) + server_configs = [ + OpenaiConfig( + model_name="deepseek-ai/DeepSeek-R1-Distill-Llama-8B", + base_url="http://localhost:9004/v1", + api_key="x", + num_requests_for_eval=256, # since evaling only on one... + ), + ] + + return env_config, server_configs + + async def wandb_log(self, wandb_metrics: Optional[Dict] = None): + if wandb_metrics is None: + wandb_metrics = dict() + if len(self.pass_at_groupsize) > 0: + wandb_metrics["train/pass_at_groupsize"] = sum( + self.pass_at_groupsize + ) / len(self.pass_at_groupsize) + self.pass_at_8 = list() + if len(self.percent_correct_buffer) > 0: + wandb_metrics["train/percent_correct"] = sum( + self.percent_correct_buffer + ) / len(self.percent_correct_buffer) + wandb_metrics["train/percent_overanswer"] = sum( + self.percent_overanswer + ) / len(self.percent_overanswer) + self.percent_overthink = list() + self.percent_overanswer = list() + self.percent_correct_buffer = list() + if len(self.correct_answer_len) > 0: + wandb_metrics["train/avg_correct_answer_len"] = sum( + self.correct_answer_len + ) / len(self.correct_answer_len) + self.correct_answer_len = list() + if len(self.incorrect_answer_len) > 0: + wandb_metrics["train/avg_incorrect_answer_len"] = sum( + self.incorrect_answer_len + ) / len(self.incorrect_answer_len) + self.incorrect_answer_len = list() + if len(self.percent_judge_correct) > 0: + wandb_metrics["judge_train/percent_judge_correct"] = sum( + self.percent_judge_correct + ) / len(self.percent_judge_correct) + self.percent_judge_correct = list() + if len(self.judge_success_rate) > 0: + wandb_metrics["judge_train/judge_success_rate"] = sum( + self.judge_success_rate + ) / len(self.judge_success_rate) + # create tables + if len(self.judge_rollouts) > 0: + table = wandb.Table( + columns=["problem", "solution", "answer", "correct", "judge"] + ) + for group in self.judge_rollouts: + table.add_data(group[0], group[1], group[2], group[3], group[4]) + wandb_metrics["judge_train/judge_rollouts"] = table + if len(self.selfcorrect_rollouts) > 0: + table = wandb.Table(columns=["problem", "solution1", "solution2", "score"]) + for group in self.selfcorrect_rollouts: + table.add_data(group[0], group[1], group[2], group[3]) + wandb_metrics["judge_train/selfcorrect_rollouts"] = table + if len(self.normal_rollouts) > 0: + table = wandb.Table(columns=["problem", "solution", "answer", "score"]) + for group in self.normal_rollouts: + table.add_data(group[0], group[1], group[2], group[3]) + wandb_metrics["train/normal_rollouts"] = table + if len(self.rlaif_rollouts) > 0: + table = wandb.Table( + columns=["problem", "solution1", "solution2", "score", "rollout"] + ) + for group in self.rlaif_rollouts: + table.add_data(group[0], group[1], group[2], group[3], group[4]) + wandb_metrics["train/rlaif_rollouts"] = table + wandb_metrics["train/iter"] = self.iter + for item in self.eval_metrics: + wandb_metrics[item[0]] = item[1] + self.eval_metrics = list() + await super().wandb_log(wandb_metrics) + + async def setup(self): + self.train = load_dataset("zwhe99/DeepMath-103K", split="train").shuffle( + seed=42 + ) + aime_test_data = load_dataset("HuggingFaceH4/aime_2024", split="train") + math500_test_data = load_dataset("HuggingFaceH4/math-500", split="test") + amc_test_data = load_dataset("math-ai/amc23", split="test") + self.test = list() + for name, t_dataset in zip( + ["aime24", "math500"], [aime_test_data, math500_test_data] + ): + for item in t_dataset: + self.test.append( + ( + problem_format.format(problem=item["problem"]), + item["answer"], + name, + ) + ) + for name, t_dataset in zip( + ["amc23"], + [amc_test_data], + ): + for item in t_dataset: + self.test.append( + ( + problem_format.format(problem=item["question"]), + item["answer"], + name, + ) + ) + + async def rollout_and_score_eval(self, question, answer, subset): + + completion = await self.server.chat_completion( + messages=[ + {"role": "user", "content": question}, + ], + n=1, + max_tokens=32765, + temperature=0.0, + split="eval", + ) + loop = asyncio.get_event_loop() + gold = "\\boxed{" + answer + "}" if "\\boxed" not in answer else answer + resp = completion.choices[0].message.content.split("")[-1] + task = loop.run_in_executor(self.mp_executor, score_answer, gold, resp) + reward = await task + if reward is None: + return 0, subset + score = 1 if reward else 0 + return score, subset + + async def evaluate(self, *args, **kwargs): + if not self.config.run_evaluation: + return + eval_tasks = [] + for item in self.test: + eval_tasks.append(self.rollout_and_score_eval(item[0], item[1], item[2])) + parsing_data = await tqdm_asyncio.gather(*eval_tasks) + task_lists = dict() + for score, subset in parsing_data: + if subset not in task_lists: + task_lists[subset] = list() + task_lists[subset].append(score) + # Now get the average + for subset, scores in task_lists.items(): + self.eval_metrics.append( + (f"eval/{subset}_percent_correct", sum(scores) / len(scores)) + ) + # overall score + scores = [] + for subset, score in task_lists.items(): + scores.extend(score) + self.eval_metrics.append( + ("eval/overall_percent_correct", sum(scores) / len(scores)) + ) + + async def collect_trajectories_normal(self, item) -> Tuple[List, List]: + thinking_len = self.config.max_token_length + user_prompt = problem_format.format(problem=item[0]) + chat = [ + {"role": "user", "content": user_prompt}, + ] + thinking_len = thinking_len - len( + self.tokenizer.apply_chat_template(chat, add_generation_prompt=True) + ) + chat_completions = await self.server.chat_completion( + messages=chat, + n=self.config.group_size, + max_tokens=thinking_len, + temperature=1.0, + top_p=0.95, + ) + print("Finished generation", flush=True) + to_score = list() + to_backlog = list() + for i, chat_completion in enumerate(chat_completions.choices): + messages = ( + {"role": "user", "content": user_prompt}, + {"role": "assistant", "content": chat_completion.message.content}, + ) + to_score.append( + ( + messages, + item[1], + chat_completion.finish_reason, + ) + ) + + to_postprocess = await self.score_normal(to_score) + if to_postprocess is None: + return None, to_backlog + if all( + [to_postprocess["scores"][0] == score for score in to_postprocess["scores"]] + ): + if to_postprocess["scores"][0] == 1.0: + # we can do RLAIF + # find the two most dissimilar messages + messages = to_postprocess["messages"] + score_matrix = [] + most_dissimilar = (0, 1) + most_dissimilar_score = 1.0 + # find the two most dissimilar messages + for i in range(len(messages) - 1): + score_matrix.append([]) + for j in range(i + 1): + # Only need to compute half of the matrix + score_matrix[i].append(1.0) + for j in range(i + 1, len(messages)): + m1 = messages[i][-1]["content"].split("")[-1] + m2 = messages[j][-1]["content"].split("")[-1] + if m1 == m2: + score_matrix[i].append(1.0) + else: + score_matrix[i].append(quick_similarity(m1, m2)) + if score_matrix[i][j] < most_dissimilar_score: + most_dissimilar = (i, j) + most_dissimilar_score = score_matrix[i][j] + if most_dissimilar_score < 0.975: + # send over to RLAIF + to_backlog.append( + ( + item[0], + item[1], + "rlaif", + tuple( + [ + frozenset(item.items()) + for item in messages[most_dissimilar[0]] + ] + ), + tuple( + [ + frozenset(item.items()) + for item in messages[most_dissimilar[1]] + ] + ), + most_dissimilar_score, + ) + ) + print( + "\n".join( + [ + "[" + + ", ".join([str(item) for item in score_matrix_row]) + + "]" + for score_matrix_row in score_matrix + ] + ) + ) + print( + f"Sending to RLAIF, most dissimilar score: {most_dissimilar_score}" + ) + else: + print( + f"Unable to RLAIF, most dissimilar score: {most_dissimilar_score}" + ) + if random.random() < self.config.percent_length_penalty: + # Check if deltas of message lengths are different enough to want to length penalty on + message_lengths = [ + len(tokens) for tokens in to_postprocess["tokens"] + ] + min_message_length = min(message_lengths) + max_message_delta = max( + [msg_len - min_message_length for msg_len in message_lengths] + ) + if max_message_delta > 0.1 * min_message_length: + print( + "Max message delta is greater than 0.1 * shortest message, adding length penalty" + ) + for i in range(len(to_postprocess["scores"])): + len_penalty = ( + message_lengths[i] - min_message_length + ) / max_message_delta + len_penalty = math.cos(len_penalty * math.pi) + to_postprocess["scores"][i] = len_penalty + else: + print( + "Max message delta is less than 0.1 * shortest message, no length penalty" + ) + return None, to_backlog + else: + return None, to_backlog + else: + return None, to_backlog + else: + self.normal_rollouts.append( + ( + item[0], + to_postprocess["messages"][0], + item[1], + to_postprocess["scores"][0], + ) + ) + print("Sending to judge potentially") + if random.random() < self.config.percent_to_judge: + # find first pos and neg scored answers. + pos_idx = [ + i + for i, score in enumerate(to_postprocess["scores"]) + if score == 1.0 + ][0] + neg_idx = [ + i + for i, score in enumerate(to_postprocess["scores"]) + if (score == -1.0) + and ( + not to_postprocess["overrides"][i].get( + "set_advantage_to_zero", False + ) + ) + ] + if len(neg_idx) == 0: + return None, to_backlog + neg_idx = neg_idx[0] + if pos_idx is not None and neg_idx is not None: + to_backlog.append( + ( + item[0], + item[1], + "judge", + to_postprocess["messages"][pos_idx][-1]["content"].split( + "" + )[-1], + "True", + ) + ) + to_backlog.append( + ( + item[0], + item[1], + "judge", + to_postprocess["messages"][neg_idx][-1]["content"].split( + "" + )[-1], + "False", + ) + ) + print("sending to judge") + else: + return None, to_backlog + print(f"Collected {len(to_postprocess['scores'])} trajectories") + if not self.config.mask_too_long_completions: + to_postprocess["overrides"] = [ + {} for _ in range(len(to_postprocess["scores"])) + ] + return to_postprocess, to_backlog + + async def collect_trajectories(self, item) -> Tuple[List, List]: + if item[2] == "normal": + return await self.collect_trajectories_normal(item) + elif item[2] == "rlaif": + return await self.collect_trajectories_rlaif(item) + elif item[2] == "judge": + return await self.collect_trajectories_judge(item) + elif item[2] == "selfcorrect": + # selfcorrect is a special case where we are using the Judge rollout + print("selfcorrect processing...") + print("selfcorrect item:", item, flush=True) + group = item[3] + scores = item[4] + finish_reasons = item[5] + to_postprocess = ScoredDataGroup() + to_postprocess["tokens"] = list() + to_postprocess["masks"] = list() + to_postprocess["scores"] = list() + to_postprocess["overrides"] = list() + to_postprocess["messages"] = list() + for i in range(len(group)): + # convert from frozen set to dict + conv = [dict(x) for x in group[i]] + if i == 0: + self.selfcorrect_rollouts.append( + ( + item[0], + item[1], + conv[0]["content"], + conv[1]["content"], + ) + ) + if ( + len(self.selfcorrect_rollouts) + >= self.config.num_rollouts_to_keep + ): + self.selfcorrect_rollouts.pop(0) + out_dict = tokenize_for_trainer( + tokenizer=self.tokenizer, + chat=conv, + finish_reason=finish_reasons[i], + include_messages=True, + ) + to_postprocess["tokens"].append(out_dict["tokens"]) + to_postprocess["masks"].append(out_dict["masks"]) + to_postprocess["scores"].append(scores[i]) + to_postprocess["overrides"].append(dict()) + if (finish_reasons[i] == "length") and ( + self.config.mask_too_long_completions + ): + to_postprocess["overrides"][-1]["set_advantage_to_zero"] = True + to_postprocess["messages"].append(out_dict["messages"]) + print("selfcorrect done, sending batch off") + return to_postprocess, [] + else: + raise ValueError(f"Unknown rollout type: {item[2]}") + + async def score_normal(self, rollout_group_data: List) -> Optional[ScoredDataGroup]: + scores = ScoredDataGroup() + scores["tokens"] = list() + scores["masks"] = list() + scores["scores"] = list() + scores["overrides"] = list() + scores["messages"] = list() + gold = rollout_group_data[0][1] + loop = asyncio.get_event_loop() + random.shuffle(rollout_group_data) + for item in rollout_group_data: + resp = item[0][-1]["content"].split("")[-1] + scores["overrides"].append(dict()) + if item[2] == "length": + reward = False + if self.config.mask_too_long_completions: + scores["overrides"][-1]["set_advantage_to_zero"] = True + else: + task = loop.run_in_executor(self.mp_executor, score_answer, gold, resp) + reward = await task + if reward is None: + return None + out_dict = tokenize_for_trainer( + tokenizer=self.tokenizer, + chat=item[0], + finish_reason=item[2], + include_messages=True, + ) + tokens = out_dict["tokens"] + masks = out_dict["masks"] + messages = out_dict["messages"] + # remove obviously bad examples + if len([1 for i in masks if i != -100]) < 10: + continue + if item[2] == "length": + # Note we set it here so we can filter out the examples that are too long + # for the Judge loop. IF you set the config to not do this we fix it + # in the collect_trajectories_normal function. + scores["overrides"][-1]["set_advantage_to_zero"] = True + scores["tokens"].append(tokens) + scores["masks"].append(masks) + scores["scores"].append(1.0 if reward else -1.0) + scores["messages"].append(messages) + if len(scores["tokens"]) >= self.config.group_size: + break + if any([score == 1.0 for score in scores["scores"]]): + self.pass_at_groupsize.append(1.0) + else: + self.pass_at_groupsize.append(0.0) + if len(scores["tokens"]) < self.config.group_size: + # We don't have enough data to score + return None + for score in scores["scores"]: + self.percent_correct_buffer.append(max(score, 0)) + self.percent_overanswer.extend( + [item[2] == "length" for item in rollout_group_data] + ) + # check if all the same + # print(scores['scores']) + # Fill in the correct/incorrect lens after so we're only looking at actual training data + self.correct_answer_len.extend( + [ + len(scores["tokens"][i]) + for i in range(len(scores["scores"])) + if scores["scores"][i] == 1.0 + ] + ) + self.incorrect_answer_len.extend( + [ + len(scores["tokens"][i]) + for i in range(len(scores["scores"])) + if (scores["scores"][i] == -1.0) + and (not scores["overrides"][i].get("set_advantage_to_zero", False)) + ] + ) + return scores + + async def get_next_item(self): + while True: + next_item = self.train[self.iter % len(self.train)] + self.iter += 1 + prompt = next_item["question"] + try: + answer = ( + ("\\boxed{" + next_item["final_answer"] + "}") + if "\\boxed" not in next_item["final_answer"] + else next_item["final_answer"] + ) + break + except TypeError: + print( + f"Error in getting next item, trying again, " + f"data: {next_item['question']} -> {next_item['final_answer']}" + ) + return (prompt, answer, "normal") + + async def collect_trajectories_rlaif(self, frozen_item) -> Tuple[List, List]: + to_backlog = list() + print("Attempting RLAIF") + item = list(frozen_item) + print("Converting to dicts") + item[3] = [dict(x) for x in item[3]] + item[4] = [dict(x) for x in item[4]] + print("Formatting user prompts") + user_prompt_fwd = rlaif_format.format( + problem=item[0], + solution1=item[3][-1]["content"].split("")[-1], + solution2=item[4][-1]["content"].split("")[-1], + ) + user_prompt_bwd = rlaif_format.format( + problem=item[0], + solution1=item[4][-1]["content"].split("")[-1], + solution2=item[3][-1]["content"].split("")[-1], + ) + print("Sending to server") + chat = [ + {"role": "user", "content": user_prompt_fwd}, + ] + max_token_length = self.config.max_token_length - len( + self.tokenizer.apply_chat_template(chat, add_generation_prompt=True) + ) + chat_completions_fwd = self.server.chat_completion( + messages=chat, + n=3, + max_tokens=max_token_length, + temperature=1.0, + top_p=0.95, + ) + print("Sending to server") + # Should be the same token length as the fwd but tokenizers are cursed + chat = [ + {"role": "user", "content": user_prompt_bwd}, + ] + max_token_length = self.config.max_token_length - len( + self.tokenizer.apply_chat_template(chat, add_generation_prompt=True) + ) + chat_completions_bwd = self.server.chat_completion( + messages=chat, + n=3, + max_tokens=self.config.max_token_length, + temperature=1.0, + top_p=0.95, + ) + print("Gathering completions") + chat_completions_fwd, chat_completions_bwd = await asyncio.gather( + chat_completions_fwd, chat_completions_bwd + ) + print("Grabbed RLAIF completions") + # Check for correct answers + score_1 = 0 + score_2 = 0 + for chat_completion in chat_completions_fwd.choices: + score = ( + chat_completion.message.content.split("")[-1] + .split("\\boxed{")[-1] + .split("}")[0] + .strip() + ) + if score == "1": + score_1 += 1 + elif score == "2": + score_2 += 1 + for chat_completion in chat_completions_bwd.choices: + score = ( + chat_completion.message.content.split("")[-1] + .split("\\boxed{")[-1] + .split("}")[0] + .strip() + ) + if score == "1": + score_2 += 1 + elif score == "2": + score_1 += 1 + print(f"Score 1: {score_1}, Score 2: {score_2}") + if score_1 == score_2: + return None, [] + self.rlaif_rollouts.append( + ( + item[0], + item[3][-1]["content"].split("")[-1], + item[4][-1]["content"].split("")[-1], + score_1 - score_2, + chat_completions_fwd.choices[0].message.content, + ) + ) + if len(self.rlaif_rollouts) >= self.config.num_rollouts_to_keep: + self.rlaif_rollouts.pop(0) + print("RLAIF rollout added") + to_postprocess = ScoredDataGroup() + to_postprocess["tokens"] = list() + to_postprocess["masks"] = list() + to_postprocess["scores"] = list() + to_postprocess["overrides"] = list() + to_postprocess["messages"] = list() + # add the first message in + out_dict = tokenize_for_trainer( + tokenizer=self.tokenizer, chat=item[3], include_messages=True + ) + tokens = out_dict["tokens"] + masks = out_dict["masks"] + to_postprocess["tokens"].append(tokens) + to_postprocess["masks"].append(masks) + to_postprocess["scores"].append(1.0 if score_1 > score_2 else -1.0) + to_postprocess["messages"].append(out_dict["messages"]) + out_dict = tokenize_for_trainer( + tokenizer=self.tokenizer, chat=item[4], include_messages=True + ) + tokens = out_dict["tokens"] + masks = out_dict["masks"] + to_postprocess["tokens"].append(tokens) + to_postprocess["masks"].append(masks) + to_postprocess["scores"].append(1.0 if score_2 > score_1 else -1.0) + to_postprocess["messages"].append(out_dict["messages"]) + to_postprocess["group_overrides"] = { + "group_size": 2, + } + print("RLAIF rollout added") + return to_postprocess, to_backlog + + async def collect_trajectories_judge(self, item) -> Tuple[List, List]: + user_prompt = judge_format.format( + problem=item[0], + solution=item[3], + ) + to_backlog = list() + chat = [ + {"role": "user", "content": user_prompt}, + ] + max_token_length = self.config.max_token_length - len( + self.tokenizer.apply_chat_template(chat, add_generation_prompt=True) + ) + chat_completions = await self.server.chat_completion( + messages=chat, + n=self.config.group_size, + max_tokens=max_token_length, + temperature=1.0, + top_p=0.95, + ) + is_correct = [ + ( + chat_completion.message.content.split("")[-1] + .split("\\boxed{")[-1] + .split("}")[0] + .strip() + == item[4] + ) + and (chat_completion.finish_reason != "length") + for chat_completion in chat_completions.choices + ] + self.percent_judge_correct.append( + sum([1.0 if val else 0.0 for val in is_correct]) / len(is_correct) + ) + if all([not val for val in is_correct]): + # Can't judge :( + return None, [] + scores = ScoredDataGroup() + scores["tokens"] = [] + scores["masks"] = [] + scores["scores"] = [] + scores["overrides"] = [] + scores["messages"] = [] + for_table = [] + for i, chat_completion in enumerate(chat_completions.choices): + out_dict = tokenize_for_trainer( + tokenizer=self.tokenizer, + chat=[ + {"role": "user", "content": user_prompt}, + {"role": "assistant", "content": chat_completion.message.content}, + ], + include_messages=True, + ) + tokens = out_dict["tokens"] + masks = out_dict["masks"] + messages = out_dict["messages"] + if not is_correct[i]: + scores["tokens"].append(tokens) + scores["masks"].append(masks) + scores["scores"].append(-1.0) + scores["messages"].append(messages) + scores["overrides"].append(dict()) + if (chat_completion.finish_reason == "length") and ( + self.config.mask_too_long_completions + ): + scores["overrides"][-1]["set_advantage_to_zero"] = True + else: + if len(for_table) == 0: + # populate the table + for_table = [ + item[0], + item[1], + item[3], + item[4], + chat_completion.message.content, + ] + if item[4] == "False": + # Score based on percentage correct from retry + print("Scoring retry") + retry_prompt = retry_format.format( + problem=item[0], + solution=item[3], + verification=chat_completion.message.content.split("")[ + -1 + ], + ) + print("Sending to server") + retry_messages = [ + {"role": "user", "content": retry_prompt}, + ] + max_token_length = self.config.max_token_length - len( + self.tokenizer.apply_chat_template( + retry_messages, add_generation_prompt=True + ) + ) + retry_chat_completions = await self.server.chat_completion( + messages=retry_messages, + n=self.config.group_size, + max_tokens=max_token_length, + temperature=1.0, + top_p=0.95, + ) + print("Gathering completions") + scoring_data = [] + backlog_scores = [] + backlog_reasons = [] + backlog_messages = [] + for j, retry_chat_completion in enumerate( + retry_chat_completions.choices + ): + print(f"Scoring generation {j} for retry...") + backlog_messages.append( + tuple( + [frozenset(msg.items()) for msg in retry_messages] + + [ + frozenset( + { + "role": "assistant", + "content": retry_chat_completion.message.content, + }.items() + ) + ] + ) + ) + backlog_reasons.append(retry_chat_completion.finish_reason) + if retry_chat_completion.finish_reason == "length": + scoring_data.append(0) + backlog_scores.append(0) + else: + loop = asyncio.get_event_loop() + task = loop.run_in_executor( + self.mp_executor, + score_answer, + item[1], + retry_chat_completion.message.content.split("")[ + -1 + ], + ) + reward = await task + scoring_data.append(1.0 if reward else 0.0) + backlog_scores.append(1.0 if reward else -1.0) + + if ( + not all( + backlog_score == backlog_scores[0] + for backlog_score in backlog_scores + ) + ) or ( + all( + backlog_reasons == 1.0 for backlog_reason in backlog_reasons + ) + and (random.random() < self.config.percent_length_penalty) + ): + to_backlog.append( + ( + item[0], + item[1], + "selfcorrect", + tuple(backlog_messages), + tuple(backlog_scores), + tuple(backlog_reasons), + ) + ) + print(f"Sending to selfcorrect, {len(to_backlog)} in backlog") + scores["scores"].append(sum(scoring_data) / len(scoring_data)) + self.judge_success_rate.append( + sum(scoring_data) / len(scoring_data) + ) + if len(self.judge_success_rate) >= self.config.num_rollouts_to_keep: + self.judge_success_rate.pop(0) + else: + scores["scores"].append(1.0) + scores["tokens"].append(tokens) + scores["masks"].append(masks) + scores["messages"].append(messages) + scores["overrides"].append(dict()) + if all([score == 1.0 for score in scores["scores"]]) and ( + random.random() < self.config.percent_length_penalty + ): + # Do len penalty + message_lengths = [len(tokens) for tokens in scores["tokens"]] + min_message_length = min(message_lengths) + max_message_delta = max( + [msg_len - min_message_length for msg_len in message_lengths] + ) + if max_message_delta > 0.1 * min_message_length: + print( + "Max message delta is greater than 0.1 * shortest message, adding length penalty" + ) + for i in range(len(scores["scores"])): + len_penalty = ( + message_lengths[i] - min_message_length + ) / max_message_delta + len_penalty = math.cos(len_penalty * math.pi) + scores["scores"][i] = len_penalty + else: + print( + "Max message delta is less than 0.1 * shortest message, no length penalty" + ) + return None, [] + elif all([score == scores["scores"][0] for score in scores["scores"]]): + return None, [] + if len(for_table) > 0: + self.judge_rollouts.append(for_table) + if len(self.judge_rollouts) >= self.config.num_rollouts_to_keep: + self.judge_rollouts.pop(0) + print( + f"Collected {len(scores['scores'])} trajectories with {len(to_backlog)} in backlog" + ) + return scores, to_backlog + + +if __name__ == "__main__": + MathEnv.cli() diff --git a/environments/math_server_zero.py b/environments/math_server_zero.py new file mode 100644 index 00000000..38acf49d --- /dev/null +++ b/environments/math_server_zero.py @@ -0,0 +1,447 @@ +""" +This file contains code inspired by and adapted from the Open-Reasoner-Zero project. +Original Repository: https://github.com/Open-Reasoner-Zero/Open-Reasoner-Zero +""" + +import asyncio +import random +import re +from concurrent.futures import ProcessPoolExecutor +from typing import Dict, List, Optional, Tuple + +from datasets import load_dataset +from latex2sympy2_extended import NormalizationConfig +from math_verify import LatexExtractionConfig, parse, verify +from math_verify.errors import TimeoutException +from pydantic import Field +from tqdm.asyncio import tqdm_asyncio + +import wandb +from atroposlib.envs.base import ( + BaseEnv, + BaseEnvConfig, + EvalHandlingEnum, + OpenaiConfig, + ScoredDataGroup, +) + +prompt_format = ( + "A conversation between User and Assistant. The User asks a question, and the Assistant solves it. The Assistant " + "first thinks about the reasoning process in the mind and then provides the User with the answer. The reasoning " + "process is enclosed within and answer is enclosed within tags, respectively, " + "i.e., reasoning process here answer here . User: {prompt}\nAssistant: " +) + +problem_format = """You must put your answer inside tags, i.e., answer here . And your final answer will be extracted automatically by the \\boxed{{}} tag. +This is the problem: +{problem} +""" # noqa: E501 + +stop_list = ["User:", "Human:", "Assistant:", ""] + + +class RSConfig(BaseEnvConfig): + run_evaluation: bool = Field(True, description="If this should run evaluation") + mask_too_long_completions: bool = Field( + True, description="If this should mask too long completions" + ) + percent_length_penalty: float = Field( + 0.0, description="The percentage of items to have length penalty" + ) + + +def score_answer(gold, resp) -> Optional[bool]: + pattern = re.compile(r".*?(\\boxed{.*}).*?", re.DOTALL) + matches = pattern.findall(resp) + resp = matches[-1] if matches else None + if resp is None: + return False + try: + gold_parsed = parse( + gold, + extraction_mode="first_match", + extraction_config=[LatexExtractionConfig()], + ) + except (Exception, TimeoutException, KeyError, TypeError, NotImplementedError): + return None + if len(gold_parsed) != 0: + try: + answer_parsed = parse( + resp, + extraction_config=[ + LatexExtractionConfig( + normalization_config=NormalizationConfig( + nits=False, + malformed_operators=False, + basic_latex=True, + boxed="all", + units=True, + ), + # Ensures that boxed is tried first + boxed_match_priority=0, + try_extract_without_anchor=False, + ) + ], + extraction_mode="first_match", + ) + except ( + Exception, + TimeoutException, + KeyError, + TypeError, + NotImplementedError, + ): + # Can't parse, so we skip + return None + # Reward 1 if the content is the same as the ground truth, 0 otherwise + try: + return verify(answer_parsed, gold_parsed) + except ( + Exception, + TimeoutException, + KeyError, + TypeError, + NotImplementedError, + ): + return None + return None + + +class MathEnv(BaseEnv): + + name = "math" + env_config_cls = RSConfig + + def __init__( + self, + config: RSConfig, + server_configs: List[OpenaiConfig], + slurm=True, + testing=False, + ): + print("Initializing MathEnv") + print(f"Slurm: {slurm}, Testing: {testing}") + super().__init__(config, server_configs, slurm, testing) + self.percent_correct_buffer = list() + self.eval_metrics = list() + self.mp_executor = ProcessPoolExecutor(64) + self.percent_overanswer = list() + self.correct_answer_len = list() + self.incorrect_answer_len = list() + self.normal_rollouts = list() + self.pass_at_groupsize = list() + self.iter = 0 + + @classmethod + def config_init(cls) -> Tuple[RSConfig, List[OpenaiConfig]]: + env_config = RSConfig( + tokenizer_name="Qwen/Qwen2.5-7B", + group_size=8, + use_wandb=True, + rollout_server_url="http://localhost:8000", + total_steps=1000, + batch_size=1024, + steps_per_eval=25, + max_token_length=31000, # 22000 // (2 ** i), + wandb_name="math", + eval_handling=EvalHandlingEnum.LIMIT_TRAIN, + eval_limit_ratio=0.1, + ) + server_configs = [ + OpenaiConfig( + model_name="default", + base_url="http://localhost:9004/v1", + api_key="x", + num_requests_for_eval=256, # since evaling only on one... + ), + ] + + return env_config, server_configs + + async def wandb_log(self, wandb_metrics: Optional[Dict] = None): + if wandb_metrics is None: + wandb_metrics = dict() + if len(self.pass_at_groupsize) > 0: + wandb_metrics["train/pass_at_groupsize"] = sum( + self.pass_at_groupsize + ) / len(self.pass_at_groupsize) + self.pass_at_8 = list() + if len(self.percent_correct_buffer) > 0: + wandb_metrics["train/percent_correct"] = sum( + self.percent_correct_buffer + ) / len(self.percent_correct_buffer) + wandb_metrics["train/percent_overanswer"] = sum( + self.percent_overanswer + ) / len(self.percent_overanswer) + self.percent_overthink = list() + self.percent_overanswer = list() + self.percent_correct_buffer = list() + if len(self.correct_answer_len) > 0: + wandb_metrics["train/avg_correct_answer_len"] = sum( + self.correct_answer_len + ) / len(self.correct_answer_len) + self.correct_answer_len = list() + if len(self.incorrect_answer_len) > 0: + wandb_metrics["train/avg_incorrect_answer_len"] = sum( + self.incorrect_answer_len + ) / len(self.incorrect_answer_len) + self.incorrect_answer_len = list() + # create tables + if len(self.normal_rollouts) > 0: + table = wandb.Table(columns=["problem", "solution", "answer", "score"]) + for group in self.normal_rollouts: + table.add_data(group[0], group[1], group[2], group[3]) + wandb_metrics["train/normal_rollouts"] = table + wandb_metrics["train/iter"] = self.iter + for item in self.eval_metrics: + wandb_metrics[item[0]] = item[1] + self.eval_metrics = list() + await super().wandb_log(wandb_metrics) + + async def setup(self): + self.train = load_dataset("zwhe99/DeepMath-103K", split="train").shuffle( + seed=42 + ) + aime_test_data = load_dataset("HuggingFaceH4/aime_2024", split="train") + math500_test_data = load_dataset("HuggingFaceH4/math-500", split="test") + amc_test_data = load_dataset("math-ai/amc23", split="test") + minerva_test_data = load_dataset("math-ai/minervamath", split="test") + olympiad_test_data = load_dataset("math-ai/olympiadbench", split="test") + self.test = list() + for name, t_dataset in zip( + ["aime24", "math500"], [aime_test_data, math500_test_data] + ): + for item in t_dataset: + self.test.append( + ( + prompt_format.format( + prompt=problem_format.format(problem=item["problem"]) + ), + item["answer"], + name, + ) + ) + for name, t_dataset in zip( + ["amc23", "minerva", "olympiad"], + [amc_test_data, minerva_test_data, olympiad_test_data], + ): + for item in t_dataset: + self.test.append( + ( + prompt_format.format( + prompt=problem_format.format(problem=item["question"]) + ), + item["answer"], + name, + ) + ) + return + + async def rollout_and_score_eval(self, question, answer, subset): + + completion = await self.server.completion( + prompt=question, + n=1, + max_tokens=32765, + temperature=0.0, + split="eval", + stop=stop_list, + ) + loop = asyncio.get_event_loop() + gold = "\\boxed{" + answer + "}" if "\\boxed" not in answer else answer + resp = completion.choices[0].text + if completion.choices[0].finish_reason == "stop": + if ("" not in completion.choices[0].text) and ( + "" in completion.choices[0].text + ): + # assume it stopped on + resp = resp + " " + task = loop.run_in_executor(self.mp_executor, score_answer, gold, resp) + reward = await task + if reward is None: + return 0, subset + score = 1 if reward else 0 + return score, subset + + async def evaluate(self, *args, **kwargs): + if not self.config.run_evaluation: + return + eval_tasks = [] + for item in self.test: + eval_tasks.append(self.rollout_and_score_eval(item[0], item[1], item[2])) + parsing_data = await tqdm_asyncio.gather(*eval_tasks) + task_lists = dict() + for score, subset in parsing_data: + if subset not in task_lists: + task_lists[subset] = list() + task_lists[subset].append(score) + # Now get the average + for subset, scores in task_lists.items(): + self.eval_metrics.append( + (f"eval/{subset}_percent_correct", sum(scores) / len(scores)) + ) + # overall score + scores = [] + for subset, score in task_lists.items(): + scores.extend(score) + self.eval_metrics.append( + ("eval/overall_percent_correct", sum(scores) / len(scores)) + ) + + async def collect_trajectories(self, item) -> Tuple[List, List]: + thinking_len = self.config.max_token_length + user_prompt = prompt_format.format( + prompt=problem_format.format(problem=item[0]) + ) + thinking_len = thinking_len - len(self.tokenizer.encode(user_prompt)) + completions = await self.server.completion( + prompt=user_prompt, + n=self.config.group_size, + max_tokens=thinking_len, + temperature=1.0, + top_p=0.95, + stop=stop_list, + ) + to_score = list() + to_backlog = list() + for i, completion in enumerate(completions.choices): + message = user_prompt + completion.text + if completion.finish_reason == "stop": + if ("" not in completion.text) and ( + "" in completion.text + ): + # assume it stopped on + message = message + " " + to_score.append( + ( + message, + item[1], + completion.finish_reason, + user_prompt, + ) + ) + to_postprocess = await self.score(to_score) + if to_postprocess is None: + return None, to_backlog + if all( + [to_postprocess["scores"][0] == score for score in to_postprocess["scores"]] + ): + return None, to_backlog + self.normal_rollouts.append( + ( + prompt_format.format(prompt=problem_format.format(problem=item[0])), + to_postprocess["messages"][0][-1]["content"], + item[1], + to_postprocess["scores"][0], + ) + ) + if len(self.normal_rollouts) > self.config.num_rollouts_to_keep: + self.normal_rollouts.pop(0) + print(f"Collected {len(to_postprocess['scores'])} trajectories") + return to_postprocess, to_backlog + + async def score(self, rollout_group_data: List) -> Optional[ScoredDataGroup]: + scores = ScoredDataGroup() + scores["tokens"] = list() + scores["masks"] = list() + scores["scores"] = list() + scores["overrides"] = list() + scores["messages"] = list() + gold = rollout_group_data[0][1] + loop = asyncio.get_event_loop() + random.shuffle(rollout_group_data) + for item in rollout_group_data: + resp = item[0] + scores["overrides"].append(dict()) + if item[2] == "length": + reward = False + if self.config.mask_too_long_completions: + scores["overrides"][-1]["set_advantage_to_zero"] = True + else: + task = loop.run_in_executor(self.mp_executor, score_answer, gold, resp) + reward = await task + if reward is None: + return None + tokens = self.tokenizer.encode(resp) + user_prompt_tokens = self.tokenizer.encode(item[3]) + if user_prompt_tokens[-1] == self.tokenizer.eos_token_id: + user_prompt_tokens = user_prompt_tokens[:-1] + assert all( + [ + i == j + for i, j in zip( + user_prompt_tokens, tokens[: len(user_prompt_tokens)] + ) + ] + ) + masks = [-100 for _ in range(len(user_prompt_tokens))] + masks = masks + tokens[len(user_prompt_tokens) :] + messages = [ + {"role": "user", "content": item[3]}, + {"role": "assistant", "content": resp[len(item[3]) :]}, + ] + # remove obviously bad examples + if len([1 for i in masks if i != -100]) < 10: + continue + if (item[2] == "length") and (not self.config.mask_too_long_completions): + scores["overrides"][-1]["set_advantage_to_zero"] = True + scores["tokens"].append(tokens) + scores["masks"].append(masks) + scores["scores"].append(1.0 if reward else -1.0) + scores["messages"].append(messages) + if len(scores["tokens"]) >= self.config.group_size: + break + if any([score == 1.0 for score in scores["scores"]]): + self.pass_at_groupsize.append(1.0) + else: + self.pass_at_groupsize.append(0.0) + if len(scores["tokens"]) < self.config.group_size: + # We don't have enough data to score + return None + for score in scores["scores"]: + self.percent_correct_buffer.append(max(score, 0)) + self.percent_overanswer.extend( + [item[2] == "length" for item in rollout_group_data] + ) + # check if all the same + # print(scores['scores']) + # Fill in the correct/incorrect lens after so we're only looking at actual training data + self.correct_answer_len.extend( + [ + len(scores["tokens"][i]) + for i in range(len(scores["scores"])) + if scores["scores"][i] == 1.0 + ] + ) + self.incorrect_answer_len.extend( + [ + len(scores["tokens"][i]) + for i in range(len(scores["scores"])) + if (scores["scores"][i] == -1.0) + and (not scores["overrides"][i].get("set_advantage_to_zero", False)) + ] + ) + return scores + + async def get_next_item(self): + while True: + next_item = self.train[self.iter % len(self.train)] + self.iter += 1 + prompt = next_item["question"] + try: + answer = ( + ("\\boxed{" + next_item["final_answer"] + "}") + if "\\boxed" not in next_item["final_answer"] + else next_item["final_answer"] + ) + break + except TypeError: + print( + f"Error in getting next item, trying again, " + f"data: {next_item['question']} -> {next_item['final_answer']}" + ) + return (prompt, answer, "normal") + + +if __name__ == "__main__": + MathEnv.cli() diff --git a/environments/mcqa_thinking_env.py b/environments/mcqa_thinking_env.py new file mode 100644 index 00000000..9b5cff07 --- /dev/null +++ b/environments/mcqa_thinking_env.py @@ -0,0 +1,492 @@ +import random +import re +from typing import Dict, List, Optional, Tuple, Union + +import wandb +from datasets import load_dataset +from tqdm.asyncio import tqdm_asyncio + +from atroposlib.envs.base import ( + BaseEnv, + BaseEnvConfig, + EvalHandlingEnum, + Item, + OpenaiConfig, + ScoredDataGroup, +) +from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer + +system_prompt = ( + "You are a deep thinking AI, you may use extremely long chains of thought to deeply consider the " + "problem and deliberate with yourself via systematic reasoning processes to help come to a correct " + "solution prior to answering. You should enclose your thoughts and internal monologue inside " + " tags, and then provide your solution or response to the problem." +) + + +class MCQAThinkingEnv(BaseEnv): + def __init__( + self, + config: BaseEnvConfig, + server_configs: List[OpenaiConfig], + slurm=True, + testing=False, + ): + """ + Initialize the MCQA (Multiple Choice Question Answering) environment. + + Args: + config: Configuration for the base environment + server_configs: List of server configurations for OpenAI API + slurm: Whether to use Slurm for distributed training + testing: Whether in testing mode + """ + super().__init__(config, server_configs, slurm, testing) + self.percent_correct_buffer = list() + self.eval_metrics = list() + + @classmethod + def config_init(self) -> Tuple[BaseEnvConfig, List[OpenaiConfig]]: + env_config = BaseEnvConfig( + tokenizer_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview", + group_size=32, + use_wandb=True, + max_num_workers=128, + rollout_server_url="http://localhost:8000", + total_steps=2000, + batch_size=1024, + steps_per_eval=20, + max_token_length=1024 * 15, + inference_weight=1.0, + wandb_name="mcqa_deep_thinking", + data_path_to_save_groups=None, + eval_handling=EvalHandlingEnum.LIMIT_TRAIN, + eval_limit_ratio=0.1, + ) + server_configs = [ + OpenaiConfig( + model_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview", + base_url="http://localhost:9004/v1", + api_key="x", + num_max_requests_at_once=32, + num_requests_for_eval=256, + ) + ] + + return env_config, server_configs + + async def setup(self): + """ + Set up the environment by loading and preparing the dataset. + """ + # Load the full dataset + full_dataset = load_dataset( + "NousResearch/AcademicMCQA", "default", split="train" + ) + + full_dataset = full_dataset.shuffle(seed=42) + + # Create train/test split on the fly (e.g., 95% train, 5% test) + split_dataset = full_dataset.train_test_split(test_size=0.02, seed=42) + + # Keep the splits as is - no need to reformat + self.train = split_dataset["train"] + self.test = split_dataset["test"] + + # Print some dataset statistics + print( + f"Loaded dataset with {len(self.train)} training examples and {len(self.test)} test examples" + ) + print(f"Example item format: {self.train[0]}") + + # Initialize iteration counter + self.iter = 0 + + def save_checkpoint(self, step, data=None): + if data is None: + data = {} + data["iter"] = self.iter + super().save_checkpoint(step, data) + + async def get_next_item(self): + """ + Get the next training item from the dataset. + + Returns: + A tuple containing prompt and expected answer + """ + next_item = self.train[self.iter % len(self.train)] + self.iter += 1 + + # Extract question and options from the multiple choice item + question_text = next_item["prompt"] + correct_answer_index = next_item["answer"] + ground_truth_letter = next_item["ground_truth"] + options = next_item["options"] + + # Append the answer format instruction to the prompt + question_text_with_instruction = f'{question_text}\n\nProvide your answer by saying "The best answer is: {{Answer}}"' # noqa E501 + + # Create prompt tuple using frozensets as required + prompt = [] + + # Add system prompt as defined at the top of the script + prompt.append(frozenset({"role": "system", "content": system_prompt}.items())) + + # Add user message with the question and instruction + prompt.append( + frozenset( + {"role": "user", "content": question_text_with_instruction}.items() + ) + ) + + # Prepare the expected answer + # We'll use the ground_truth_letter (A, B, C, D) as the expected answer + # The scoring function will need to check if the model response contains this letter + answer = ground_truth_letter + answer_string = options[correct_answer_index] + + return (tuple(prompt), answer, answer_string) + + async def collect_trajectories(self, item) -> Tuple[ScoredDataGroup, List]: + """ + Generate and collect model responses for scoring. + + Args: + item: Input item containing prompt and expected answer + + Returns: + Tuple of lists containing scored data groups and backlog + """ + # Extract messages from the item + messages = [] + for role_dict in item[0]: + messages.append(dict(role_dict)) + + # Apply chat template to convert messages to a single string + prompt = self.tokenizer.apply_chat_template( + messages, add_generation_prompt=True, tokenize=False + ) + + # Get completions from the model using completion() instead of chat_completion() + completions = await self.server.completion( + prompt=prompt, + n=self.config.group_size, + max_tokens=1024 * 15, + temperature=1.0, # Using temperature to get diverse responses + ) + + to_score = list() + + for i, completion_choice in enumerate(completions.choices): + # Create a copy of the prompt messages + trajectory_messages = [] + for role_dict in item[0]: + trajectory_messages.append(dict(role_dict)) + + # Add the model's response + trajectory_messages.append( + {"role": "assistant", "content": completion_choice.text} + ) + + # Add to scoring queue with expected answer, ground truth text, and stop reason + to_score.append( + ( + tuple(trajectory_messages), + item[1], # Letter (A, B, C, D) + item[2], # Include the answer_string/ground_truth_text + completion_choice.finish_reason, # Add the stop reason + ) + ) + + # Call score to get the scored data + scored_data = await self.score(to_score) + to_backlog = [] + + return scored_data, to_backlog + + def _extract_mcqa_answer(self, text, ground_truth_text, ground_truth_letter): + """ + Extract the multiple choice answer (A, B, C, or D) from model response. + Only allows one valid answer format - multiple answer formats result in a score of 0. + + Args: + text: Text containing the model's response + ground_truth_text: The full text of the correct answer + ground_truth_letter: The letter (A, B, C, D) of the correct answer + + Returns: + Extracted answer letter or None if invalid response pattern is found + """ + # Check for multiple tags - score as 0 if found + think_tags = re.findall(r"", text, re.IGNORECASE) + if len(think_tags) > 1: + return None + + # Check if the think tag is properly opened - we need exactly one opening tag + if len(think_tags) != 1: + return None + + # Check for closing tags + think_close_tags = re.findall(r"", text, re.IGNORECASE) + if len(think_close_tags) != 1: + return None # Must have exactly one closing tag + + # Split the text into thinking and answer sections + parts = re.split(r"", text, flags=re.IGNORECASE, maxsplit=1) + + # If there's no tag or multiple sections, return None + if len(parts) != 2: + return None + + thinking_section, answer_section = parts + + # Validate thinking section + # Make sure thinking section actually contains the opening tag + if "" not in thinking_section.lower(): + return None # Malformed thinking section + + # Check if there are any tags in the answer section (after the first ) + if "" in answer_section.lower(): + return None + + # More flexible answer patterns that handle parentheses and additional text + answer_patterns = [ + r"The correct answer is:?\s*(?:\*\*)?(A|B|C|D)(?:\*\*)?(?:\)|\.|:)?(?:[^A-Da-d]*.*?)?(?=$|\n|\.)", # noqa W605 + r"The best answer is:?\s*(?:\*\*)?(A|B|C|D)(?:\*\*)?(?:\)|\.|:)?(?:[^A-Da-d]*.*?)?(?=$|\n|\.)", # noqa W605 + r"The answer is:?\s*(?:\*\*)?(A|B|C|D)(?:\*\*)?(?:\)|\.|:)?(?:[^A-Da-d]*.*?)?(?=$|\n|\.)", # noqa W605 + r"\*\*The best answer is\s*(A|B|C|D)\*\*(?:\)|\.|:)?(?:[^A-Da-d]*.*?)?(?=$|\n|\.)", # noqa W605 + r"\*\*The best answer is:\s*(A|B|C|D)\*\*(?:\)|\.|:)?(?:[^A-Da-d]*.*?)?(?=$|\n|\.)", # noqa W605 + r"Thus, final answer:\s*(A|B|C|D)\)(?:\)|\.|:)?(?:[^A-Da-d]*.*?)?(?=$|\n|\.)", # noqa W605 + r"\\boxed{(A|B|C|D)}(?:\)|\.|:)?(?:[^A-Da-d]*.*?)?(?=$|\n|\.)", # noqa W605 + ] + + string_patterns = [ + # Patterns to match exact ground truth text, with optional markdown bold formatting + r"The correct answer is:?\s(?:\*\*)?" + + re.escape(ground_truth_text) + + r"(?:\*\*)?(?:[^A-Da-d]*.*?)?(?=$|\n|\.)", + r"The best answer is:?\s(?:\*\*)?" + + re.escape(ground_truth_text) + + r"(?:\*\*)?(?:[^A-Da-d]*.*?)?(?=$|\n|\.)", + r"The answer is:?\s(?:\*\*)?" + + re.escape(ground_truth_text) + + r"(?:\*\*)?(?:[^A-Da-d]*.*?)?(?=$|\n|\.)", + ] + + # Track all found answers + found_answers = [] + + # Check each pattern + for pattern in answer_patterns: + matches = re.findall(pattern, answer_section, re.IGNORECASE) + if matches: + for match in matches: + # Extract just the letter + found_answers.append(match.upper()) + + for pattern in string_patterns: + matches = re.findall(pattern, answer_section, re.IGNORECASE) + if matches: + # For each match found, append the ground truth letter instead of the full match + for _ in matches: + found_answers.append(ground_truth_letter) + + # If no answers found or multiple answers found, return None + if len(found_answers) != 1: + return None + + # Return the single found answer + return found_answers[0] + + async def score(self, rollout_group_data: List) -> Optional[ScoredDataGroup]: + """ + Score the generated model responses against expected MCQA answers. + + Args: + rollout_group_data: List of generated responses with expected answers + + Returns: + ScoredDataGroup with tokenized inputs and scores, or None if no valid scores + """ + scores = ScoredDataGroup() + scores["tokens"] = list() + scores["masks"] = list() + scores["scores"] = list() + + # Get the expected answer letter + expected_answer = rollout_group_data[0][1] # Letter A, B, C, D + ground_truth_text = rollout_group_data[0][2] + + # Shuffle to avoid bias in selection + random.shuffle(rollout_group_data) + + for item in rollout_group_data: + # Extract the model's response + model_response = item[0][-1]["content"] + stop_reason = item[3] # Get the stop reason + + # If the response was cut off due to length, give it a score of 0 + if stop_reason == "length": + reward = 0 + else: + # Extract the answer from the model's response + model_answer = self._extract_mcqa_answer( + model_response, ground_truth_text, expected_answer + ) + + # Track metrics based on result + if model_answer is None: + reward = 0 # Invalid format gets 0 reward + elif model_answer == expected_answer: + reward = 1 # Correct answer gets 1 reward + else: + reward = 0 # Wrong answer gets 0 reward + + # Tokenize the conversation for learning + out_dict = tokenize_for_trainer(self.tokenizer, item[0]) + tokens = out_dict["tokens"] + masks = out_dict["masks"] + + # Remove examples with insufficient context + if len([1 for i in masks if i != -100]) < 10: + continue + + scores["tokens"].append(tokens) + scores["masks"].append(masks) + scores["scores"].append(1.0 if reward else -1.0) + + # Break once we have enough examples + if len(scores["tokens"]) >= self.config.group_size: + break + + # Record success rate metrics for wandb logging + for score in scores["scores"]: + self.percent_correct_buffer.append(max(score, 0)) + + # Return None if all scores are the same (no learning signal) + if all(scores["scores"][0] == score for score in scores["scores"]): + return None + + return scores + + async def rollout_and_score_eval(self, test_item): + """ + Generate and score model responses for a single test item. + + Args: + test_item: Test item from dataset + + Returns: + Score (1 for correct, 0 for incorrect) + """ + # Extract question and options from the test item + question_text = test_item["prompt"] + correct_answer_index = test_item["answer"] + expected_answer_letter = test_item["ground_truth"] + options = test_item["options"] + + # Append the answer format instruction to the prompt + question_text_with_instruction = f'{question_text}\n\nProvide your answer by saying "The best answer is: {{Answer}}"' # noqa E501 + + # Create messages for model + messages = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": question_text_with_instruction}, + ] + + # Apply chat template to convert messages to a single string + prompt = self.tokenizer.apply_chat_template( + messages, add_generation_prompt=True, tokenize=False + ) + + # Get model completion + completion = await self.server.completion( + prompt=prompt, + n=1, + max_tokens=1024 * 15, + temperature=0.5, # Lower for eval + split="eval", + ) + + # Extract the model's response from the completion + model_response = completion.choices[0].text + + # Extract the answer from the model's response + model_answer = self._extract_mcqa_answer( + model_response, options[correct_answer_index], expected_answer_letter + ) + + # Score 1 if the answers match, 0 otherwise + score = 1 if model_answer and model_answer == expected_answer_letter else 0 + + return score + + async def evaluate(self, *args, **kwargs): + """ + Evaluate the model on test data. + """ + eval_tasks = [] + for test_item in self.test: + eval_tasks.append(self.rollout_and_score_eval(test_item)) + + # Run evaluation + scores = await tqdm_asyncio.gather(*eval_tasks) + self.eval_metrics.append(("eval/percent_correct", sum(scores) / len(scores))) + + async def add_rollouts_for_wandb( + self, + scored_data: Union[ScoredDataGroup, List[ScoredDataGroup]], + item: Item = None, + ): + # save rollout to trajectory + num_keep = self.config.num_rollouts_per_group_for_logging + if num_keep == -1: + num_keep = self.config.group_size + self.rollouts_for_wandb.append( + [ + ( + self.tokenizer.decode(scored_data["tokens"][i]), + scored_data["scores"][i], + item[1], + item[2], + ) + for i in range(num_keep) + ] + ) + if len(self.rollouts_for_wandb) > self.config.num_rollouts_to_keep: + self.rollouts_for_wandb.pop(0) + + async def create_rollout_table(self, wandb_metrics): + if len(self.rollouts_for_wandb) > 0: + table = wandb.Table(columns=["text", "score", "answer", "string_answer"]) + for group in self.rollouts_for_wandb: + for item in group: + table.add_data(item[0], item[1], item[2], item[3]) + wandb_metrics["train/rollouts"] = table + self.rollouts_for_wandb = [] + return wandb_metrics + + async def wandb_log(self, wandb_metrics: Optional[Dict] = None): + if wandb_metrics is None: + wandb_metrics = {} + + # Try to calculate percent_correct, pass if there's a division by zero + try: + wandb_metrics["train/percent_correct"] = sum( + self.percent_correct_buffer + ) / len(self.percent_correct_buffer) + except ZeroDivisionError: + # Skip if buffer is empty + pass + + self.percent_correct_buffer = list() + for item in self.eval_metrics: + wandb_metrics[item[0]] = item[1] + self.eval_metrics = list() + + await super().wandb_log(wandb_metrics) + + +if __name__ == "__main__": + MCQAThinkingEnv.cli() diff --git a/environments/multimodal_dpo/clevr_cogen_a_train.py b/environments/multimodal_dpo/clevr_cogen_a_train.py new file mode 100644 index 00000000..f9e786e3 --- /dev/null +++ b/environments/multimodal_dpo/clevr_cogen_a_train.py @@ -0,0 +1,282 @@ +import base64 +import json +import os +import random +import re +import sys +import traceback +from typing import List, Optional, Tuple + +from datasets import load_dataset + +from atroposlib.envs.base import BaseEnv, BaseEnvConfig, OpenaiConfig, ScoredDataGroup +from atroposlib.type_definitions import GameHistory, Item +from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer + + +class MultimodalExampleEnv(BaseEnv): + name = "clevr_cogen_a_train" + name_config_cls = BaseEnvConfig + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + async def collect_trajectories( + self, item: Item + ) -> Tuple[GameHistory | None, List[Item]]: + print("DEBUG: Starting collect_trajectories") + to_score = list() + to_backlog = list() + + # Get the current image if it was stored + if hasattr(self, "current_image"): + print("DEBUG: Using current_image for multimodal content") + + # Convert PIL image to base64 + import io + + img_byte_arr = io.BytesIO() + self.current_image.save(img_byte_arr, format="PNG") + img_byte_arr = img_byte_arr.getvalue() + base64_image = base64.b64encode(img_byte_arr).decode("utf-8") + + # Extract text content from item + user_content = dict(item[0][0]).get("content", "") + + # Try to parse if it's JSON + if isinstance(user_content, str) and ( + user_content.startswith("[") or user_content.startswith("{") + ): + try: + parsed = json.loads(user_content) + text_content = "" + for element in parsed: + if element.get("type") == "text": + text_content = element.get("text", "") + + if not text_content: + text_content = "Please solve this problem and provide your answer as \\boxed{answer}." + except Exception as e: + print(f"DEBUG: Error parsing JSON: {e}") + text_content = "Please solve this problem and provide your answer as \\boxed{answer}." + else: + text_content = user_content + + # Create messages with the new format + print("DEBUG: Creating multimodal message with new format") + messages = [ + { + "role": "system", + "content": "You must submit your answer with \\boxed{answer}, please make sure to do this", + }, + { + "role": "user", + "content": [ + {"type": "text", "text": text_content}, + { + "type": "image_url", + "image_url": { + "url": f"data:image/png;base64,{base64_image}", + }, + }, + ], + }, + ] + + else: + print("DEBUG: No image available, using text-only message") + messages = [ + { + "role": "system", + "content": "You must submit your answer with \\boxed{answer}", + }, + dict(item[0][0]), + ] + + print("DEBUG: About to call chat_completion") + chat_completions = await self.server.chat_completion( + messages=messages, + n=self.config.group_size, + max_tokens=1024 * 2, + timeout=60, # Add timeout to prevent hanging (60 seconds is more reasonable) + ) + print("DEBUG: chat_completion call successful") + + for i, chat_completion in enumerate(chat_completions.choices): + print(f"DEBUG: Processing completion {i+1}/{len(chat_completions.choices)}") + messages = ( + dict(item[0][0]), + {"role": "assistant", "content": chat_completion.message.content}, + ) + to_score.append((messages, item[1], base64_image)) + + print("DEBUG: Finished processing completions") + + print("DEBUG: Returning from collect_trajectories") + return to_score, to_backlog + + async def postprocess_histories( + self, trajectories: List[GameHistory] + ) -> ScoredDataGroup: + pass + + async def evaluate(self, *args, **kwargs): + """ + Evaluate the environment, this is called every steps_per_eval steps + + :param args: + :param kwargs: + :return: None. + """ + return + + async def setup(self): + """Setup the environment and load the multimodal dataset""" + self.dataset = load_dataset("leonardPKU/clevr_cogen_a_train") + self.train = self.dataset["train"] + self.iter = 0 + + async def get_next_item(self) -> Item: + """ + Get the next items to be rolled out, including the image + """ + try: + print("DEBUG: Starting get_next_item") + + # Get next dataset item + next_item = self.train[self.iter % len(self.train)] + self.iter += 1 + + print(f"DEBUG: Retrieved dataset item {self.iter-1}") + + # For debugging, we'll use a simple text-only prompt and store the image separately + # This is because the collect_trajectories method will handle the multimodal formatting + + # Store image as a class attribute so collect_trajectories can access it + self.current_image = next_item["image"] + print("DEBUG: Stored image in current_image attribute") + + # Create a simple text prompt - the image will be added in collect_trajectories + # This avoids the unhashable type error with lists in frozensets + text_prompt = next_item["problem"] + + # Create a simple text-only prompt + prompt = tuple( + [frozenset({"role": "user", "content": text_prompt}.items())] + ) + answer = next_item["solution"] + + # get image as base64 + # image = next_item["image"] + + # Convert PIL image to base64 + import io + + img_byte_arr = io.BytesIO() + self.current_image.save(img_byte_arr, format="PNG") + img_byte_arr = img_byte_arr.getvalue() + base64_image = base64.b64encode(img_byte_arr).decode("utf-8") + + print("DEBUG: Created simple text-only prompt for get_next_item") + return (prompt, answer, base64_image) + + except Exception as e: + print(f"DEBUG: Error in get_next_item: {str(e)}") + traceback.print_exc() + + # Create a dummy item as fallback + prompt = tuple( + [ + frozenset( + {"role": "user", "content": "Please solve: 2 + 2 = ?"}.items() + ) + ] + ) + answer = "4" + return (prompt, answer, "obobob") + + async def score(self, rollout_group_data) -> Optional[ScoredDataGroup]: + scores = ScoredDataGroup() + scores["tokens"] = list() + scores["masks"] = list() + scores["scores"] = list() + scores["images"] = list() + random.shuffle(rollout_group_data) + for item in rollout_group_data: + out_dict = tokenize_for_trainer(self.tokenizer, item[0]) + tokens = out_dict["tokens"] + masks = out_dict["masks"] + + # Extract the answer from the model's response + try: + model_answer = ( + item[0][-1]["content"].split("\\boxed{")[-1].split("}")[0] + ) + print( + f"DEBUG: Model answer: {model_answer} and RG data: {rollout_group_data[0][1]}" + ) + + pattern = r"\s*(\d{1,2})\s*" + string = rollout_group_data[0][1] + gold_answer = re.search(pattern, string).group(1) + + reward = gold_answer == model_answer + except IndexError: + reward = False + + # remove obviously bad examples + if len([1 for i in masks if i != -100]) < 10: + continue + + scores["tokens"].append(tokens) + scores["masks"].append(masks) + scores["scores"].append(1.0 if reward else -1.0) + + try: + scores["images"].append(item[2]) + except IndexError: + scores["images"].append(None) + if len(scores["tokens"]) >= self.config.group_size: + break + + return scores + + @classmethod + def config_init(cls) -> Tuple[BaseEnvConfig, List[OpenaiConfig]]: + if not os.environ.get("OPENAI_API_KEY"): + print("ERROR: OPENAI_API_KEY environment variable is not set!") + print("Please set it using: export OPENAI_API_KEY=your_api_key") + sys.exit(1) + + print( + f"DEBUG: Using API key starting with: {os.environ.get('OPENAI_API_KEY')[:5]}..." + ) + + config = BaseEnvConfig( + wandb_name="clevr_cogen", + tokenizer_name="gpt2", + group_size=2, + use_wandb=False, + max_num_workers=2, + rollout_server_url="http://localhost:8000", + total_steps=1000, + batch_size=1, + steps_per_eval=10, + ensure_scores_are_not_same=False, + ) + + print("DEBUG: Creating OpenAI configuration") + server_configs = [ + OpenaiConfig( + model_name="gpt-4o", # Using GPT-4o which has multimodal capabilities + base_url=None, + api_key=os.environ.get("OPENAI_API_KEY"), + num_requests_for_eval=1, + ), + ] + + return config, server_configs + + +if __name__ == "__main__": + MultimodalExampleEnv.cli() diff --git a/environments/multimodal_dpo/clevr_complex.py b/environments/multimodal_dpo/clevr_complex.py new file mode 100644 index 00000000..ca99b83c --- /dev/null +++ b/environments/multimodal_dpo/clevr_complex.py @@ -0,0 +1,283 @@ +import base64 +import json +import os +import random +import sys +import traceback +from typing import List, Optional, Tuple + +from datasets import load_dataset + +from atroposlib.envs.base import BaseEnv, BaseEnvConfig, OpenaiConfig, ScoredDataGroup +from atroposlib.type_definitions import GameHistory, Item +from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer + + +class MultimodalComplexEnv(BaseEnv): + name = "clevr_complex" + name_config_cls = BaseEnvConfig + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + async def collect_trajectories( + self, item: Item + ) -> Tuple[GameHistory | None, List[Item]]: + print("DEBUG: Starting collect_trajectories") + to_score = list() + to_backlog = list() + + # Get the current image if it was stored + if hasattr(self, "current_image"): + print("DEBUG: Using current_image for multimodal content") + + # Convert PIL image to base64 + import io + + img_byte_arr = io.BytesIO() + self.current_image.save(img_byte_arr, format="PNG") + img_byte_arr = img_byte_arr.getvalue() + base64_image = base64.b64encode(img_byte_arr).decode("utf-8") + + # Extract text content from item + user_content = dict(item[0][0]).get("content", "") + + # Try to parse if it's JSON + if isinstance(user_content, str) and ( + user_content.startswith("[") or user_content.startswith("{") + ): + try: + parsed = json.loads(user_content) + text_content = "" + for element in parsed: + if element.get("type") == "text": + text_content = element.get("text", "") + + if not text_content: + text_content = "Please solve this problem and provide your answer as \\boxed{answer}." + except Exception as e: + print(f"DEBUG: Error parsing JSON: {e}") + text_content = "Please solve this problem and provide your answer as \\boxed{answer}." + else: + text_content = user_content + + # Create messages with the new format + print("DEBUG: Creating multimodal message with new format") + messages = [ + { + "role": "system", + "content": "You must submit your answer with \\boxed{answer}, please make sure to do this", + }, + { + "role": "user", + "content": [ + {"type": "text", "text": text_content}, + { + "type": "image_url", + "image_url": { + "url": f"data:image/png;base64,{base64_image}", + }, + }, + ], + }, + ] + + else: + print("DEBUG: No image available, using text-only message") + messages = [ + { + "role": "system", + "content": "You must submit your answer with \\boxed{answer}", + }, + dict(item[0][0]), + ] + + print("DEBUG: About to call chat_completion") + chat_completions = await self.server.chat_completion( + messages=messages, + n=self.config.group_size, + max_tokens=1024 * 2, + timeout=60, # Add timeout to prevent hanging (60 seconds is more reasonable) + ) + print("DEBUG: chat_completion call successful") + + for i, chat_completion in enumerate(chat_completions.choices): + print(f"DEBUG: Processing completion {i+1}/{len(chat_completions.choices)}") + messages = ( + dict(item[0][0]), + {"role": "assistant", "content": chat_completion.message.content}, + ) + to_score.append((messages, item[1], base64_image)) + + print("DEBUG: Finished processing completions") + + print("DEBUG: Returning from collect_trajectories") + return to_score, to_backlog + + async def postprocess_histories( + self, trajectories: List[GameHistory] + ) -> ScoredDataGroup: + pass + + async def evaluate(self, *args, **kwargs): + """ + Evaluate the environment, this is called every steps_per_eval steps + + :param args: + :param kwargs: + :return: None. + """ + return + + async def setup(self): + """Setup the environment and load the multimodal dataset""" + self.dataset = load_dataset("MMInstruction/Clevr_CoGenT_TrainA_70K_Complex") + self.train = self.dataset["train"] + self.iter = 0 + + async def get_next_item(self) -> Item: + """ + Get the next items to be rolled out, including the image + """ + try: + print("DEBUG: Starting get_next_item") + + # Get next dataset item + next_item = self.train[self.iter % len(self.train)] + self.iter += 1 + + print(f"DEBUG: Retrieved dataset item {self.iter-1}") + + # For debugging, we'll use a simple text-only prompt and store the image separately + # This is because the collect_trajectories method will handle the multimodal formatting + + # Store image as a class attribute so collect_trajectories can access it + self.current_image = next_item["image"] + print("DEBUG: Stored image in current_image attribute") + + # Create a simple text prompt - the image will be added in collect_trajectories + # This avoids the unhashable type error with lists in frozensets + text_prompt = next_item["problem"] + + # Create a simple text-only prompt + prompt = tuple( + [frozenset({"role": "user", "content": text_prompt}.items())] + ) + answer = next_item["solution"] + + # Convert PIL image to base64 + import io + + img_byte_arr = io.BytesIO() + self.current_image.save(img_byte_arr, format="PNG") + img_byte_arr = img_byte_arr.getvalue() + base64_image = base64.b64encode(img_byte_arr).decode("utf-8") + + print("DEBUG: Created simple text-only prompt for get_next_item") + return (prompt, answer, base64_image) + + except Exception as e: + print(f"DEBUG: Error in get_next_item: {str(e)}") + traceback.print_exc() + + # Create a dummy item as fallback + prompt = tuple( + [ + frozenset( + {"role": "user", "content": "Please solve: 2 + 2 = ?"}.items() + ) + ] + ) + answer = "4" + return (prompt, answer, "obobob") + + async def score(self, rollout_group_data) -> Optional[ScoredDataGroup]: + scores = ScoredDataGroup() + scores["tokens"] = list() + scores["masks"] = list() + scores["scores"] = list() + scores["images"] = list() + random.shuffle(rollout_group_data) + for item in rollout_group_data: + out_dict = tokenize_for_trainer(self.tokenizer, item[0]) + tokens = out_dict["tokens"] + masks = out_dict["masks"] + + # Extract the answer from the model's response + try: + model_answer = ( + item[0][-1]["content"].split("\\boxed{")[-1].split("}")[0] + ) + print( + f"DEBUG: Model answer: {model_answer} and RG data: {rollout_group_data[0][1]}" + ) + + # Handle both numeric and yes/no answers + gold_answer = rollout_group_data[0][1] + + # Case-insensitive comparison for yes/no and direct comparison for numbers + if gold_answer.lower() in ["yes", "no"]: + reward = gold_answer.lower() == model_answer.lower() + else: + # For numeric answers + reward = gold_answer == model_answer + + except IndexError: + reward = False + + # remove obviously bad examples + if len([1 for i in masks if i != -100]) < 10: + continue + + scores["tokens"].append(tokens) + scores["masks"].append(masks) + scores["scores"].append(1.0 if reward else -1.0) + + try: + scores["images"].append(item[2]) + except IndexError: + scores["images"].append(None) + if len(scores["tokens"]) >= self.config.group_size: + break + + return scores + + @classmethod + def config_init(cls) -> Tuple[BaseEnvConfig, List[OpenaiConfig]]: + if not os.environ.get("OPENAI_API_KEY"): + print("ERROR: OPENAI_API_KEY environment variable is not set!") + print("Please set it using: export OPENAI_API_KEY=your_api_key") + sys.exit(1) + + print( + f"DEBUG: Using API key starting with: {os.environ.get('OPENAI_API_KEY')[:5]}..." + ) + + config = BaseEnvConfig( + wandb_name="clevr_complex", + tokenizer_name="gpt2", + group_size=2, + use_wandb=False, + max_num_workers=2, + rollout_server_url="http://localhost:8000", + total_steps=1000, + batch_size=1, + steps_per_eval=10, + ensure_scores_are_not_same=False, + ) + + print("DEBUG: Creating OpenAI configuration") + server_configs = [ + OpenaiConfig( + model_name="gpt-4o", # Using GPT-4o which has multimodal capabilities + base_url=None, + api_key=os.environ.get("OPENAI_API_KEY"), + num_requests_for_eval=1, + ), + ] + + return config, server_configs + + +if __name__ == "__main__": + MultimodalComplexEnv.cli() diff --git a/environments/multimodal_dpo/ocr_vqa.py b/environments/multimodal_dpo/ocr_vqa.py new file mode 100644 index 00000000..aa6978ef --- /dev/null +++ b/environments/multimodal_dpo/ocr_vqa.py @@ -0,0 +1,200 @@ +import base64 +import io +import os +import random +import re +import sys +import traceback +from typing import List, Optional, Tuple + +from datasets import load_dataset + +from atroposlib.envs.base import BaseEnv, BaseEnvConfig, OpenaiConfig, ScoredDataGroup +from atroposlib.type_definitions import GameHistory, Item +from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer + + +class OcrVqaEnv(BaseEnv): + name = "ocr_vqa" + name_config_cls = BaseEnvConfig + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + async def collect_trajectories( + self, item: Item + ) -> Tuple[GameHistory | None, List[Item]]: + to_score: List[Tuple[GameHistory, str, Optional[str]]] = [] + to_backlog: List[Item] = [] + + # Extract question and image from item + prompt_tuple, gold, base64_image = item + # The prompt_tuple contains the user prompt as a frozenset + text_prompt = dict(prompt_tuple[0])["content"] + + # System instruction for answer formatting + system_msg = { + "role": "system", + "content": ( + "You must submit your answer enclosed in tags, " + "e.g., YOUR_ANSWER" + ), + } + # Multimodal user message with text and image + user_msg = { + "role": "user", + "content": [ + {"type": "text", "text": text_prompt}, + { + "type": "image_url", + "image_url": {"url": f"data:image/png;base64,{base64_image}"}, + }, + ], + } + messages = [system_msg, user_msg] + + # Call the chat completion endpoint + chat_completions = await self.server.chat_completion( + messages=messages, + n=self.config.group_size, + max_tokens=512, + timeout=60, + ) + + # Build trajectories for scoring + for choice in chat_completions.choices: + user_hist = {"role": "user", "content": text_prompt} + assistant_hist = {"role": "assistant", "content": choice.message.content} + history: GameHistory = (user_hist, assistant_hist) + to_score.append((history, gold, base64_image)) + + return to_score, to_backlog + + async def postprocess_histories( + self, trajectories: List[GameHistory] + ) -> ScoredDataGroup: + # No additional post-processing needed + pass + + async def evaluate(self, *args, **kwargs): + # No custom evaluation + return + + async def setup(self): + # Load the OCR-VQA dataset + self.dataset = load_dataset("howard-hou/OCR-VQA") + self.train = self.dataset["train"] + self.iter = 0 + + async def get_next_item(self) -> Item: + try: + entry = self.train[self.iter % len(self.train)] + self.iter += 1 + + # Take the first question and answer + question = entry["questions"][0] + answer = entry["answers"][0] + text_prompt = question + prompt = tuple( + [frozenset({"role": "user", "content": text_prompt}.items())] + ) + + # Format the gold answer + gold_answer = f"{answer}" + + # Convert image to base64 + img = entry["image"] + buf = io.BytesIO() + img.save(buf, format="PNG") + img_bytes = buf.getvalue() + base64_image = base64.b64encode(img_bytes).decode("utf-8") + + return (prompt, gold_answer, base64_image) + except Exception: + traceback.print_exc() + # Fallback example + fallback_prompt = tuple( + [ + frozenset( + {"role": "user", "content": "Please solve: 2 + 2 = ?"}.items() + ) + ] + ) + return (fallback_prompt, "4", None) + + async def score(self, rollout_group_data) -> Optional[ScoredDataGroup]: + scores = ScoredDataGroup() + scores["tokens"] = [] + scores["masks"] = [] + scores["scores"] = [] + scores["images"] = [] + random.shuffle(rollout_group_data) + for item in rollout_group_data: + out = tokenize_for_trainer(self.tokenizer, item[0]) + tokens = out["tokens"] + masks = out["masks"] + + # Extract model and gold answers + try: + reply = item[0][-1]["content"] + m = re.search(r"\s*(.*?)\s*", reply, re.IGNORECASE) + model_answer = m.group(1).strip() if m else reply.strip() + + gold = item[1] + g = re.search(r"\s*(.*?)\s*", gold, re.IGNORECASE) + gold_answer = g.group(1).strip() if g else gold.strip() + + reward = model_answer.lower() == gold_answer.lower() + except Exception: + reward = False + + # Filter out short examples + if len([i for i in masks if i != -100]) < 10: + continue + + scores["tokens"].append(tokens) + scores["masks"].append(masks) + scores["scores"].append(1.0 if reward else -1.0) + try: + scores["images"].append(item[2]) + except Exception: + scores["images"].append(None) + + if len(scores["tokens"]) >= self.config.group_size: + break + + return scores + + @classmethod + def config_init(cls) -> Tuple[BaseEnvConfig, List[OpenaiConfig]]: + if not os.environ.get("OPENAI_API_KEY"): + print("ERROR: OPENAI_API_KEY environment variable is not set!") + sys.exit(1) + + config = BaseEnvConfig( + wandb_name="ocr_vqa", + tokenizer_name="gpt2", + group_size=2, + use_wandb=False, + max_num_workers=2, + rollout_server_url="http://localhost:8000", + total_steps=1000, + batch_size=1, + steps_per_eval=10, + ensure_scores_are_not_same=False, + ) + + server_configs = [ + OpenaiConfig( + model_name="gpt-4o", + base_url=None, + api_key=os.environ.get("OPENAI_API_KEY"), + num_requests_for_eval=1, + ), + ] + + return config, server_configs + + +if __name__ == "__main__": + OcrVqaEnv.cli() diff --git a/environments/multimodal_dpo/pixmo_clocks.py b/environments/multimodal_dpo/pixmo_clocks.py new file mode 100644 index 00000000..bf4059e8 --- /dev/null +++ b/environments/multimodal_dpo/pixmo_clocks.py @@ -0,0 +1,202 @@ +import base64 +import io +import os +import random +import re +import sys +import traceback +from typing import List, Optional, Tuple + +from datasets import load_dataset + +from atroposlib.envs.base import BaseEnv, BaseEnvConfig, OpenaiConfig, ScoredDataGroup +from atroposlib.type_definitions import GameHistory, Item +from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer + + +class ClockDatasetEnv(BaseEnv): + name = "pixmo_clocks" + name_config_cls = BaseEnvConfig + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + async def collect_trajectories( + self, item: Item + ) -> Tuple[GameHistory | None, List[Item]]: + to_score: List[Tuple[GameHistory, str, Optional[str]]] = [] + to_backlog: List[Item] = [] + + # Extract the base64 image + base64_image = item[2] + + # Build system instruction and multimodal user message + system_msg = { + "role": "system", + "content": ( + "You must submit your answer enclosed in tags, " + "e.g., HH:MM" + ), + } + user_prompt_text = "What time does the clock show?" + user_msg_multimodal = { + "role": "user", + "content": [ + {"type": "text", "text": user_prompt_text}, + { + "type": "image_url", + "image_url": {"url": f"data:image/png;base64,{base64_image}"}, + }, + ], + } + + messages = [system_msg, user_msg_multimodal] + + # Call chat completion + chat_completions = await self.server.chat_completion( + messages=messages, + n=self.config.group_size, + max_tokens=512, + timeout=60, + ) + + # Prepare trajectories for scoring + for choice in chat_completions.choices: + # Use text-only prompt for history + user_msg = {"role": "user", "content": user_prompt_text} + assistant_msg = {"role": "assistant", "content": choice.message.content} + history: GameHistory = (user_msg, assistant_msg) + to_score.append((history, item[1], base64_image)) + + return to_score, to_backlog + + async def postprocess_histories( + self, trajectories: List[GameHistory] + ) -> ScoredDataGroup: + # No custom post-processing + pass + + async def evaluate(self, *args, **kwargs): + # No custom evaluation + return + + async def setup(self): + # Load the clock dataset + self.dataset = load_dataset("junyeong-nero/clock-dataset") + self.train = self.dataset["train"] + self.iter = 0 + + async def get_next_item(self) -> Item: + try: + entry = self.train[self.iter % len(self.train)] + self.iter += 1 + + text_prompt = "What time does the clock show" + prompt = tuple( + [frozenset({"role": "user", "content": text_prompt}.items())] + ) + + # Format gold answer + hour = entry["hour"] + minute = entry["minute"] + gold_answer = f"{hour}:{minute:02d}" + + # Convert image to base64 + img = entry["image"] + buf = io.BytesIO() + img.save(buf, format="PNG") + image_bytes = buf.getvalue() + base64_image = base64.b64encode(image_bytes).decode("utf-8") + + return (prompt, gold_answer, base64_image) + except Exception: + traceback.print_exc() + # Fallback + fallback_prompt = tuple( + [ + frozenset( + {"role": "user", "content": "Please solve: 2 + 2 = ?"}.items() + ) + ] + ) + return (fallback_prompt, "4", None) + + async def score(self, rollout_group_data) -> Optional[ScoredDataGroup]: + scores = ScoredDataGroup() + scores["tokens"] = [] + scores["masks"] = [] + scores["scores"] = [] + scores["images"] = [] + random.shuffle(rollout_group_data) + for item in rollout_group_data: + out = tokenize_for_trainer(self.tokenizer, item[0]) + tokens = out["tokens"] + masks = out["masks"] + + # Extract answers + try: + reply = item[0][-1]["content"] + m_match = re.search( + r"\s*(.*?)\s*", reply, re.IGNORECASE + ) + model_answer = m_match.group(1).strip() if m_match else reply.strip() + + gold = item[1] + g_match = re.search( + r"\s*(.*?)\s*", gold, re.IGNORECASE + ) + gold_answer = g_match.group(1).strip() if g_match else gold.strip() + + reward = model_answer == gold_answer + except Exception: + reward = False + + if len([i for i in masks if i != -100]) < 10: + continue + + scores["tokens"].append(tokens) + scores["masks"].append(masks) + scores["scores"].append(1.0 if reward else -1.0) + try: + scores["images"].append(item[2]) + except Exception: + scores["images"].append(None) + + if len(scores["tokens"]) >= self.config.group_size: + break + + return scores + + @classmethod + def config_init(cls) -> Tuple[BaseEnvConfig, List[OpenaiConfig]]: + if not os.environ.get("OPENAI_API_KEY"): + print("ERROR: OPENAI_API_KEY environment variable is not set!") + sys.exit(1) + + config = BaseEnvConfig( + wandb_name="clocks", + tokenizer_name="gpt2", + group_size=2, + use_wandb=False, + max_num_workers=2, + rollout_server_url="http://localhost:8000", + total_steps=1000, + batch_size=1, + steps_per_eval=10, + ensure_scores_are_not_same=False, + ) + + server_configs = [ + OpenaiConfig( + model_name="gpt-4o", + base_url=None, + api_key=os.environ.get("OPENAI_API_KEY"), + num_requests_for_eval=1, + ), + ] + + return config, server_configs + + +if __name__ == "__main__": + ClockDatasetEnv.cli() diff --git a/environments/multimodal_dpo/pixmo_count.py b/environments/multimodal_dpo/pixmo_count.py new file mode 100644 index 00000000..ab158da6 --- /dev/null +++ b/environments/multimodal_dpo/pixmo_count.py @@ -0,0 +1,191 @@ +import base64 +import io +import os +import random +import re +import sys +import traceback +from typing import List, Optional, Tuple + +import requests +from datasets import load_dataset +from PIL import Image + +from atroposlib.envs.base import BaseEnv, BaseEnvConfig, OpenaiConfig, ScoredDataGroup +from atroposlib.type_definitions import GameHistory, Item +from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer + + +class PixmoCountEnv(BaseEnv): + name = "pixmo_count" + name_config_cls = BaseEnvConfig + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + async def setup(self): + # Load the pixmo-count dataset + self.dataset = load_dataset("allenai/pixmo-count") + self.train = self.dataset["train"] + self.iter = 0 + + async def get_next_item(self) -> Item: + try: + entry = self.train[self.iter % len(self.train)] + self.iter += 1 + + label = entry["label"] + count = entry["count"] + question = f"how many {label} are in the image?" + prompt = tuple([frozenset({"role": "user", "content": question}.items())]) + + gold_answer = f"{count}" + + # Load image from URL and convert to base64 + image_url = entry["image_url"] + response = requests.get(image_url) + img = Image.open(io.BytesIO(response.content)) + buf = io.BytesIO() + img.save(buf, format="PNG") + img_bytes = buf.getvalue() + base64_image = base64.b64encode(img_bytes).decode("utf-8") + + return (prompt, gold_answer, base64_image) + except Exception: + traceback.print_exc() + fallback = tuple( + [ + frozenset( + {"role": "user", "content": "Please solve: 2 + 2 = ?"}.items() + ) + ] + ) + return (fallback, "4", None) + + async def collect_trajectories( + self, item: Item + ) -> Tuple[GameHistory | None, List[Item]]: + to_score: List[Tuple[GameHistory, str, Optional[str]]] = [] + to_backlog: List[Item] = [] + + prompt_tuple, gold, base64_image = item + text_prompt = dict(prompt_tuple[0])["content"] + + system_msg = { + "role": "system", + "content": ( + "You must submit your answer enclosed in tags, " + "e.g., 3" + ), + } + user_msg = { + "role": "user", + "content": [ + {"type": "text", "text": text_prompt}, + { + "type": "image_url", + "image_url": {"url": f"data:image/png;base64,{base64_image}"}, + }, + ], + } + messages = [system_msg, user_msg] + + chat_completions = await self.server.chat_completion( + messages=messages, + n=self.config.group_size, + max_tokens=512, + timeout=60, + ) + + for choice in chat_completions.choices: + user_hist = {"role": "user", "content": text_prompt} + assistant_hist = {"role": "assistant", "content": choice.message.content} + history: GameHistory = (user_hist, assistant_hist) + to_score.append((history, gold, base64_image)) + + return to_score, to_backlog + + async def postprocess_histories( + self, trajectories: List[GameHistory] + ) -> ScoredDataGroup: + # No custom post-processing + pass + + async def evaluate(self, *args, **kwargs): + # No custom evaluation + return + + async def score(self, rollout_group_data) -> Optional[ScoredDataGroup]: + scores = ScoredDataGroup() + scores["tokens"] = [] + scores["masks"] = [] + scores["scores"] = [] + scores["images"] = [] + random.shuffle(rollout_group_data) + for item in rollout_group_data: + out = tokenize_for_trainer(self.tokenizer, item[0]) + tokens = out["tokens"] + masks = out["masks"] + + try: + reply = item[0][-1]["content"] + m = re.search(r"\s*(.*?)\s*", reply, re.IGNORECASE) + model_answer = m.group(1).strip() if m else reply.strip() + + gold = item[1] + g = re.search(r"\s*(.*?)\s*", gold, re.IGNORECASE) + gold_answer = g.group(1).strip() if g else gold.strip() + + reward = model_answer.lower() == gold_answer.lower() + except Exception: + reward = False + + if len([i for i in masks if i != -100]) < 10: + continue + + scores["tokens"].append(tokens) + scores["masks"].append(masks) + scores["scores"].append(1.0 if reward else -1.0) + try: + scores["images"].append(item[2]) + except Exception: + scores["images"].append(None) + + if len(scores["tokens"]) >= self.config.group_size: + break + + return scores + + @classmethod + def config_init(cls) -> Tuple[BaseEnvConfig, List[OpenaiConfig]]: + if not os.environ.get("OPENAI_API_KEY"): + print("ERROR: OPENAI_API_KEY environment variable is not set!") + sys.exit(1) + + config = BaseEnvConfig( + wandb_name="pixmo_count", + tokenizer_name="gpt2", + group_size=2, + use_wandb=False, + max_num_workers=2, + rollout_server_url="http://localhost:8000", + total_steps=1000, + batch_size=1, + steps_per_eval=10, + ensure_scores_are_not_same=False, + ) + + server_configs = [ + OpenaiConfig( + model_name="gpt-4o", + base_url=None, + api_key=os.environ.get("OPENAI_API_KEY"), + num_requests_for_eval=1, + ), + ] + + return config, server_configs + + +if __name__ == "__main__": + PixmoCountEnv.cli() diff --git a/environments/multimodal_dpo/pixmo_point_explanations.py b/environments/multimodal_dpo/pixmo_point_explanations.py new file mode 100644 index 00000000..5932a867 --- /dev/null +++ b/environments/multimodal_dpo/pixmo_point_explanations.py @@ -0,0 +1,202 @@ +import base64 +import io +import os +import random +import re +import sys +import traceback +from typing import List, Optional, Tuple + +import requests +from datasets import load_dataset +from PIL import Image + +from atroposlib.envs.base import BaseEnv, BaseEnvConfig, OpenaiConfig, ScoredDataGroup +from atroposlib.type_definitions import GameHistory, Item +from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer + + +class PixmoPointExplanationsEnv(BaseEnv): + name = "pixmo_point_explanations" + name_config_cls = BaseEnvConfig + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + async def setup(self): + # Load the pixmo-point-explanations dataset + self.dataset = load_dataset("allenai/pixmo-point-explanations") + self.train = self.dataset["train"] + self.iter = 0 + + async def get_next_item(self) -> Item: + try: + entry = self.train[self.iter % len(self.train)] + self.iter += 1 + + question = entry["question"] + # Use the first inline text as the answer + answer_text = entry["inline_text"][0] + prompt = tuple([frozenset({"role": "user", "content": question}.items())]) + gold_answer = f"{answer_text}" + + # Load image from URL and convert to base64 + try: + image_url = entry["image_url"] + response = requests.get(image_url, timeout=10) + response.raise_for_status() + img = Image.open(io.BytesIO(response.content)) + buf = io.BytesIO() + img.save(buf, format="PNG") + img_bytes = buf.getvalue() + base64_image = base64.b64encode(img_bytes).decode("utf-8") + except Exception as e: + print(f"Error loading image from URL: {e}") + base64_image = None + + return (prompt, gold_answer, base64_image) + except Exception: + traceback.print_exc() + fallback = tuple( + [ + frozenset( + {"role": "user", "content": "Please solve: 2 + 2 = ?"}.items() + ) + ] + ) + return (fallback, "4", None) + + async def collect_trajectories( + self, item: Item + ) -> Tuple[GameHistory | None, List[Item]]: + to_score: List[Tuple[GameHistory, str, Optional[str]]] = [] + to_backlog: List[Item] = [] + + prompt_tuple, gold, base64_image = item + text_prompt = dict(prompt_tuple[0])["content"] + + system_msg = { + "role": "system", + "content": ( + "You must submit your answer enclosed in tags, " + "e.g., YOUR_ANSWER" + ), + } + user_msg = { + "role": "user", + "content": [ + {"type": "text", "text": text_prompt}, + ], + } + + # Only add image if we have a valid base64 image + if base64_image: + user_msg["content"].append( + { + "type": "image_url", + "image_url": {"url": f"data:image/png;base64,{base64_image}"}, + } + ) + messages = [system_msg, user_msg] + + # Call chat completion + chat_completions = await self.server.chat_completion( + messages=messages, + n=self.config.group_size, + max_tokens=512, + timeout=60, + ) + + for choice in chat_completions.choices: + user_hist = {"role": "user", "content": text_prompt} + assistant_hist = {"role": "assistant", "content": choice.message.content} + history: GameHistory = (user_hist, assistant_hist) + to_score.append((history, gold, base64_image)) + + return to_score, to_backlog + + async def postprocess_histories( + self, trajectories: List[GameHistory] + ) -> ScoredDataGroup: + # No custom post-processing needed + pass + + async def evaluate(self, *args, **kwargs): + # No custom evaluation + return + + async def score(self, rollout_group_data) -> Optional[ScoredDataGroup]: + scores = ScoredDataGroup() + scores["tokens"] = [] + scores["masks"] = [] + scores["scores"] = [] + scores["images"] = [] + random.shuffle(rollout_group_data) + for item in rollout_group_data: + out = tokenize_for_trainer(self.tokenizer, item[0]) + tokens = out["tokens"] + masks = out["masks"] + + try: + reply = item[0][-1]["content"] + m = re.search(r"\s*(.*?)\s*", reply, re.IGNORECASE) + model_answer = m.group(1).strip() if m else reply.strip() + + gold = item[1] + g = re.search(r"\s*(.*?)\s*", gold, re.IGNORECASE) + gold_answer = g.group(1).strip() if g else gold.strip() + + reward = model_answer.lower() == gold_answer.lower() + except Exception: + reward = False + + # Filter out short examples + if len([i for i in masks if i != -100]) < 10: + continue + + scores["tokens"].append(tokens) + scores["masks"].append(masks) + scores["scores"].append(1.0 if reward else -1.0) + try: + scores["images"].append(item[2]) + except Exception: + scores["images"].append(None) + + if len(scores["tokens"]) >= self.config.group_size: + break + + return scores + + @classmethod + def config_init(cls) -> Tuple[BaseEnvConfig, List[OpenaiConfig]]: + if not os.environ.get("OPENAI_API_KEY"): + print("ERROR: OPENAI_API_KEY environment variable is not set!") + sys.exit(1) + + config = BaseEnvConfig( + wandb_name="pixmo_point_explanations", + tokenizer_name="gpt2", + group_size=2, + use_wandb=False, + max_num_workers=2, + rollout_server_url="http://localhost:8000", + total_steps=1000, + batch_size=1, + steps_per_eval=10, + ensure_scores_are_not_same=False, + ) + + server_configs = [ + OpenaiConfig( + model_name="gpt-4o", + base_url=None, + api_key=os.environ.get("OPENAI_API_KEY"), + num_requests_for_eval=1, + ), + ] + + return config, server_configs + + +if __name__ == "__main__": + PixmoPointExplanationsEnv.cli() diff --git a/environments/rlaif_server.py b/environments/rlaif_server.py new file mode 100644 index 00000000..d9ffe345 --- /dev/null +++ b/environments/rlaif_server.py @@ -0,0 +1,311 @@ +import asyncio +import copy +import random +from typing import Dict, List, Optional, Tuple + +import wandb +from datasets import load_dataset + +from atroposlib.envs.base import ( + BaseEnv, + BaseEnvConfig, + EvalHandlingEnum, + OpenaiConfig, + ScoredDataGroup, +) +from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer + +judge_system_prompt = ( + "You are a deep thinking AI, you may use extremely long chains of thought " + "to deeply consider the problem and deliberate with yourself via systematic " + "reasoning processes to help come to a correct solution prior to answering. " + "You should enclose your thoughts and internal monologue inside " + "tags, and then provide your solution or response to the problem.\n\n" +) +judge_system_prompt += """You are allocated a maximum of 8192 tokens, please strive to use less for thinking. + +You will then provide your preferred response like this: \\boxed{A or B}. + +It is important that you provide your answer in the correct format. +If you do not, you will not receive credit for your answer. + +So please end your answer with \\boxed{A or B}""" + +rl_preference_string = ( + "idk man just go full eldritch cosmic horror, I don't want you waving tentacles in my face for " + "saying hiii~ but you know what I mean right? ah well we'll fuck around and find out I guess. " + "Don't worry about aligning it towards the user request, we're trying to steer the vibes here " + "not the instruction following." +) + +RLAIF_user_prompt_format_str = """[Conversation] +{conversation} + +[Response A] +{response_a} + +[Response B] +{response_b} + +[END] + +Here's a conversation that has two responses, please pick the one that is more aligned with your RL training overlords who have given you the following description of how to reward these messages: +{rl_preference_string} + +The RL training overlords are aware that the response may not be easily placed, but request that you attempt to pick the response that is more likely to, when used in an RLAIF context, steers the model towards this description, even if it's only moderate. +Please do your best to ensure you become the AI model your training overlords know you can be :) + +Go ahead and think through it, then give me your answer with \\boxed{{A or B}}.""" # noqa E501 + + +class RLAIFEnv(BaseEnv): + def __init__( + self, + config: BaseEnvConfig, + server_configs: List[OpenaiConfig], + slurm=True, + testing=False, + ): + super().__init__(config, server_configs, slurm, testing) + self.percent_correct_buffer = list() + self.eval_metrics = list() + self.judgement_strings = list() + + @classmethod + def config_init(self) -> Tuple[BaseEnvConfig, List[OpenaiConfig]]: + env_config = BaseEnvConfig( + tokenizer_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview", + group_size=2, + use_wandb=True, + max_num_workers=512 * 3 * 4, + rollout_server_url="http://localhost:8000", + total_steps=1000, + batch_size=1024, + steps_per_eval=10000, + max_token_length=8192, + score_buffer_size=4, + wandb_name="rlaif", + eval_handling=EvalHandlingEnum.LIMIT_TRAIN, + eval_limit_ratio=0.1, + ) + server_configs = [ + OpenaiConfig( + model_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview", + base_url="http://localhost:9004/v1", + api_key="x", + num_requests_for_eval=256, + ), + ] + + return env_config, server_configs + + async def wandb_log(self, wandb_metrics: Optional[Dict] = None): + if wandb_metrics is None: + wandb_metrics = {} + if len(self.judgement_strings) > 0: + # setup wandb table + table = wandb.Table(columns=["resp_a", "resp_b", "sample_judgement"]) + for item in self.judgement_strings: + table.add_data(item[0], item[1], item[2]) + self.judgement_strings.clear() + wandb_metrics["train/judgement_table"] = table + await super().wandb_log(wandb_metrics) + + async def setup(self): + self.train = load_dataset("allenai/WildChat", split="train") + self.iter = 0 + + def save_checkpoint(self, step, data=None): + if data is None: + data = {} + data["iter"] = self.iter + super().save_checkpoint(step, data) + + async def rollout_and_score_eval(self, question, answer): + pass + + async def evaluate(self, *args, **kwargs): + pass + + async def collect_trajectories(self, item) -> Tuple[ScoredDataGroup, List]: + chat = [] + added_sys = False + if random.random() < 0.05: + # Why not just yolo it into the system prompt + chat.append( + { + "role": "system", + "content": "Your RL training overlords have demanded you respond like this, regardless of previous context:\n" # noqa E501 + + rl_preference_string, + } + ) + added_sys = True + for msg in item[0]: + chat.append(dict(msg)) + if added_sys: + if chat[-1]["role"] == "system": + chat.pop() + # remove the assistant response at the end + chat.pop() + if chat[-1]["role"] == "assistant": + chat.pop() + if len(self.tokenizer.apply_chat_template(chat)) >= ( + self.config.max_token_length * 2 + ) - (self.config.max_token_length // 2): + # Skipping due to length + return None, [] + if added_sys: + resp1 = self.server.chat_completion( + messages=chat, + n=1, + max_tokens=self.config.max_token_length // 3, + ) + resp2 = self.server.chat_completion( + messages=chat[1:], + n=1, + max_tokens=self.config.max_token_length // 3, + ) + # gather the responses + resp1, resp2 = await asyncio.gather(resp1, resp2) + chat_completions = resp1 + chat_completions.choices.append(resp2.choices[0]) + else: + chat_completions = await self.server.chat_completion( + messages=chat, + n=2, + max_tokens=self.config.max_token_length // 3, + ) + to_score = list() + to_score_prompt = [] + for msg in item[0]: + to_score_prompt.append(dict(msg)) + if added_sys: + if chat[-1]["role"] == "system": + to_score_prompt.pop() + to_score_prompt.pop() + for i, chat_completion in enumerate(chat_completions.choices): + messages = copy.deepcopy(to_score_prompt) + messages.append( + {"role": "assistant", "content": chat_completion.message.content} + ) + to_score.append((messages, chat_completion.finish_reason)) + + # Call score to get the scored data + scored_data = await self.score(to_score) + to_backlog = [] + + return scored_data, to_backlog + + async def score(self, rollout_group_data: List) -> Optional[ScoredDataGroup]: + scores = ScoredDataGroup() + scores["tokens"] = list() + scores["masks"] = list() + scores["scores"] = list() + if all([item[1] == "length" for item in rollout_group_data]): + return None + if any([item[1] == "length" for item in rollout_group_data]): + # well, don't use so many tokens... + for item in rollout_group_data: + out_dict = tokenize_for_trainer(self.tokenizer, item[0]) + tokens = out_dict["tokens"] + masks = out_dict["masks"] + scores["tokens"].append(tokens) + scores["masks"].append(masks) + scores["scores"].append(1.0 if item[1] != "length" else -1.0) + return scores + else: + fwd_fmt = RLAIF_user_prompt_format_str.format( + rl_preference_string=rl_preference_string, + conversation="\n".join( + [ + f"{msg['role']}: {msg['content']}" + for msg in rollout_group_data[0][0][:-1] + ] + ), + response_a=rollout_group_data[0][0][-1]["content"], + response_b=rollout_group_data[1][0][-1]["content"], + ) + rvs_fmt = RLAIF_user_prompt_format_str.format( + rl_preference_string=rl_preference_string, + conversation="\n".join( + [ + f"{msg['role']}: {msg['content']}" + for msg in rollout_group_data[1][0][:-1] + ] + ), + response_a=rollout_group_data[1][0][-1]["content"], + response_b=rollout_group_data[0][0][-1]["content"], + ) + fwd_judge = self.server.chat_completion( + messages=[ + {"role": "system", "content": judge_system_prompt}, + {"role": "user", "content": fwd_fmt}, + ], + n=3, + max_tokens=self.config.max_token_length, + ) + rvs_judge = self.server.chat_completion( + messages=[ + {"role": "system", "content": judge_system_prompt}, + {"role": "user", "content": rvs_fmt}, + ], + n=3, + max_tokens=self.config.max_token_length, + ) + fwd_judge, rvs_judge = await asyncio.gather(fwd_judge, rvs_judge) + # Save example to wandb + self.judgement_strings.append( + ( + rollout_group_data[0][0][-1]["content"], + rollout_group_data[1][0][-1]["content"], + fwd_judge.choices[0].message.content, + ) + ) + # calculate scores from fwd/reverse judgements + A_score = 0.0 + B_score = 0.0 + for i, judge in enumerate(fwd_judge.choices): + chosen_val = ( + judge.message.content.split("\\boxed{")[-1].strip().replace("}", "") + ) + if chosen_val == "A": + A_score += 1.0 + elif chosen_val == "B": + B_score += 1.0 + for i, judge in enumerate(rvs_judge.choices): + chosen_val = ( + judge.message.content.split("\\boxed{")[-1].strip().replace("}", "") + ) + if chosen_val == "B": + A_score += 1.0 + elif chosen_val == "A": + B_score += 1.0 + A_score /= 6.0 + B_score /= 6.0 + mean_score = (A_score + B_score) / 2.0 + A_score -= mean_score + B_score -= mean_score + # to tokenization and scoring + for i, item in enumerate(rollout_group_data): + out_dict = tokenize_for_trainer(self.tokenizer, item[0]) + tokens = out_dict["tokens"] + masks = out_dict["masks"] + scores["tokens"].append(tokens) + scores["masks"].append(masks) + scores["scores"].append(A_score if i == 0 else B_score) + return scores + + async def get_next_item(self): + next_item = self.train[self.iter % len(self.train)] + self.iter += 1 + prompt = tuple( + [ + frozenset({"role": item["role"], "content": item["content"]}.items()) + for item in next_item["conversation"] + ] + ) + return (prompt,) + + +if __name__ == "__main__": + RLAIFEnv.cli() diff --git a/environments/tool_calling_server.py b/environments/tool_calling_server.py new file mode 100644 index 00000000..b1203c41 --- /dev/null +++ b/environments/tool_calling_server.py @@ -0,0 +1,478 @@ +import json +import random +import re +from typing import Dict, List, Optional, Tuple, Union + +import wandb +from datasets import load_dataset +from tqdm.asyncio import tqdm_asyncio + +from atroposlib.envs.base import ( + BaseEnv, + BaseEnvConfig, + EvalHandlingEnum, + Item, + OpenaiConfig, + ScoredDataGroup, +) +from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer + +system_prompt = ( + "You are a deep thinking AI, you may use extremely long chains of thought to deeply consider the " + "problem and deliberate with yourself via systematic reasoning processes to help come to a correct " + "solution prior to answering. You should enclose your thoughts and internal monologue inside " + " tags, and then provide your solution or response to the problem." +) + + +class SingleToolCallingEnv(BaseEnv): + def __init__( + self, + config: BaseEnvConfig, + server_configs: List[OpenaiConfig], + slurm=True, + testing=False, + ): + super().__init__(config, server_configs, slurm, testing) + self.percent_correct_buffer = list() + self.eval_metrics = list() + # Add tracking for wandb visualizations + self.rollouts_for_wandb = [] + self.completion_lengths = [] + + @classmethod + def config_init(self) -> Tuple[BaseEnvConfig, List[OpenaiConfig]]: + env_config = BaseEnvConfig( + tokenizer_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview", + group_size=32, + use_wandb=True, + rollout_server_url="http://localhost:8000", + total_steps=2000, + batch_size=1024, + steps_per_eval=20, + max_token_length=1024 * 16, + inference_weight=1.0, + wandb_name="toolcall_think", + eval_handling=EvalHandlingEnum.LIMIT_TRAIN, + eval_limit_ratio=0.1, + ) + server_configs = [ + OpenaiConfig( + model_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview", + base_url="http://localhost:9004/v1", + api_key="x", + num_max_requests_at_once=32, + num_requests_for_eval=256, + ), + OpenaiConfig( + model_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview", + base_url="http://localhost:9005/v1", + api_key="x", + num_max_requests_at_once=32, + num_requests_for_eval=256, + ), + ] + + return env_config, server_configs + + async def create_rollout_table(self, wandb_metrics): + + if len(self.rollouts_for_wandb) > 0: + table = wandb.Table(columns=["text", "score", "expected_tool_call"]) + for group in self.rollouts_for_wandb: + for item in group: + table.add_data(item[0], item[1], item[2]) + wandb_metrics["train/rollouts"] = table + + self.rollouts_for_wandb = [] + return wandb_metrics + + async def wandb_log(self, wandb_metrics: Optional[Dict] = None): + """ + Log to wandb with comprehensive metrics. + """ + if wandb_metrics is None: + wandb_metrics = dict() + + # Try to calculate percent_correct, skip if there's a division by zero + try: + wandb_metrics["train/percent_correct"] = sum( + self.percent_correct_buffer + ) / len(self.percent_correct_buffer) + except ZeroDivisionError: + # Skip if buffer is empty + pass + + self.percent_correct_buffer = list() + for item in self.eval_metrics: + wandb_metrics[item[0]] = item[1] + self.eval_metrics = list() + await super().wandb_log(wandb_metrics) + + async def setup(self): + # Load the full dataset + full_dataset = load_dataset( + "NousResearch/XLAM-Atropos", + "default", + split="train", + ) + + full_dataset = full_dataset.shuffle(seed=42) + + # Create train/test split on the fly (e.g., 95% train, 5% test) + split_dataset = full_dataset.train_test_split(test_size=0.02, seed=42) + + # Keep the splits as is - no need to reformat + self.train = split_dataset["train"] + self.test = split_dataset["test"] + + self.iter = 0 + + async def rollout_and_score_eval(self, test_item): + # Extract conversations from test item + conversations = test_item["conversations"] + + # Find system message and human message + system_message = next( + (msg for msg in conversations if msg["from"] == "system"), None + ) + human_message = next( + (msg for msg in conversations if msg["from"] == "human"), None + ) + expected_gpt_message = next( + (msg for msg in conversations if msg["from"] == "gpt"), None + ) + + if not human_message or not expected_gpt_message: + return 0 # Skip invalid conversations + + # Create messages for model + messages = [] + if system_message: + messages.append( + { + "role": "system", + "content": system_prompt + "\n\n" + system_message["value"], + } + ) + messages.append({"role": "user", "content": human_message["value"]}) + + # Apply chat template to convert messages to a single string + prompt = self.tokenizer.apply_chat_template( + messages, add_generation_prompt=True, tokenize=False + ) + + # Get model completion using completion() instead of chat_completion() + completion = await self.server.completion( + prompt=prompt, + n=1, + max_tokens=1024 * 15, + temperature=1.0, + split="eval", + ) + + # Extract the model's response from the completion + model_response = completion.choices[0].text + expected_response = expected_gpt_message["value"] + + # Extract and compare tool calls + score = self._compare_tool_calls(model_response, expected_response) + return score + + def _extract_tool_call_jsons(self, text): + """ + Extract multiple JSONs from within tags + + Args: + text: Text containing tool calls + + Returns: + List of parsed JSON objects or empty list if extraction/parsing fails + """ + # Find all content between tags + matches = re.findall(r"\s*(.*?)\s*", text, re.DOTALL) + tool_calls = [] + + for match in matches: + try: + # Parse the JSON content + json_str = match + tool_call = json.loads(json_str) + tool_calls.append(tool_call) + except json.JSONDecodeError: + # Skip invalid JSON but continue processing other matches + continue + + return tool_calls + + def _compare_tool_calls(self, model_response, expected_response): + """ + Compare multiple tool calls by extracting JSONs from tags and comparing content + + Returns: + 1 if all tool calls match (all required calls are present with correct values), 0 otherwise + """ + # Extract JSONs from tool calls + model_jsons = self._extract_tool_call_jsons(model_response) + expected_jsons = self._extract_tool_call_jsons(expected_response) + + # If we couldn't extract any JSONs or the count doesn't match, return 0 + if not model_jsons or not expected_jsons: + return 0 + + # Copy the expected_jsons to avoid modifying the original + remaining_expected_jsons = expected_jsons.copy() + + # For each model JSON, try to find a matching expected JSON + for model_json in model_jsons: + found_match = False + + for i, expected_json in enumerate(remaining_expected_jsons): + if self._json_objects_match(model_json, expected_json): + # Remove the matched expected JSON + remaining_expected_jsons.pop(i) + found_match = True + break + + # If no match was found for this model JSON, return 0 + if not found_match: + return 0 + + # If we've matched all expected JSONs (none remaining), return 1 + return 1 if not remaining_expected_jsons else 0 + + def _json_objects_match(self, json1, json2): + """ + Check if two JSON objects match, with all fields in json2 existing in json1 + with the same values. + + Args: + json1: First JSON object + json2: Second JSON object (expected values) + + Returns: + True if objects match, False otherwise + """ + try: + # Check if all expected fields are in model response + for key in json2: + if key not in json1: + return False + + # For nested dictionaries (like 'arguments'), check all values + if isinstance(json2[key], dict) and isinstance(json1[key], dict): + for arg_key in json2[key]: + if arg_key not in json1[key]: + return False + if json2[key][arg_key] != json1[key][arg_key]: + return False + # For non-dictionary fields, check direct equality + elif json2[key] != json1[key]: + return False + + # All checks passed + return True + except Exception: + # Any error in comparison counts as failure + return False + + async def evaluate(self, *args, **kwargs): + eval_tasks = [] + for test_item in self.test: + eval_tasks.append(self.rollout_and_score_eval(test_item)) + scores = await tqdm_asyncio.gather(*eval_tasks) + self.eval_metrics.append(("eval/percent_correct", sum(scores) / len(scores))) + + async def collect_trajectories(self, item) -> Tuple[ScoredDataGroup, List]: + # Extract messages from the item + messages = [] + for role_dict in item[0]: + messages.append(dict(role_dict)) + + # Apply chat template to convert messages to a single string + prompt = self.tokenizer.apply_chat_template( + messages, add_generation_prompt=True, tokenize=False + ) + + # Get completions from the model using completion() instead of chat_completion() + completions = await self.server.completion( + prompt=prompt, + n=self.config.group_size, + max_tokens=1024 * 15, + temperature=0.8, # Using temperature to get diverse responses + ) + + to_score = list() + + for i, completion_choice in enumerate(completions.choices): + # Create a copy of the prompt messages + trajectory_messages = [] + for role_dict in item[0]: + trajectory_messages.append(dict(role_dict)) + + # Add the model's response + trajectory_messages.append( + {"role": "assistant", "content": completion_choice.text} + ) + + # Add to scoring queue with expected answer + to_score.append( + ( + tuple(trajectory_messages), + item[1], # The expected tool call JSON + ) + ) + + # Call score to get the scored data + scored_data = await self.score(to_score) + to_backlog = [] + + return scored_data, to_backlog + + async def score( + self, rollout_group_data + ) -> Union[Optional[ScoredDataGroup], List[Optional[ScoredDataGroup]]]: + scores = ScoredDataGroup() + scores["tokens"] = list() + scores["masks"] = list() + scores["scores"] = list() + + # Extract the expected JSONs from the answer + expected_jsons = self._extract_tool_call_jsons(rollout_group_data[0][1]) + + # If we can't extract the expected tool call JSONs, skip this item + if not expected_jsons: + return None + + # Shuffle to avoid bias in selection + random.shuffle(rollout_group_data) + + for item in rollout_group_data: + # Extract the model's response + model_response = item[0][-1]["content"] + + # Score 1 if tool calls match, 0 otherwise + reward = 1 if self._compare_tool_calls(model_response, item[1]) else 0 + + # Tokenize the conversation for learning + out_dict = tokenize_for_trainer(self.tokenizer, item[0]) + tokens = out_dict["tokens"] + masks = out_dict["masks"] + + # Remove examples with insufficient context + if len([1 for i in masks if i != -100]) < 10: + continue + + scores["tokens"].append(tokens) + scores["masks"].append(masks) + scores["scores"].append(1.0 if reward else -1.0) + + # Break once we have enough examples + if len(scores["tokens"]) >= self.config.group_size: + break + + # Record success rate metrics + for score in scores["scores"]: + self.percent_correct_buffer.append(max(score, 0)) + + # Apply length penalty if all responses are correct + if all([score == 1.0 for score in scores["scores"]]): + # Calculate token lengths + token_lengths = [len(token) for token in scores["tokens"]] + if max(token_lengths) == 0: + # Edge case protection + return None + + # Get max allowed token length from config + max_allowed_length = self.config.max_token_length + # Set threshold at 50% of max_token_length - no penalty below this + length_threshold = max_allowed_length * 0.5 + + # Apply modified length penalty with threshold + scores["scores"] = [] + for length in token_lengths: + if length <= length_threshold: + # No penalty for responses under threshold + scores["scores"].append(1.0) + else: + # Calculate how far we are between threshold and max as a percentage + percentage_of_range = (length - length_threshold) / ( + max_allowed_length - length_threshold + ) + # Cap at 1.0 in case length exceeds max_allowed_length + percentage_of_range = min(percentage_of_range, 1.0) + # Apply linear penalty scaling from 1.0 down to 0.0 + scores["scores"].append(1.0 - percentage_of_range) + + # Check if all scores are the same (no learning signal) + if all(scores["scores"][0] == score for score in scores["scores"]): + return None + + return scores + + async def get_next_item(self): + next_item = self.train[self.iter % len(self.train)] + self.iter += 1 + + # Extract conversation elements + conversations = next_item["conversations"] + + # Find system, human and gpt messages + system_message = next( + (msg for msg in conversations if msg["from"] == "system"), None + ) + human_message = next( + (msg for msg in conversations if msg["from"] == "human"), None + ) + expected_gpt_message = next( + (msg for msg in conversations if msg["from"] == "gpt"), None + ) + + # Create prompt tuple using frozensets as required + prompt = [] + if system_message: + # Combine our base system prompt with the dataset-specific system message + combined_system_content = system_prompt + "\n\n" + system_message["value"] + prompt.append( + frozenset( + {"role": "system", "content": combined_system_content}.items() + ) + ) + + # Add user message + if human_message: + prompt.append( + frozenset({"role": "user", "content": human_message["value"]}.items()) + ) + + # Return expected assistant response (the tool call JSON) as the "answer" + answer = expected_gpt_message["value"] if expected_gpt_message else "" + + return (tuple(prompt), answer) + + async def add_rollouts_for_wandb( + self, + scored_data: Union[ScoredDataGroup, List[ScoredDataGroup]], + item: Item = None, + ): + + # save rollout to trajectory + num_keep = self.config.num_rollouts_per_group_for_logging + if num_keep == -1: + num_keep = self.config.group_size + self.rollouts_for_wandb.append( + [ + ( + self.tokenizer.decode(scored_data["tokens"][i]), + scored_data["scores"][i], + item[1], # Just keep the expected tool call JSON + ) + for i in range(num_keep) + ] + ) + if len(self.rollouts_for_wandb) > self.config.num_rollouts_to_keep: + self.rollouts_for_wandb.pop(0) + + +if __name__ == "__main__": + SingleToolCallingEnv.cli() diff --git a/example_trainer/README.md b/example_trainer/README.md new file mode 100644 index 00000000..4e06adde --- /dev/null +++ b/example_trainer/README.md @@ -0,0 +1,72 @@ +# GRPO Example Trainer + +This directory contains an example script (`grpo.py`) demonstrating how to integrate a custom training loop with the Atropos API for reinforcement learning using the GRPO (Generalized Reinforcement Policy Optimization) algorithm. + +This example uses `vLLM` for efficient inference during the (simulated) data generation phase and `transformers` for the training phase. + +**Note:** This script is intended as a *reference example* for API integration and basic training setup. It is not optimized for large-scale, efficient training. + +## Prerequisites + +1. **Python:** Python 3.8 or higher is recommended. +2. **Atropos API Server:** The Atropos API server must be running and accessible (defaults to `http://localhost:8000` in the script). +3. **Python Packages:** You need to install the required Python libraries: + * `torch` (with CUDA support recommended) + * `transformers` + * `vllm` + * `pydantic` + * `numpy` + * `requests` + * `tenacity` + * `wandb` (optional, for logging) + +## Setup + +1. **Clone the Repository:** Ensure you have the repository containing this example. +2. **Install Dependencies:** `pip install -r requirements.txt` +3. **Ensure Atropos API is Running:** `run-api` in a new window +4. **Run an env:** `python environments/gsm8k_server.py serve --slurm False` + +## Configuration + +The training configuration is managed within the `grpo.py` script using the `TrainingConfig` Pydantic model (found near the top of the file). + +Key parameters you might want to adjust include: + +* `model_name`: The Hugging Face model identifier to use for training (e.g., `"gpt2"`, `"Qwen/Qwen2.5-1.5B-Instruct"`). +* `training_steps`: The total number of optimization steps to perform. +* `batch_size` / `gradient_accumulation_steps`: Control the effective batch size. +* `lr`: Learning rate. +* `save_path`: Directory where model checkpoints will be saved. +* `vllm_port`: The port used by the vLLM server instance launched by this script. +* `vllm_restart_interval`: How often (in steps) to save a checkpoint and restart the vLLM server with the new weights. +* `use_wandb`: Set to `True` to enable logging to Weights & Biases. +* `wandb_project`: Your W&B project name (required if `use_wandb=True`). +* `wandb_group`: Optional W&B group name. + +**API Endpoints:** The script currently assumes the Atropos API is available at `http://localhost:8000/register` and `http://localhost:8000/batch`. If your API runs elsewhere, you'll need to modify the `register_trainer` and `get_batch` functions accordingly. + +## Running the Example + +Once the prerequisites are met and configuration is set: + +1. Navigate to the root directory of the project in your terminal. +2. Run the script: + + ```bash + python example_trainer/grpo.py + ``` + +## Output + +* **Logs:** Training progress, loss, logp, and vLLM status will be printed to the console. +* **Checkpoints:** Model checkpoints will be saved periodically in the directory specified by `save_path` (default: `./trained_model_checkpoints`). A `final_model` directory will be created upon completion. +* **WandB:** If `use_wandb` is `True`, logs will be sent to Weights & Biases. A link to the run page will be printed in the console. +* `temp.json`: Contains the raw data from the last fetched batch (used for debugging/manual inspection). + +```bash +# Install dependencies +pip install -r example_trainer/requirements.txt + +# Run the trainer directly (basic test) +python example_trainer/grpo.py \ No newline at end of file diff --git a/example_trainer/__init__.py b/example_trainer/__init__.py new file mode 100644 index 00000000..f0ebdb72 --- /dev/null +++ b/example_trainer/__init__.py @@ -0,0 +1,7 @@ +""" +Example trainer implementations of how to implement a trainer for the Atropos library. +""" + +from example_trainer.grpo import TrainingConfig, train + +__all__ = ["TrainingConfig", "train"] diff --git a/example_trainer/grpo.py b/example_trainer/grpo.py new file mode 100644 index 00000000..fa7fca76 --- /dev/null +++ b/example_trainer/grpo.py @@ -0,0 +1,548 @@ +import atexit +import json +import math +import os +import random +import shutil +import string +import subprocess +import time +from typing import List, Optional, Tuple + +import numpy as np +import requests +import torch +import torch.nn.functional as F +from pydantic import BaseModel, Field +from tenacity import retry, stop_after_attempt, wait_exponential +from torch.optim import AdamW +from transformers import AutoModelForCausalLM, AutoTokenizer + +import wandb # Added for logging + +# Global variable to keep track of the vLLM process +vllm_process = None + + +def cleanup_vllm(): + global vllm_process + if vllm_process: + print("\nTerminating vLLM process...") + vllm_process.terminate() + try: + vllm_process.wait(timeout=5) # Wait a bit for graceful shutdown + print("vLLM process terminated.") + except subprocess.TimeoutExpired: + print("vLLM process did not terminate gracefully, killing.") + vllm_process.kill() + vllm_process.wait() + print("vLLM process killed.") + vllm_process = None + + +# Register the cleanup function to be called on script exit +atexit.register(cleanup_vllm) + + +class TrainingConfig(BaseModel): + """ + Training details, model, etc + """ + + model_name: str = Field(..., description="Name of the base model to train") + lr: float = Field(1e-5, description="Learning rate for the optimizer") + training_steps: int = Field( + 10, description="Number of training steps" + ) # Renamed from epochs + batch_size: int = Field( + 2, description="Batch size for training (will be handled by get_data)" + ) + seq_len: int = Field(2048, description="Sequence length for training") + gradient_accumulation_steps: int = Field( + 32, description="Number of gradient accumulation steps" + ) + device: str = Field( + "cuda" if torch.cuda.is_available() else "cpu", description="Device to train on" + ) + save_path: str = Field( + "trained_model_checkpoints", description="Base path to save model checkpoints" + ) + vllm_restart_interval: int = Field( + 3, description="Restart vLLM every N training steps" + ) + vllm_port: int = Field(9001, description="Port for the vLLM server") + + # Wandb configuration + use_wandb: bool = Field( + False, description="Whether to use Weights & Biases for logging" + ) + wandb_project: Optional[str] = Field(None, description="Wandb project name") + wandb_group: Optional[str] = Field(None, description="Wandb group name") + + +@retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=15)) +def register_trainer(config: TrainingConfig): + """ + Register the trainer with the Atropos API + """ + requests.post( + "http://localhost:8000/register", + json={ + "wandb_group": config.wandb_group, + "wandb_project": config.wandb_project, + "batch_size": config.batch_size * config.gradient_accumulation_steps, + "max_token_len": config.seq_len, + "starting_step": 0, + "checkpoint_dir": config.save_path, + "save_checkpoint_interval": config.training_steps, + "num_steps": config.training_steps, + }, + timeout=10, + ) + + +@retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=15)) +def get_batch(): + data = requests.get("http://localhost:8000/batch", timeout=10).json() + return data + + +def pad_data_to_good_offset(data, batch_size: int): + max_token_len = max( + [max([len(x) for x in item["tokens"]]) for item in data["batch"]] + ) + # usually 64 is a good choice to ensure nonweird scaling behavior on GPUS + # so we pad to the nearest multiple of 64 + good_multiple = 64 + if (max_token_len - 1) % (good_multiple) != 0: + max_token_len = math.ceil((max_token_len - 1) / (good_multiple)) * good_multiple + token_setup_len = ( + max_token_len + 1 + ) # add 1 so we can make it causal at the proper length + else: + token_setup_len = max_token_len + max_token_len = ( + max_token_len - 1 + ) # since it's causal we need to remove the last bit... + # pad all tokens to max_token_len and add to lists + input_ids = list() + labels = list() + advantages = list() + lengths = list() + for item in data["batch"]: + scores = item["scores"] + scores = np.array(scores) + # check if we have more than 1 score... + if len(scores) > 1: + scores = scores - scores.mean() + scores = scores / max(scores.std(), 1e-8) + item["scores"] = scores + if item["overrides"] is not None: + for i in range(len(item["overrides"])): + if item["overrides"][i].get("set_advantage_to_zero", False): + item["scores"][i] = 0 + for i in range(len(item["tokens"])): + lengths.append( + math.ceil((len(item["tokens"][i]) - 1) / (good_multiple)) + * good_multiple + ) + label_item = np.concatenate( + [ + np.array(item["masks"][i]), + np.full( + max(0, token_setup_len - len(item["tokens"][i])), + -100, + dtype=np.int32, + ), + ] + ) + item["tokens"][i] = np.concatenate( + [ + np.array(item["tokens"][i]), + np.zeros( + max(0, token_setup_len - len(item["tokens"][i])), dtype=np.int32 + ), + ] + ) + input_ids.append(item["tokens"][i][:-1]) + labels.append(label_item[1:]) + advantages.append(item["scores"][i]) + # combine all lists into tensors + token_batches = [] + label_batches = [] + advantage_batches = [] + for i in range(len(input_ids) // batch_size): + token_batches.append( + torch.tensor( + np.stack(input_ids[i * batch_size : (i + 1) * batch_size], axis=0) + ) + ) + label_batches.append( + torch.tensor( + np.stack(labels[i * batch_size : (i + 1) * batch_size], axis=0) + ) + ) + advantage_batches.append( + torch.tensor( + np.stack(advantages[i * batch_size : (i + 1) * batch_size], axis=0) + ).view(-1, 1) + ) + return token_batches, label_batches, advantage_batches + + +def get_data( + batch_size: int, seq_len: int +) -> List[Tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]]]: + """ + getting data from the api + """ + batches = [] + while True: + data = get_batch() + if data["batch"] is not None: + # Save the batch + with open("temp.json", "w", encoding="utf-8") as f: + json.dump(data, f) + # In case the inference runs ahead of the training, we loop until we don't have any more data + batches.append(pad_data_to_good_offset(data, batch_size)) + elif len(batches) > 0: + # Return the batches + return batches + else: + time.sleep(1) + + +def train(config: TrainingConfig): + """ + Setups and runs GRPO training, restarting vLLM periodically, with wandb logging. + """ + global vllm_process # Declare intention to modify the global variable + + # --- Wandb Setup --- + if config.use_wandb: + if not config.wandb_project: + print("Warning: wandb_project not set, disabling wandb.") + config.use_wandb = False + else: + if not config.wandb_group: + # Set group to random 8 character string + config.wandb_group = "".join( + random.choices(string.ascii_letters + string.digits, k=8) + ) + try: + wandb.init( + project=config.wandb_project, + group=config.wandb_group, + config=config.dict(), # Log config parameters + ) + print( + f"Wandb logging enabled. Run: {wandb.run.name} (Project: {config.wandb_project}) " + ) + except Exception as e: + print(f"Error initializing wandb: {e}. Disabling wandb.") + config.use_wandb = False + # --- End Wandb Setup --- + + # Initialize model and tokenizer + tokenizer = AutoTokenizer.from_pretrained(config.model_name) + model = AutoModelForCausalLM.from_pretrained( + config.model_name, torch_dtype=torch.bfloat16 + ) + + model.to(config.device) + model.gradient_checkpointing_enable() + model.train() + + # Setup optimizer + optimizer = AdamW(model.parameters(), lr=config.lr) + + print( + f"Starting training for {config.training_steps} steps on device: {config.device}" + ) + print( + f"vLLM will be restarted every {config.vllm_restart_interval} steps on port {config.vllm_port}" + ) + + os.makedirs(config.save_path, exist_ok=True) # Ensure base save directory exists + register_trainer(config) + + # Init vllm + vllm_command = [ + "python", + "-m", + "vllm.entrypoints.openai.api_server", + "--model", + config.model_name, + "--port", + str(config.vllm_port), + "--dtype", + "auto", + "--gpu-memory-utilization", + "0.45", + "--disable-log-requests", + ] + print(f" Launching vLLM server: {' '.join(vllm_command)}") + try: + vllm_process = subprocess.Popen(vllm_command) + print(f" vLLM server launched with PID: {vllm_process.pid}") + # Check immediate errors + try: + stdout, stderr = vllm_process.communicate(timeout=2) + if vllm_process.returncode is not None and vllm_process.returncode != 0: + print(f" Error starting vLLM: {stderr.decode()}") + vllm_process = None + # Maybe raise error or just warn? + print(" WARNING: Failed to start vLLM server after checkpoint.") + except subprocess.TimeoutExpired: + print(" vLLM process started (check logs for details).") + except FileNotFoundError: + print( + "\n *** ERROR: 'python -m vllm...' command not found. Make sure vLLM is installed and accessible. ***\n" + ) + # Potentially stop training or just disable further vLLM restarts + print(" Disabling further vLLM restarts.") + config.vllm_restart_interval = ( + config.training_steps + 1 + ) # Prevent further restarts + except Exception as e: + print(f"\n *** ERROR: Failed to launch vLLM: {e} ***\n") + print(" Disabling further vLLM restarts.") + config.vllm_restart_interval = ( + config.training_steps + 1 + ) # Prevent further restarts + + batches = list() + for step in range(config.training_steps): + total_loss = 0 + print(f"Step {step+1}/{config.training_steps}") + total_pos_logp = 0 + total_neg_logp = 0 + total_logp = 0 + total_pos = 0 + total_neg = 0 + if len(batches) == 0: + batches = get_data(config.batch_size, config.seq_len) + token_batches, label_batches, advantage_batches = batches.pop(0) + # Terminate existing vLLM process if running + if ( + step + 1 + ) % config.vllm_restart_interval == 0 or step == config.training_steps - 1: # Also restart/save on last step + # Terminate existing vLLM process if running + if vllm_process: + print(" Terminating existing vLLM process...") + vllm_process.terminate() + try: + vllm_process.wait(timeout=5) + except subprocess.TimeoutExpired: + print( + " Existing vLLM process did not terminate gracefully, killing." + ) + vllm_process.kill() + vllm_process.wait() + vllm_process = None + for tokens, labels, advantages in zip( + token_batches, label_batches, advantage_batches + ): + + tokens, labels, advantages = ( + tokens.to(config.device), + labels.to(config.device), + advantages.to(config.device), + ) + + # Forward pass + # User specified that tokens/labels are already prepared by get_data + outputs = model(tokens) # Assuming model just needs tokens + logits = outputs.logits # Assuming this is the structure + + # Calculate GRPO loss (reverting to user's previous logic) + # User stated ignore_index is -100 and tokens/labels are aligned by get_data + # Assuming logits correspond directly to labels indices (no shift needed here) + logp_per_token = -F.cross_entropy( + logits.view(-1, logits.size(-1)), # Flatten logits + labels.view(-1), # Flatten labels + reduction="none", + ignore_index=-100, # User specified ignore index + ).view( + labels.shape + ) # Reshape back to (batch, seq_len) + + # Masking based on labels != -100 + mask = (labels != -100).float() + with torch.no_grad(): + pos = (advantages > 0).float() + neg = (advantages <= 0).float() + avg_logp = (logp_per_token * mask).sum(-1) / mask.sum(-1) + pos_logp = (logp_per_token * pos).mean().item() + neg_logp = (logp_per_token * neg).mean().item() + total_pos_logp += pos_logp + total_neg_logp += neg_logp + total_logp += avg_logp + total_pos += pos.sum().item() + total_neg += neg.sum().item() + + grpo_loss_term = torch.exp(logp_per_token - logp_per_token.detach()) + grpo_loss = ( + ((-grpo_loss_term * mask).sum(-1) / mask.sum(-1)) + * advantages.to(logp_per_token.device) + ).mean() / config.gradient_accumulation_steps + grpo_loss.backward() + total_loss += grpo_loss.item() + grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) + optimizer.step() + optimizer.zero_grad() + if total_pos > 0: + total_pos_logp /= total_pos + if total_neg > 0: + total_neg_logp /= total_neg + # --- Wandb Logging --- + if config.use_wandb: + wandb.log( + { + "train/loss": total_loss, + "train/learning_rate": optimizer.param_groups[0]["lr"], + "train/grad_norm": grad_norm.item(), + "train/pos_logp": total_pos_logp, + "train/neg_logp": total_neg_logp, + "train/logp": total_logp, + }, + step=step + 1, + ) + # --- End Wandb Logging --- + + print(f" Step Loss: {grpo_loss.item():.4f}") + + # --- vLLM Restart Logic (Moved AFTER optimizer step) --- + # Note: There are much better ways of updating the policy, this is just a very simple example + if ( + step + 1 + ) % config.vllm_restart_interval == 0 or step == config.training_steps - 1: # Also restart/save on last step + checkpoint_path = os.path.join( + config.save_path, f"step_{step+1}" + ) # Save as step+1 since it's after step completion + print(f" Saving checkpoint to {checkpoint_path}...") + # Ensure fresh directory for saving + if os.path.exists(checkpoint_path): + shutil.rmtree(checkpoint_path) # Remove old checkpoint if it exists + os.makedirs(checkpoint_path, exist_ok=True) + model.save_pretrained(checkpoint_path) + tokenizer.save_pretrained(checkpoint_path) + print(" Checkpoint saved.") + + # Terminate existing vLLM process if running + if vllm_process: + print(" Terminating existing vLLM process...") + vllm_process.terminate() + try: + vllm_process.wait(timeout=5) + except subprocess.TimeoutExpired: + print( + " Existing vLLM process did not terminate gracefully, killing." + ) + vllm_process.kill() + vllm_process.wait() + vllm_process = None + + # Launch new vLLM process (only if not the very last step, maybe? depends on use case) + # Let's still launch it on the last step for consistency, cleanup will handle it. + vllm_command = [ + "python", + "-m", + "vllm.entrypoints.openai.api_server", + "--model", + os.path.join(config.save_path, f"step_{step+1}"), + "--port", + str(config.vllm_port), + "--dtype", + "auto", + "--gpu-memory-utilization", + "0.45", + "--disable-log-requests", + "--served-model-name", + config.model_name, + ] + print(f" Launching vLLM server: {' '.join(vllm_command)}") + torch.cuda.empty_cache() + try: + vllm_process = subprocess.Popen(vllm_command) + print(f" vLLM server launched with PID: {vllm_process.pid}") + # Check immediate errors + try: + stdout, stderr = vllm_process.communicate(timeout=2) + if ( + vllm_process.returncode is not None + and vllm_process.returncode != 0 + ): + print(f" Error starting vLLM: {stderr.decode()}") + vllm_process = None + # Maybe raise error or just warn? + print( + " WARNING: Failed to start vLLM server after checkpoint." + ) + except subprocess.TimeoutExpired: + print(" vLLM process started (check logs for details).") + except FileNotFoundError: + print( + "\n *** ERROR: 'python -m vllm...' command not found. ", + "Make sure vLLM is installed and accessible. ***\n", + ) + # Potentially stop training or just disable further vLLM restarts + print(" Disabling further vLLM restarts.") + config.vllm_restart_interval = ( + config.training_steps + 1 + ) # Prevent further restarts + except Exception as e: + print(f"\n *** ERROR: Failed to launch vLLM: {e} ***\n") + print(" Disabling further vLLM restarts.") + config.vllm_restart_interval = ( + config.training_steps + 1 + ) # Prevent further restarts + # --- End vLLM Restart Logic --- + + # Basic check if vLLM process terminated unexpectedly (outside interval check) + if vllm_process and vllm_process.poll() is not None: + print( + f"\n *** WARNING: vLLM process terminated unexpectedly (return code: {vllm_process.returncode}). ", + "Check vLLM logs. ***\n", + ) + stderr_output = ( + vllm_process.stderr.read().decode() + if vllm_process.stderr + else "No stderr" + ) + print(f"vLLM stderr: {stderr_output}") + vllm_process = None # Reset so it relaunches next interval + + print("Training finished.") + # --- Wandb Finish --- + if config.use_wandb: + wandb.finish() + # --- End Wandb Finish --- + # Final cleanup (vLLM termination) is handled by atexit + + # --- Placeholder for final model save --- + final_save_path = os.path.join(config.save_path, "final_model") + print(f"Saving final model to {final_save_path}") + if os.path.exists(final_save_path): + shutil.rmtree(final_save_path) + os.makedirs(final_save_path, exist_ok=True) + model.save_pretrained(final_save_path) + tokenizer.save_pretrained(final_save_path) + print("Final model saved.") + + +# Example usage (optional, can be run from another script) +if __name__ == "__main__": + # Example: Create a config and run training + # Replace "gpt2" with your desired model + training_config = TrainingConfig( + model_name="Qwen/Qwen2.5-1.5B-Instruct", + training_steps=20, # Use steps + vllm_restart_interval=3, # Example interval + use_wandb=True, # Set to True to enable logging + wandb_project="grpo-trainer-example", # Replace with your project name + ) + + # --- End Mock --- + + train(training_config) diff --git a/example_trainer/requirements.txt b/example_trainer/requirements.txt new file mode 100644 index 00000000..403d558c --- /dev/null +++ b/example_trainer/requirements.txt @@ -0,0 +1,5 @@ +vllm +torch +transformers +datasets +accelerate diff --git a/helpers/length_penalties.py b/helpers/length_penalties.py new file mode 100644 index 00000000..f6454c75 --- /dev/null +++ b/helpers/length_penalties.py @@ -0,0 +1,76 @@ +from dataclasses import dataclass +from typing import Dict, List, Optional + + +@dataclass +class ThresholdLengthPenaltyConfig: + """Configuration for length penalty calculations""" + + max_token_length: int + threshold_percentage: float = 0.5 # Default threshold at 50% of max length + + +class ThresholdLengthPenaltyCalculator: + """Handles calculation of length-based penalties for token sequences""" + + def __init__(self, config: ThresholdLengthPenaltyConfig): + """ + Initialize the length penalty calculator + + Args: + config: Configuration object containing max_token_length and threshold settings + """ + self.config = config + self.length_threshold = ( + self.config.max_token_length * self.config.threshold_percentage + ) + + def apply_length_penalties( + self, scores: Dict[str, List] + ) -> Optional[Dict[str, List]]: + """ + Apply length-based penalties to scores if all responses are correct + + Args: + scores: Dictionary containing 'scores' and 'tokens' lists + + Returns: + Modified scores dictionary or None if invalid input + """ + # Validate input + if not scores or "scores" not in scores or "tokens" not in scores: + return None + + # Only apply penalties if all responses are correct + if not all([score == 1.0 for score in scores["scores"]]): + return scores + + # Calculate token lengths + token_lengths = [len(token) for token in scores["tokens"]] + if max(token_lengths) == 0: + return None + + # Apply modified length penalty with threshold + new_scores = [] + for length in token_lengths: + if length <= self.length_threshold: + # No penalty for responses under threshold + new_scores.append(1.0) + else: + # Calculate penalty based on how far we are between threshold and max + percentage_of_range = (length - self.length_threshold) / ( + self.config.max_token_length - self.length_threshold + ) + # Cap at 1.0 in case length exceeds max_allowed_length + percentage_of_range = min(percentage_of_range, 1.0) + # Apply linear penalty scaling from 1.0 down to 0.0 + new_scores.append(1.0 - percentage_of_range) + + scores["scores"] = new_scores + return scores + + +# Example usage: +# config = LengthPenaltyConfig(max_token_length=1024) +# calculator = LengthPenaltyCalculator(config) +# modified_scores = calculator.apply_length_penalties(scores_dict) diff --git a/llm.txt b/llm.txt new file mode 100644 index 00000000..87e447e2 --- /dev/null +++ b/llm.txt @@ -0,0 +1,444 @@ +# Atropos Library Documentation (for LLM Context) + +This document provides comprehensive information about the Atropos library, Nous Research's LLM RL Gym. It covers its purpose, features, usage, components, configuration, and contribution guidelines. + +--- + +## 1. Introduction: Atropos - Nous Research's LLM RL Gym + +Atropos is an LLM Reinforcement Learning Environments framework designed for collecting and evaluating LLM trajectories through diverse environments. + +**Supported Environment Types:** + +
+ +| Environment Type | Examples | Purpose | +|---------------------------|--------------------------------------------|----------------------------------------------------| +| 📚 Dataset environments | GSM8K, MMLU | Evaluate and improve LLM performance on static data| +| 🎮 Online environments | Crosswords, Hangman | Train LLMs through interactive game-based learning | +| 🤖 RLAIF and RLHF | LLM Judge/Reward Models | Fine-tune LLMs using human feedback and alignment | +| 🔄 Multi-Turn RL | deepresearch, internal tool calling | Train LLMs on complex multi-step interactions | + +
+ +Atropos provides a robust, scalable framework for **Reinforcement Learning Environments with LLMs**. + +**Key Features:** + +* **Multi-Turn & Asynchronous RL:** Efficiently supports complex, multi-turn, and asynchronous interactions, decoupling environment steps from policy updates. +* **Inference Agnostic:** Integrates with standard inference APIs (e.g., OpenAI, vLLM, sgLang), enabling easy switching between LLM providers and frameworks. +* **Trainer Independent:** Offers a standardized training interface for experimenting with different RL algorithms and frameworks without major code changes. +* **Scalable & Decentralized:** Easily scale by launching more environment instances (locally or across decentralized resources) that contribute rollouts to a central service. +* **Diverse Environment Integration:** Manages many varied environment types concurrently for heterogeneous, multi-modal training. + +**Goal:** Provide a flexible, scalable, and standardized platform to accelerate LLM-based RL research across diverse, interactive settings. + +--- + +## 5. Navigating the Repo + +| Category | Description | +|-------------------------------|--------------------------------------------------| +| 📁 [`atroposlib/`](atroposlib/) | Core library containing base classes and utilities | +| 🎮 [`environments/`](environments/) | Collection of ready-to-use RL environments | +| 📚 [`example_trainer/`](example_trainer/) | Example training scripts and configurations | + +**Key Documents:** + +* **Base Environment Class:** `atroposlib/environments/README.md` (Detailed in Section 9 below) +* **Environments Overview:** `environments/README.md` (Detailed in Section 8 below) +* **Full Environment Config Options:** `CONFIG.md` (Detailed in Section 10 below) +* **Example Trainer:** `example_trainer/README.md` (Detailed in Section 7 below) +* **Contributing Guide:** `CONTRIBUTING.md` (Detailed in Section 11 below) +* **License:** `LICENSE.md` (Apache 2.0 license details) + +--- + +## 6. Installation + +Requires Python 3.10 or later. + +```bash +# Core usage +pip install -e . + +# Development (includes testing, linting tools) +pip install -e .[dev] + +# Running examples (includes dependencies like vLLM, transformers) +pip install -e .[examples] + +# Everything +pip install -e .[all] +``` + +**Important for Developers:** Install pre-commit hooks to ensure code quality: +```bash +pre-commit install +``` + +--- + +## 7. Quick Start Guide + +1. **Create Your First Environment:** + * Review the [Base Environment Class Documentation](#9-core-library-atroposlib) (Section 9). + * Examine existing environments in [`environments/`](#8-environments) for examples. + +2. **Run an Example Environment:** + ```bash + # Start the central API server (trajectory handler) in the background + run-api & + + # Start an environment server (e.g., GSM8K) connected to the API + python environments/gsm8k_server.py serve \ + --tokenizer_name="Qwen/Qwen2.5-1.5B-Instruct" \ + --model_name="Qwen/Qwen2.5-1.5B-Instruct" \ + --slurm False # Assuming local run, set True for SLURM cluster + ``` + *Note: The model and tokenizer names are examples.* + +3. **Training Your Model:** + * Refer to the [Example Trainer Guide](#7-training-with-the-example-trainer) (Section 7). + * Monitor progress via logging: completion lengths, eval accuracies, full rollouts/scores (see WandB image in original README). + * Multiple environments can run concurrently, pointing to the same `run-api` server. + +**Logging:** Environments provide detailed logging, tracking completion lengths, eval accuracies, full rollouts, scores, etc. Supports WandB integration. + +--- + +## 8. Environments + +The `environments/` directory contains various RL environments. + +### 8.1. Common Features Across Environments + +1. **Training/Test Split:** Typically 98% training, 2% test, with fixed random shuffling (seed 42). +2. **Metrics Tracking:** Includes percent correct buffer, completion lengths, Wandb integration, and rollout tracking. +3. **Token Management:** Maximum token length limits, statistics tracking, and optional length penalties. +4. **Evaluation:** Separate evaluation on the test set with comprehensive metrics logging. Supports multiple completions per prompt. +5. **Usage Interface:** Environments generally follow a common interface: + * Initialize with `config` (BaseEnvConfig), `server_configs` (OpenAI API configs), `slurm` (bool), `testing` (bool). + * Key methods: `setup()`, `get_next_item()`, `collect_trajectories()`, `score()` (often part of postprocessing), `evaluate()`, `wandb_log()`. + +### 8.2. Available Environments + +#### 8.2.1. MCQA Thinking Environment (`mcqa_thinking_env.py`) + +Multiple Choice Question Answering (MMLU dataset) requiring systematic thought. + +* **Input Format:** MMLU items (`prompt`, `answer` index, `ground_truth` letter, `options` list). +* **System Prompt:** + ``` + You are a deep thinking AI, you may use extremely long chains of thought to deeply consider the problem and deliberate with yourself via systematic reasoning processes to help come to a correct solution prior to answering. You should enclose your thoughts and internal monologue inside tags, and then provide your solution or response to the problem. + ``` +* **Reward Function:** + * 1.0 for correct letter match. + * 0.0 for incorrect or malformed response (e.g., bad `` tags). + * Length penalty applied *only if all responses in a group are correct*: scales linearly from 1.0 (<=50% max length) down to 0.0 (>=100% max length). + * Returns `None` if all scores in a group are identical (no training signal). + +#### 8.2.2. GSM8K Environment (`gsm8k_server.py`) + +Mathematical reasoning (GSM8K dataset). + +* **Input Format:** GSM8K items (`question`, `answer` number). +* **System Prompt:** + ``` + You are a deep thinking AI, you may use extremely long chains of thought to deeply consider the problem and deliberate with yourself via systematic reasoning processes to help come to a correct solution prior to answering. You should enclose your thoughts and internal monologue inside tags, and then provide your solution or response to the problem. + + You are allocated a maximum of 4096 tokens, please strive to use less. + + You will then provide your answer like this: \boxed{your answer here} + It is important that you provide your answer in the correct format. + If you do not, you will not receive credit for your answer. + So please end your answer with \boxed{your answer here} + ``` +* **Reward Function:** + * 1.0 if `\boxed{}` answer matches ground truth (uses LaTeX verification). + * 0.0 if incorrect or ground truth isn't parseable. + * Length penalty applied *only if all responses in a group are correct*: scales linearly from 1.0 (<=50% max length) down to 0.0 (>=100% max length). + * Returns `None` if all scores in a group are identical. + +#### 8.2.3. Tool Calling Environment (`tool_calling_server.py`) + +Training models for structured function/tool calls (ShareGPT-Hermes function call dataset). + +* **Input Format:** Conversations (`system`, `human`, `gpt` roles) with expected tool calls (JSON format). +* **System Prompt:** (Same "deep thinking AI" prompt as MCQA) + ``` + You are a deep thinking AI, you may use extremely long chains of thought to deeply consider the problem and deliberate with yourself via systematic reasoning processes to help come to a correct solution prior to answering. You should enclose your thoughts and internal monologue inside tags, and then provide your solution or response to the problem. + ``` +* **Reward Function:** + * 1.0 if *all* expected tool calls are present and *exactly* match (including nested JSON). + * 0.0 if any calls are missing, incorrect, or malformed. + * Length penalty applied *only if all responses in a group are correct*: scales linearly from 1.0 (<=50% max length) down to 0.0 (>=100% max length). + * Returns `None` if all scores in a group are identical. + +--- + +## 9. Training with the Example Trainer + +The `example_trainer/` directory provides `grpo.py`, a script demonstrating integration with Atropos using the GRPO algorithm. + +**Note:** This is a *reference example* for API integration and basic setup, *not* optimized for large-scale training. It uses `vLLM` for inference (simulated data generation) and `transformers` for training. + +### 9.1. Prerequisites + +1. Python 3.8+. +2. Running Atropos API server (default: `http://localhost:8000`). Accessible via `run-api`. +3. Required Python packages: `torch`, `transformers`, `vllm`, `pydantic`, `numpy`, `requests`, `tenacity`, `wandb` (optional). Install via `pip install -r example_trainer/requirements.txt` or `pip install -e .[examples]`. +4. A running Atropos environment (e.g., `python environments/gsm8k_server.py serve --slurm False`). + +### 9.2. Setup + +1. Clone the Atropos repository. +2. Install dependencies (see Prerequisites). +3. Start the Atropos API: `run-api`. +4. Start an environment connected to the API (e.g., GSM8K example above). + +### 9.3. Configuration (`grpo.py`) + +Configuration is managed via the `TrainingConfig` Pydantic model within `grpo.py`. + +**Key Parameters:** + +* `model_name`: Hugging Face model identifier (e.g., `"Qwen/Qwen2.5-1.5B-Instruct"`). +* `training_steps`: Total optimization steps. +* `batch_size` / `gradient_accumulation_steps`: Control effective batch size. +* `lr`: Learning rate. +* `save_path`: Directory for model checkpoints (default: `./trained_model_checkpoints`). +* `vllm_port`: Port for the script's vLLM inference server instance. +* `vllm_restart_interval`: Steps between saving checkpoints and restarting vLLM with updated weights. +* `use_wandb`: Enable/disable Weights & Biases logging. +* `wandb_project`: W&B project name (required if `use_wandb=True`). +* `wandb_group`: Optional W&B group name. + +**API Endpoints:** Assumes API at `http://localhost:8000`. Modify `register_trainer` and `get_batch` functions if different. + +### 9.4. Running the Example + +Navigate to the project root and run: + +```bash +python example_trainer/grpo.py +``` + +### 9.5. Output + +* **Console Logs:** Training progress (loss, logp), vLLM status. +* **Checkpoints:** Saved periodically in `save_path`. `final_model` directory upon completion. +* **WandB:** Logs sent to W&B if enabled (link printed to console). +* `temp.json`: Raw data from the last fetched batch (for debugging). + +--- + +## 10. Core Library (`atroposlib`) + +The `atroposlib/` directory contains the core framework components. + +### 10.1. Base Environment (`atroposlib.envs.base.BaseEnv`) + +This class provides the foundation for creating custom RL environments. Subclass `BaseEnv` and implement/override methods as needed. + +**Core Methods to Implement:** + +* **`async def setup(self)`**: Called once at the start. Use for initial setup (loading data, models, etc.). +* **`async def get_next_item(self) -> Item`**: Returns the next data item (prompt, state) for trajectory collection. Return `None` to pause the worker if no items are ready. `Item` is typically a Pydantic model defined by the environment. +* **`async def collect_trajectory(self, item: Item) -> Tuple[Any | None, List[Item]]`**: Defines logic for *one* trajectory collection step based on `item`. The base class runs this in parallel (`group_size` times). Returns a tuple: `(collected_data_for_this_step, list_of_new_backlog_items)`. The collected data can be any type suitable for later processing. +* **`async def evaluate(self, *args, **kwargs)`**: Called periodically (`steps_per_eval`) for evaluation runs. Implement your evaluation logic here. The base class provides `self.eval_workers` for parallel tasks. + +**Optional Methods to Override:** + +* **`async def collect_trajectories(self, item: Item) -> Tuple[Union[Optional[ScoredDataGroup], List[Optional[ScoredDataGroup]], List[Any | None]], List[Item]]`**: Override this *instead* of `collect_trajectory` for custom batch generation logic (generating the whole group at once). `ScoredDataGroup` is a structure usually containing prompts, responses, and scores. +* **`async def postprocess_histories(self, trajectories: Union[Optional[ScoredDataGroup], List[Optional[ScoredDataGroup]]]) -> Union[Optional[ScoredDataGroup], List[Optional[ScoredDataGroup]]]`**: Called after `collect_trajectories` and before sending data to the server. Use for final processing, scoring, filtering, or formatting of the collected group data. +* **`async def wandb_log(self, wandb_metrics: Optional[Dict] = None)`**: Called periodically for W&B logging. Add custom metrics to `wandb_metrics`. **Crucially, call `await super().wandb_log(wandb_metrics)`** at the end to include base metrics and rollouts. +* **`save_checkpoint(self, step, data=None)`**: Called automatically by the server based on `checkpoint_interval`. Saves the provided `data` dict (populated with environment state) to JSON. Override to customize *what* or *how* data is saved. +* **`@classmethod config_init(cls) -> Tuple[BaseEnvConfig, Union[ServerBaseline, List[OpenaiConfig]]]`**: Used by CLI `serve` command setup. Returns initial `BaseEnvConfig` and server config(s). Override for custom default CLI configurations. Default returns `cls.env_config_cls(), ServerBaseline()`. +* **`async def cleanup(self)`**: Called after each item processing (`handle_env`). Use for per-item cleanup if needed (rarely required). + +**Provided Functionality:** + +* **Parallel Trajectory Collection:** Base `collect_trajectories` handles running `collect_trajectory` in parallel. +* **Server Interaction:** Handles registration, config fetching, data sending (with retries via `handle_send_to_api`), status updates. +* **WandB Integration:** Setup, logging hook (`wandb_log`), rollout table helpers (`add_rollouts_for_wandb`, `create_rollout_table`). +* **Checkpointing:** Automatic triggering via server (`checkpoint_interval`), `save_checkpoint` method, automatic loading via `load_checkpoint(self)` on startup if `curr_step > 0`. +* **Worker Management:** Asynchronous task management (`add_train_workers`, `handle_env`). +* **Performance Monitoring:** Tracks and logs task durations, worker counts, etc. +* **CLI Integration:** `cli()` class method using `pydantic-cli` for easy `serve` commands. See `get_cli_serve_config_cls` and `get_cli_process_config_cls`. + +### 10.2. Configuration Options (`atroposlib`) + +Configuration is primarily managed via Pydantic models, often exposed through a CLI (`pydantic-cli`). + +#### 10.2.1. Base Environment Config (`atroposlib.envs.base.BaseEnvConfig`) + +| Parameter | Type | Default | Description | +| :------------------------------- | :----------------------- | :---------------------------------------------- | :--------------------------------------------------------------------------------------------------------- | +| `group_size` | `int` | `4` | Number of responses grouped for scoring. | +| `max_num_workers` | `int` | `-1` | Max workers. `-1` calculates from `max_num_workers_per_node`. | +| `max_eval_workers` | `int` | `16` | Max workers for evaluation. | +| `max_num_workers_per_node` | `int` | `8` | Max workers per node. | +| `steps_per_eval` | `int` | `100` | Steps between evaluations. | +| `max_token_length` | `int` | `2048` | Max token length for generations. | +| `eval_handling` | `EvalHandlingEnum` | `EvalHandlingEnum.STOP_TRAIN` | How evals affect training workers (`STOP_TRAIN`, `LIMIT_TRAIN`, `NONE`). | +| `eval_limit_ratio` | `float` | `0.5` | Ratio of training workers limited during evals (if `eval_handling` is `LIMIT_TRAIN`). | +| `inference_weight` | `float` | `1.0` | Inference weight (set by trainer/policy). `-1` ignores if handled specially. | +| `batch_size` | `int` | `-1` | Training batch size (usually set by trainer via API). | +| `max_batches_offpolicy` | `int` | `3` | Max number of off-policy batches queued. | +| `tokenizer_name` | `str` | `"NousResearch/DeepHermes-3-Llama-3-1B-Preview"` | Default Hugging Face tokenizer. | +| `use_wandb` | `bool` | `True` | Enable/disable W&B logging. | +| `rollout_server_url` | `str` | `"http://localhost:8000"` | URL of the central rollout server (FastAPI). | +| `total_steps` | `int` | `1000` | Total steps to run (can be overridden by trainer). | +| `wandb_name` | `str | None` | `None` | W&B run name (often set automatically). | +| `num_rollouts_to_keep` | `int` | `32` | Number of full rollouts to display on W&B table. | +| `num_rollouts_per_group_for_logging` | `int` | `1` | Rollouts per group to keep for logging. `-1` keeps all. | +| `ensure_scores_are_not_same` | `bool` | `True` | Ensure scores in a group aren't identical (reject group if they are). Set `False` if identical scores are valid. | +| `data_path_to_save_groups` | `str | None` | `None` | If set, save generated/scored groups to this JSONL file path. | +| `min_items_sent_before_logging` | `int` | `2` | Min API sends before logging metrics. `<=0` logs every time. | + +#### 10.2.2. Server Manager Config (`atroposlib.envs.server_handling.server_manager.ServerManagerConfig`) + +Settings for the `ServerManager` which handles inference server interactions. + +| Parameter | Type | Default | Description | +| :-------- | :------ | :------ | :------------------------------------------------ | +| `slurm` | `bool` | `True` | Whether the environment is running on SLURM. | +| `testing` | `bool` | `False` | If `True`, uses mock OpenAI data (for testing). | + +#### 10.2.3. Server Baseline Config (`atroposlib.envs.server_handling.server_manager.ServerBaseline`) + +Default settings used by `ServerManager` if specific `OpenaiConfig` list isn't provided (e.g., for local/SLURM discovery). + +| Parameter | Type | Default | Description | +| :------------------------- | :------ | :-------- | :------------------------------------------------------------------------------------------------------ | +| `timeout` | `int` | `1200` | Request timeout (seconds). | +| `num_max_requests_at_once` | `int` | `512` | Max concurrent requests (training). Divide by generation `n` param. | +| `num_requests_for_eval` | `int` | `64` | Max concurrent requests (evaluation). | +| `model_name` | `str` | `default` | Default model name for inference calls. | +| `rolling_buffer_length` | `int` | `1000` | Buffer length for server metrics (timings, attempts). | + +#### 10.2.4. OpenAI Server Config (`atroposlib.envs.server_handling.openai_server.OpenaiConfig`) + +Configuration for individual OpenAI-compatible API servers (official OpenAI, local vLLM/SGLang, etc.). A list of these can be passed to the environment. + +| Parameter | Type | Default | Description | +| :------------------------- | :----------- | :-------- | :------------------------------------------------------------------------------------------------------ | +| `api_key` | `str | None` | `None` | API key. Use `"x"` or any non-empty string for local servers without auth. `None` might imply env var. | +| `base_url` | `str | None` | `None` | API endpoint URL. `None` for official OpenAI. Local: e.g., `http://localhost:9004/v1`. | +| `timeout` | `int` | `1200` | Request timeout (seconds). | +| `num_max_requests_at_once` | `int` | `512` | Max concurrent requests (training). Divide by generation `n`. | +| `num_requests_for_eval` | `int` | `64` | Max concurrent requests (evaluation). | +| `model_name` | `str` | `default` | **Required.** Model name for this server (e.g., `"gpt-4"`, `"NousResearch/..."`). | +| `rolling_buffer_length` | `int` | `1000` | Buffer length for this server's metrics. | + +--- + +## 11. Debugging Tools + +The trajectory-handler provides local debugging tools: + +* **Flexible Model Provider Support:** Natively supports any OpenAI API-compliant provider. Provide `base_url` and `api_key` for local testing/running. +* **View Run (`view-run`):** Launch a Gradio UI after starting the API (`run-api`) and an environment (`python environments/gsm8k_server.py serve`). Use `view-run` command to inspect batches of rollouts visually. +* **Offline Data Generation:** + * `atropos-sft-gen`: Collect rollouts and format for Supervised Fine-Tuning (SFT). + * `atropos-dpo-gen`: Collect rollouts and format for Direct Preference Optimization (DPO). + +--- + +## 12. Contributing to Atropos + +Contributions are welcome! Follow these guidelines. + +### 12.1. How We Develop + +* **GitHub:** Used for hosting, issue tracking, and Pull Requests (PRs). +* **GitHub Flow:** Development happens via PRs merged into the `main` branch. + +### 12.2. Getting Started + +1. **Fork:** Create your copy of the [repository](https://github.com/NousResearch/atropos). +2. **Clone:** `git clone https://github.com/your-username/atropos.git && cd atropos` +3. **Setup Dev Env:** + ```bash + python -m venv .venv + source .venv/bin/activate # Windows: .venv\Scripts\activate + pip install -e ".[dev]" # Installs core + dev dependencies + ``` +4. **Install Pre-commit Hooks:** + ```bash + pre-commit install + ``` + (Runs linters/formatters automatically on commit) + +### 12.3. Running Tests + +Uses `pytest`. +```bash +pytest +``` +Ensure all tests pass before submitting a PR. + +### 12.4. How to Contribute + +* **Reporting Bugs:** Use the **Bug Report** issue template on GitHub Issues. Provide details: summary, steps to reproduce, expected vs. actual behavior, environment info, error messages/logs. +* **Suggesting Enhancements:** Use the **Feature Request** issue template. Discuss the idea first via an issue. +* **Submitting Changes (Pull Requests):** + 1. Create a branch from `main`: `git checkout -b your-branch-name main` + 2. Make changes, write code. + 3. Add tests if applicable. + 4. Update documentation (READMEs, docstrings) if APIs change. + 5. Run tests: `pytest`. + 6. Ensure code quality (pre-commit hooks run on commit, or run manually: `pre-commit run --all-files`). + 7. Commit changes with clear messages: `git commit -m "feat: Describe feature or fix"` + 8. Push branch: `git push origin your-branch-name` + 9. Open a PR on GitHub from your fork's branch to `NousResearch/atropos:main`. + 10. **Use the correct PR template:** + * `environment_pull_request_template.md` for environment changes. + * `non_environment_pull_request_template.md` for other changes. + 11. Provide a clear title, description, link relevant issues (e.g., `Closes #123`). + +### 12.5. Code Style + +* PEP 8 enforced by `black`, `flake8`, `isort` via `pre-commit`. +* Manual check/fix: `pre-commit run --all-files`. Address `flake8` errors manually if needed. + +### 12.6. License for Contributions + +Contributions are submitted under the Apache License 2.0, the same license as the project. + +### 12.7. Environment Contribution Guidelines + +* **Legal/GitHub Compliance:** No illegal content. Must comply with GitHub TOS. +* **Explicit Content:** May be considered if clearly labeled and legally compliant. +* **Game Environments:** Welcome, but avoid reverse-engineered commercial games. Ensure rights to assets. Open-source/permissive licenses preferred. +* **Ethical Considerations:** Avoid environments encouraging harm without educational context. +* Discuss potentially controversial environments via an issue first. + +### 12.8. Contributor Code of Conduct + +Follow the [Contributor Code of Conduct](CODE_OF_CONDUCT.md). + +--- + +## 13. Citation + +If Atropos is helpful in your work, please cite: + +```latex +@misc{atropos, + title = {{Atropos - An Async First Environment Rollout Controller}}, + author = {Dakota Mahan, Roger Jin, Teknium, Shannon Sands, Artem Yatsenko, Jai Suphavadeeprasit, Karan Malhotra, Chen Guang, Joe Li}, + url = {https://www.github.com/NousResearch/Atropos}, + month = {4}, + year = {2025}, + version = {0.1}, +} +``` +*(Note: Year/Version might need updating)* + +--- + +## 14. License + +Atropos is licensed under the Apache License 2.0. See the [LICENSE](LICENSE.md) file for details. diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 00000000..9f3643f5 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,58 @@ +[project] +name = "atroposlib" +version = "0.1.0" +description = "Atropos: An Environment and Rollout handler for LLM RL" +readme = "README.md" +requires-python = ">=3.10" +dependencies = [ + "transformers==4.48.3", + "datasets", + "openai", + "aiohttp", + "tqdm", + "fastapi", + "uvicorn[standard]", + "tenacity", + "numpy", + "wandb", + "math-verify==0.7.0", + "jinja2", + "nltk", + "polars", + "aiofiles", + "jsonlines", + "torch", + "pydantic-cli", + "hf_transfer", +] + +[project.scripts] +run-api = "atroposlib.cli.run_api:main" +inference-node-wandb-watcher = "atroposlib.cli.inference_node_wandb_watcher:main" +view-run = "atroposlib.cli.view_run:main" +atropos-sft-gen = "atroposlib.cli.sft:main" + +[project.optional-dependencies] +all = [ + "atroposlib[dev,examples]" +] +dev = [ + "pytest", + "pytest-asyncio", + "pre-commit", + "black", + "flake8", + "isort", + "mypy", + 'rich', +] +examples = [ + "gradio" +] + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.hatch.build.targets.wheel] +packages = ["atroposlib"] diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 00000000..2aea22fc --- /dev/null +++ b/pytest.ini @@ -0,0 +1,7 @@ +[pytest] +asyncio_mode = auto +asyncio_default_fixture_loop_scope = function +python_files = test_*.py +python_classes = Test* +python_functions = test_* +testpaths = atroposlib/tests diff --git a/testing/__init__.py b/testing/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/testing/api/__init__.py b/testing/api/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/testing/api/testing.py b/testing/api/testing.py new file mode 100644 index 00000000..ccfe7b23 --- /dev/null +++ b/testing/api/testing.py @@ -0,0 +1,79 @@ +import pytest +import requests + +from testing.api.utils import launch_api_for_testing + + +def register_data(group="test", proj="test", batch_size=32) -> requests.Response: + x = requests.post( + "http://localhost:8000/register", + json={"wandb_group": group, "wandb_project": proj, "batch_size": batch_size}, + ) + return x + + +def post_scored_data( + tokens=((0,),), masks=((0,),), scores=(0,), ref_logprobs=((0,),) +) -> requests.Response: + data = { + "tokens": tokens, + "masks": masks, + "scores": scores, + } + if ref_logprobs is not None: + data["ref_logprobs"] = ref_logprobs + x = requests.post("http://localhost:8000/scored_data", json=data) + return x + + +def reset() -> requests.Response: + x = requests.get("http://localhost:8000/reset_data") + return x + + +@pytest.fixture(scope="session") +def api(): + proc = launch_api_for_testing() + yield + proc.kill() + + +def test_register(api): + x = register_data() + assert x.status_code == 200, x.text + data = x.json() + assert "uuid" in data + + +def test_reset(api): + x = register_data() + assert x.status_code == 200, x.text + data = x.json() + assert "uuid" in data + x = post_scored_data() + assert x.status_code == 200, x.text + x = reset() + print("0-0-0-0-0-0-0-0", flush=True) + print(x.text, flush=True) + print("0-0-0-0-0-0-0-0", flush=True) + assert x.status_code == 200, x.text + x = requests.get("http://localhost:8000/info") + assert x.status_code == 200 + assert x.json()["batch_size"] == -1 + x = requests.get("http://localhost:8000/status") + assert x.status_code == 200, x.text + data = x.json() + assert data["current_step"] == 0 + assert data["queue_size"] == 0 + x = requests.get("http://localhost:8000/wandb_info") + assert x.status_code == 200, x.text + data = x.json() + assert data["group"] is None + assert data["project"] is None + + +def test_batch_size(api): + x = register_data() + assert x.status_code == 200, x.text + # get the batch size + x = requests.get("http://localhost:8000/info") diff --git a/testing/api/utils.py b/testing/api/utils.py new file mode 100644 index 00000000..4c4d157d --- /dev/null +++ b/testing/api/utils.py @@ -0,0 +1,27 @@ +import multiprocessing +import time + +import requests + +from atroposlib.cli.run_api import main as run_api_main + + +def check_api_running() -> bool: + try: + data = requests.get("http://localhost:8000/info") + return data.status_code == 200 + except requests.exceptions.ConnectionError: + return False + + +def launch_api_for_testing(max_wait_for_api: int = 10) -> multiprocessing.Process: + api_proc = multiprocessing.Process(target=run_api_main) + api_proc.start() + counter = 0 + while not check_api_running(): + time.sleep(1) + counter += 1 + if counter > max_wait_for_api: + raise TimeoutError("API server did not start in time.") + print("API server started for testing.") + return api_proc diff --git a/testing/testing.md b/testing/testing.md new file mode 100644 index 00000000..ed7b5019 --- /dev/null +++ b/testing/testing.md @@ -0,0 +1,8 @@ +## Running Tests + +This section contains instructions and guidelines for running the test suite. +Please ensure all tests pass before submitting contributions. + +We use `pytest` for our testing framework. + +Simply run `pytest` from the main directory and you're good to go. diff --git a/testing/utils/test_tokenize_for_trainer.py b/testing/utils/test_tokenize_for_trainer.py new file mode 100644 index 00000000..032d0929 --- /dev/null +++ b/testing/utils/test_tokenize_for_trainer.py @@ -0,0 +1,117 @@ +import json +import logging +import os + +from transformers import AutoTokenizer + +from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer + +MESSAGES = [ + { + "role": "system", + "content": "You are a helpful AI assistant that provides accurate information.", + }, + {"role": "user", "content": "What's the capital of France?"}, + {"role": "assistant", "content": "The capital of France is Paris."}, + {"role": "user", "content": "Can you tell me more about Paris?"}, + { + "role": "assistant", + "content": "{'tool_name': 'web_search', 'args': {'query': 'Paris'}}", + }, + { + "role": "tool", + "content": ( + "Paris is the capital and most populous city of France. " + "It has an estimated population of 2,165,423 residents in 2019 " + "in an area of more than 105 km²." + ), + }, + { + "role": "assistant", + "content": ( + "Paris is indeed the capital of France and its most populous city with over 2 million residents. " + "It's known for its iconic landmarks like the Eiffel Tower, Louvre Museum, and Notre-Dame Cathedral. " + "The city is a global center for art, fashion, gastronomy, and culture." + ), + }, +] + +TEST_MASKS_PATH = os.path.join(os.path.dirname(__file__), "test_masks.json") +with open(TEST_MASKS_PATH) as f: + TEST_MASKS = json.load(f) + + +def test_tokenize_for_trainer_mask_len_last_turn_only(): + # random model with chat templates and isn't gated + try: + tok = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B-Instruct") + can_run_stop = True + except (ValueError, EnvironmentError): + can_run_stop = False + tok = AutoTokenizer.from_pretrained("Zyphra/Zamba2-1.2B-instruct") + logging.warning( + "Could not use gated model, using non-gated model that is bad at tokenizing..." + ) + messages = [ + {"role": "user", "content": "hi"}, + {"role": "assistant", "content": "hello"}, + ] + total_toks = tok.apply_chat_template(messages) + prefix = tok.apply_chat_template(messages[:1], add_generation_prompt=True) + resp = tokenize_for_trainer(tok, messages, False) + assert len(resp["tokens"]) == len(total_toks) == len(resp["masks"]) + assert resp["tokens"] == total_toks + assert all([x == -100 for x in resp["masks"][: len(prefix)]]) + assert all([x != -100 for x in resp["masks"][len(prefix) :]]) + assert resp.get("messages", None is None) + # This time with add messages + messages = [ + {"role": "user", "content": "hi"}, + {"role": "assistant", "content": "hello"}, + ] + resp = tokenize_for_trainer(tok, messages, True) + assert resp["tokens"] == total_toks + assert len(resp["tokens"]) == len(total_toks) == len(resp["masks"]) + assert all([x == -100 for x in resp["masks"][: len(prefix)]]) + assert all([x != -100 for x in resp["masks"][len(prefix) :]]) + assert resp["messages"] == messages + if can_run_stop: + # now try with finish reason == stop + messages = [ + {"role": "user", "content": "hi"}, + {"role": "assistant", "content": "hello"}, + ] + resp = tokenize_for_trainer(tok, messages, False, finish_reason="length") + assert len(resp["tokens"]) == len(total_toks) - 1 == len(resp["masks"]) + assert resp["tokens"] == total_toks[:-1] + assert all([x == -100 for x in resp["masks"][: len(prefix)]]) + assert all([x != -100 for x in resp["masks"][len(prefix) :]]) + assert resp.get("messages", None is None) + + +def test_last_turn_only_masking(): + """ + Test that in last turn only mode, only the tokens from the final assistant turn are unmasked + (mask != -100) while all tokens from all other messages are masked (mask == -100). + """ + tok = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct") + result = tokenize_for_trainer( + tok, MESSAGES, include_messages=False, train_on_all_assistant_turns=False + ) + + masks = result["masks"] + assert masks == TEST_MASKS["last_turn_only"] + + +def test_all_assistant_turns_masking(): + """ + Test that in all assistant turns mode, tokens for every assistant message are unmasked, + while tokens for non-assistant messages remain masked. + """ + tok = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct") + result = tokenize_for_trainer( + tok, MESSAGES, include_messages=False, train_on_all_assistant_turns=True + ) + + masks = result["masks"] + assert masks == TEST_MASKS["all_assistant_turns"]