diff --git a/.env.example b/.env.example index e570b8b5..545ad9fa 100644 --- a/.env.example +++ b/.env.example @@ -1 +1,2 @@ OPENAI_API_KEY= +OPENROUTER_API_KEY= diff --git a/.github/workflows/upload_to_pypi.yml b/.github/workflows/upload_to_pypi.yml new file mode 100644 index 00000000..b88557a5 --- /dev/null +++ b/.github/workflows/upload_to_pypi.yml @@ -0,0 +1,95 @@ +name: Publish Python 🐍 distribution 📦 to PyPI and TestPyPI + +on: push + +jobs: + build: + name: Build distribution 📦 + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + with: + persist-credentials: false + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.x" + - name: Install pypa/build + run: >- + python3 -m + pip install + build + --user + - name: Build a binary wheel and a source tarball + run: python3 -m build + - name: Store the distribution packages + uses: actions/upload-artifact@v4 + with: + name: python-package-distributions + path: dist/ + + publish-to-pypi: + name: >- + Publish Python 🐍 distribution 📦 to PyPI + if: startsWith(github.ref, 'refs/tags/') # only publish to PyPI on tag pushes + needs: + - build + runs-on: ubuntu-latest + environment: + name: pypi + url: https://pypi.org/p/atroposlib # Replace with your PyPI project name + permissions: + id-token: write # IMPORTANT: mandatory for trusted publishing + + steps: + - name: Download all the dists + uses: actions/download-artifact@v4 + with: + name: python-package-distributions + path: dist/ + - name: Publish distribution 📦 to PyPI + uses: pypa/gh-action-pypi-publish@release/v1 + + github-release: + name: >- + Sign the Python 🐍 distribution 📦 with Sigstore + and upload them to GitHub Release + needs: + - publish-to-pypi + runs-on: ubuntu-latest + + permissions: + contents: write # IMPORTANT: mandatory for making GitHub Releases + id-token: write # IMPORTANT: mandatory for sigstore + + steps: + - name: Download all the dists + uses: actions/download-artifact@v4 + with: + name: python-package-distributions + path: dist/ + - name: Sign the dists with Sigstore + uses: sigstore/gh-action-sigstore-python@v3.0.0 + with: + inputs: >- + ./dist/*.tar.gz + ./dist/*.whl + - name: Create GitHub Release + env: + GITHUB_TOKEN: ${{ github.token }} + run: >- + gh release create + "$GITHUB_REF_NAME" + --repo "$GITHUB_REPOSITORY" + --notes "" + - name: Upload artifact signatures to GitHub Release + env: + GITHUB_TOKEN: ${{ github.token }} + # Upload to GitHub Release using the `gh` CLI. + # `dist/` contains the built packages, and the + # sigstore-produced signatures and certificates. + run: >- + gh release upload + "$GITHUB_REF_NAME" dist/** + --repo "$GITHUB_REPOSITORY" diff --git a/CONFIG.md b/CONFIG.md index 6f8120fc..cc02a162 100644 --- a/CONFIG.md +++ b/CONFIG.md @@ -42,7 +42,7 @@ Settings for the `ServerManager`. ## Server Baseline Configuration (`atroposlib.envs.server_handling.server_manager.ServerBaseline`) -Baseline configuration used by `ServerManager` if a list of `OpenaiConfig` is not provided, particularly for setting up local or SLURM-based server discovery. +Baseline configuration used by `ServerManager` if a list of `APIServerConfig` is not provided, particularly for setting up local or SLURM-based server discovery. | Parameter | Type | Default | Description | | :------------------------- | :------ | :-------- | :------------------------------------------------------------------------------------------------------ | @@ -52,7 +52,7 @@ Baseline configuration used by `ServerManager` if a list of `OpenaiConfig` is no | `model_name` | `str` | `default` | Model name to use when calling inference servers. | | `rolling_buffer_length` | `int` | `1000` | Length of the rolling buffer to store server metrics (like request timings, attempts). | -## OpenAI Server Configuration (`atroposlib.envs.server_handling.openai_server.OpenaiConfig`) +## OpenAI Server Configuration (`atroposlib.envs.server_handling.openai_server.APIServerConfig`) Configuration for individual OpenAI-compatible API servers (including local SGLang/vLLM instances). @@ -65,3 +65,4 @@ Configuration for individual OpenAI-compatible API servers (including local SGLa | `num_requests_for_eval` | `int` | `64` | Maximum number of concurrent requests for evaluation. | | `model_name` | `str` | `default` | The model name to use. Required for both OpenAI and local models (e.g., `"gpt-4"`, `"NousResearch/..."`). | | `rolling_buffer_length` | `int` | `1000` | Length of the rolling buffer to store server metrics (like request timings, attempts). | +| `n_kwarg_is_ignored` | `bool` | `False` | If the n kwarg is ignored by the API you are using, set this to True. | diff --git a/README.md b/README.md index 2759cdc7..6c004cfb 100644 --- a/README.md +++ b/README.md @@ -211,6 +211,29 @@ Environments come with detailed logging and reporting support, runs track comple --- +# Trainer Integrations +## Axolotl + + Atropos plugin logo + + +Axolotl is a powerful tool for fine-tuning a wide range of AI models, supporting techniques like LoRA and QLoRA through simple YAML configurations. + +The [Atropos plugin for Axolotl](https://github.com/axolotl-ai-cloud/plugin-atropos) seamlessly integrates Atropos' RL environments into Axolotl's training pipelines. +This allows you to leverage Atropos for reinforcement learning while utilizing Axolotl's extensive features for model fine-tuning. + +To use, follow the readme on the [plugin repository](https://github.com/axolotl-ai-cloud/plugin-atropos). + +## Atropos' Example Trainer +Atropos repo contains an example trainer that should primarily be used as a reference example to show how a trainer and inference provider can be integrated with Atropos to complete the RL Training Loop. + +To use the example trainer, see this page: [training example guide](example_trainer/README.md) + +--- + ## Testing and Debugging Tools The trajectory-handler provides several debugging tools to help environment developers test and understand their environments locally without requiring the full distributed infrastructure. @@ -248,7 +271,7 @@ Rejection sampling can be controlled via `--save-top-n-per-group`, `--allow-nega If you would like to use OpenAI models, please edit your `config_init` to something like the following: ```python @classmethod - def config_init(cls) -> Tuple[BaseEnvConfig, List[OpenaiConfig]]: + def config_init(cls) -> Tuple[BaseEnvConfig, List[APIServerConfig]]: env_config = BaseEnvConfig( tokenizer_name="Qwen/Qwen2.5-1.5B-Instruct", group_size=8, @@ -261,7 +284,7 @@ If you would like to use OpenAI models, please edit your `config_init` to someth wandb_name="gsm8k", ) server_configs = [ - OpenaiConfig( + APIServerConfig( model_name="gpt-4.1-nano", base_url=None, api_key=os.environ.get("OPENAI_API_KEY"), diff --git a/SLURM.md b/SLURM.md index bd070e29..65d371ce 100644 --- a/SLURM.md +++ b/SLURM.md @@ -18,7 +18,7 @@ When you initialize `ServerManager` with `slurm=True`: * Servers run on ports starting from `9000` (`9000`, `9001`, `9002`, ...). * The number of server instances per node is determined by `8 // INFER_TP` (where `INFER_TP` is another environment variable, defaulting to 1 if not set, implying 8 servers per node). You should set `INFER_TP` according to your inference server's tensor parallelism configuration if applicable. * The URL format is `http://{node_hostname}:{port}/v1`. -5. It uses the *first* configuration object you pass in the `configs` list as a template (for settings like `timeout`, `num_max_requests_at_once`, etc.) and creates specific `OpenaiConfig` objects for each discovered URL. +5. It uses the *first* configuration object you pass in the `configs` list as a template (for settings like `timeout`, `num_max_requests_at_once`, etc.) and creates specific `APIServerConfig` objects for each discovered URL. 6. The `ServerManager` then load-balances requests across these automatically configured `OpenAIServer` instances. **Setup Steps:** @@ -32,18 +32,18 @@ When you initialize `ServerManager` with `slurm=True`: * `export INFER_TP=` (Optional, defaults to 1. Set this if your inference servers use tensor parallelism and you run fewer than 8 instances per node). 3. **Initialize `ServerManager`:** In your Python script: ```python - from atroposlib.envs.server_handling.server_manager import ServerManager, ServerBaseline, OpenaiConfig + from atroposlib.envs.server_handling.server_manager import ServerManager, ServerBaseline, APIServerConfig # Provide at least one config object. It will be used as a template # for Slurm-discovered servers if slurm=True. # If you pass ServerBaseline, ensure NUM_TRAINING_NODES and potentially INFER_TP are set. - # If you pass a list of OpenaiConfig, the first one is used as the template. + # If you pass a list of APIServerConfig, the first one is used as the template. base_config = ServerBaseline( timeout=1200, # other baseline settings... ) # OR - # base_config = OpenaiConfig( + # base_config = APIServerConfig( # base_url="http://dummy", # This URL is ignored when slurm=True finds nodes # api_key="dummy", # timeout=1200, @@ -51,7 +51,7 @@ When you initialize `ServerManager` with `slurm=True`: # ) server_manager = ServerManager( - configs=base_config, # Or [base_config] if using OpenaiConfig + configs=base_config, # Or [base_config] if using APIServerConfig slurm=True ) @@ -128,7 +128,7 @@ wait # Wait for background server processes launched with '&' * This setup relies on the `scontrol` command being available in the environment where `ServerManager` is initialized. * Ensure network connectivity and firewall rules allow the training node(s) to reach the inference nodes on ports 9000+. -* The logic assumes a specific port assignment (9000+) and server count based on `INFER_TP`. If your inference server setup differs (e.g., different ports, different discovery mechanism), you would need to modify `server_manager.py` or manually provide the correct list of `OpenaiConfig` objects instead of relying on `slurm=True`. +* The logic assumes a specific port assignment (9000+) and server count based on `INFER_TP`. If your inference server setup differs (e.g., different ports, different discovery mechanism), you would need to modify `server_manager.py` or manually provide the correct list of `APIServerConfig` objects instead of relying on `slurm=True`. ## Monitoring Inference Nodes with Weights & Biases diff --git a/atroposlib/cli/run_api.py b/atroposlib/cli/run_api.py index 5a04aa07..0e5e0a40 100644 --- a/atroposlib/cli/run_api.py +++ b/atroposlib/cli/run_api.py @@ -1,8 +1,28 @@ +""" +Run the Trajectory API server. +""" + +import argparse + import uvicorn def main(): - uvicorn.run("atroposlib.api:app", host="0.0.0.0", port=8000, reload=True) + """ + Run the API server. + Args: + host: The host to run the API server on. + port: The port to run the API server on. + reload: Whether to reload the API server on code changes. + """ + parser = argparse.ArgumentParser() + parser.add_argument("--host", type=str, default="0.0.0.0") + parser.add_argument("--port", type=int, default=8000) + parser.add_argument("--reload", action="store_true") + args = parser.parse_args() + uvicorn.run( + "atroposlib.api:app", host=args.host, port=args.port, reload=args.reload + ) if __name__ == "__main__": diff --git a/atroposlib/envs/README.md b/atroposlib/envs/README.md index b7c609fd..cd9aa20b 100644 --- a/atroposlib/envs/README.md +++ b/atroposlib/envs/README.md @@ -58,10 +58,38 @@ These methods have default implementations or are optional based on your needs: * **`save_checkpoint(self, step, data=None)`**: The base class calls this method automatically at checkpoint intervals determined by the server. It saves the provided `data` dictionary (which you might populate with environment-specific state) to a JSON file. You can override this to customize *what* data is saved or *how* it's saved (e.g., using a different format or location), but the triggering mechanism remains automatic. -* **`@classmethod config_init(cls) -> Tuple[BaseEnvConfig, Union[ServerBaseline, List[OpenaiConfig]]]`**: This class method is used by the default `get_cli_serve_config_cls` implementation to get the initial environment configuration (`BaseEnvConfig` subclass) and server configurations (`ServerBaseline` or `List[OpenaiConfig]`) when setting up the `serve` command. The default implementation returns `cls.env_config_cls(), ServerBaseline()`. You might override this if your environment requires different default configurations or specific server setups (like multiple `OpenaiConfig` instances) when run via the CLI `serve` command. +* **`@classmethod config_init(cls) -> Tuple[BaseEnvConfig, Union[ServerBaseline, List[APIServerConfig]]]`**: This class method is used by the default `get_cli_serve_config_cls` implementation to get the initial environment configuration (`BaseEnvConfig` subclass) and server configurations (`ServerBaseline` or `List[APIServerConfig]`) when setting up the `serve` command. The default implementation returns `cls.env_config_cls(), ServerBaseline()`. You might override this if your environment requires different default configurations or specific server setups (like multiple `APIServerConfig` instances) when run via the CLI `serve` command. * **`async def cleanup(self)`**: Called after each call to `handle_env`. You can implement this for any cleanup needed after processing a single item, though it's often not required. +## Overrideable Class Variables + +These class-level variables in `BaseEnv` can be overridden in your subclass to customize its behavior: + +* **`name: Optional[str]`**: + * Default: `None` + * Purpose: You can set a string name for your environment. This name is used by default for `wandb_name` in the `BaseEnvConfig` if not otherwise specified, influencing how runs are grouped or named in Weights & Biases. It can also be useful for general identification or logging purposes. + +* **`env_config_cls: Type[BaseEnvConfig]`**: + * Default: `BaseEnvConfig` + * Purpose: This variable holds the Pydantic model class that will be used for your environment's configuration. If your environment requires custom configuration fields beyond what `BaseEnvConfig` offers, you should create a new class that inherits from `BaseEnvConfig` (or a subclass thereof) and assign it to `env_config_cls`. This allows the CLI and other parts of the system to correctly parse and manage your environment's specific settings. + ```python + from pydantic import Field + from atroposlib.envs import BaseEnv, BaseEnvConfig + + class MyEnvConfig(BaseEnvConfig): + my_custom_param: str = Field(default="default_value", description="A custom parameter for MyEnv") + + class MyEnv(BaseEnv): + env_config_cls = MyEnvConfig + name = "MyCustomEnvironment" + # ... other implementations + ``` + +* **`server_cls: Type[APIServer]`**: + * Default: `APIServer` + * Purpose: Specifies the class to be used for managing interactions with API servers (e.g., inference endpoints). Should mostly be used for developing addiitonal API interfaces, but if you need a nonstandard way of connecting with an existing API you can use this to easily slot in any modifications you need. + ## Provided Functionality `BaseEnv` provides several helpful features: diff --git a/atroposlib/envs/base.py b/atroposlib/envs/base.py index 719eb59a..1a6e4427 100644 --- a/atroposlib/envs/base.py +++ b/atroposlib/envs/base.py @@ -40,7 +40,8 @@ from atroposlib.utils.metrics import get_std_min_max_avg from ..type_definitions import Item, Message from .server_handling.server_manager import ( - OpenaiConfig, + APIServer, + APIServerConfig, ServerBaseline, ServerManager, ServerManagerConfig, @@ -163,13 +164,14 @@ class BaseEnvConfig(BaseModel): class BaseEnv(ABC): - name = None - env_config_cls = BaseEnvConfig + name: Optional[str] = None + env_config_cls: BaseEnvConfig = BaseEnvConfig + server_cls: APIServer = APIServer def __init__( self, config: BaseEnvConfig, - server_configs: Union[ServerBaseline, List[OpenaiConfig]], + server_configs: Union[ServerBaseline, List[APIServerConfig]], slurm=False, testing=False, ): @@ -184,7 +186,9 @@ class BaseEnv(ABC): self.last_loop_time = None self.last_completed_item = None self.config = config - self.server = ServerManager(server_configs, slurm=slurm, testing=testing) + self.server = ServerManager( + server_configs, slurm=slurm, testing=testing, server_class=self.server_cls + ) self.workers = set() self.eval_workers = set() self.backlog = [] @@ -234,7 +238,7 @@ class BaseEnv(ABC): @classmethod def config_init( cls, - ) -> Tuple[BaseEnvConfig, Union[ServerBaseline, List[OpenaiConfig]]]: + ) -> Tuple[BaseEnvConfig, Union[ServerBaseline, List[APIServerConfig]]]: """ Initialize the config """ @@ -1020,7 +1024,6 @@ class BaseEnv(ABC): Returns: type: The CliServeConfig class for serving commands. """ - # Get the default configurations defined by the specific environment class default_env_config, default_server_configs = cls.config_init() @@ -1032,8 +1035,8 @@ class BaseEnv(ABC): class CliServeConfig( get_prefixed_pydantic_model(type(default_env_config), env_full_prefix), get_prefixed_pydantic_model( - OpenaiConfig, openai_full_prefix - ), # Use OpenaiConfig for CLI args + APIServerConfig, openai_full_prefix + ), # Use APIServerConfig for CLI args ServerManagerConfig, # ServerManager args are not namespaced by default Cmd, ): @@ -1089,7 +1092,7 @@ class BaseEnv(ABC): oai_cli_passed_args or yaml_oai_config ): raise ValueError( - "ServerBaseline is not compatible with OpenAI-namespaced CLI arguments. Please edit `config_init` directly or use OpenaiConfig." # noqa: E501 + "ServerBaseline is not compatible with OpenAI-namespaced CLI arguments. Please edit `config_init` directly or use APIServerConfig." # noqa: E501 ) if ( isinstance(default_server_configs, list) @@ -1101,11 +1104,11 @@ class BaseEnv(ABC): default_openai_config_ = default_server_configs if isinstance(yaml_oai_config, list) and len(yaml_oai_config) == 1: yaml_oai_config = yaml_oai_config[0] - if isinstance(default_openai_config_, OpenaiConfig) and isinstance( + if isinstance(default_openai_config_, APIServerConfig) and isinstance( yaml_oai_config, dict ): openai_config_dict = merge_dicts( - default_openai_config_.model_dump(), # Default OpenaiConfig (or from class init) + default_openai_config_.model_dump(), # Default APIServerConfig (or from class init) yaml_oai_config, oai_cli_passed_args, ) @@ -1189,7 +1192,7 @@ class BaseEnv(ABC): data_path_to_save_groups=f"data/{cls.name or 'groups'}.jsonl", use_wandb=True, ) - PROCESS_MODE_OPENAI_DEFAULT_CONFIG = OpenaiConfig( + PROCESS_MODE_OPENAI_DEFAULT_CONFIG = APIServerConfig( model_name="gpt-4.1-nano", base_url=None, api_key=None, @@ -1200,10 +1203,7 @@ class BaseEnv(ABC): ) # Get the base default configurations from the specific environment class - ( - default_env_config, - default_server_configs, - ) = cls.config_init() + default_env_config, default_server_configs = cls.config_init() # Define namespace prefixes env_full_prefix = f"{ENV_NAMESPACE}{NAMESPACE_SEP}" @@ -1215,8 +1215,7 @@ class BaseEnv(ABC): type(default_env_config), PROCESS_MODE_ENV_DEFAULT_CONFIG ) openai_config_cls_new_defaults = adjust_model_defaults( - OpenaiConfig, - PROCESS_MODE_OPENAI_DEFAULT_CONFIG, + APIServerConfig, PROCESS_MODE_OPENAI_DEFAULT_CONFIG ) server_manager_config_cls_new_defaults = adjust_model_defaults( ServerManagerConfig, @@ -1283,7 +1282,7 @@ class BaseEnv(ABC): oai_cli_passed_args or yaml_oai_config ): raise ValueError( - "ServerBaseline is not compatible with OpenAI-namespaced CLI arguments. Please edit `config_init` directly or use OpenaiConfig." # noqa: E501 + "ServerBaseline is not compatible with OpenAI-namespaced CLI arguments. Please edit `config_init` directly or use APIServerConfig." # noqa: E501 ) if ( @@ -1296,11 +1295,11 @@ class BaseEnv(ABC): default_openai_config_ = default_server_configs if isinstance(yaml_oai_config, list) and len(yaml_oai_config) == 1: yaml_oai_config = yaml_oai_config[0] - if isinstance(default_openai_config_, OpenaiConfig) and isinstance( + if isinstance(default_openai_config_, APIServerConfig) and isinstance( yaml_oai_config, dict ): openai_config_dict = merge_dicts( - default_openai_config_.model_dump(), # Default OpenaiConfig (or from class init) + default_openai_config_.model_dump(), # Default APIServerConfig (or from class init) PROCESS_MODE_OPENAI_DEFAULT_CONFIG.model_dump(), # Process Mode Defaults yaml_oai_config, oai_cli_passed_args, diff --git a/atroposlib/envs/server_handling/openai_server.py b/atroposlib/envs/server_handling/openai_server.py index f161869e..ff353569 100644 --- a/atroposlib/envs/server_handling/openai_server.py +++ b/atroposlib/envs/server_handling/openai_server.py @@ -1,142 +1,28 @@ import asyncio -import collections -import time -from asyncio import exceptions -from typing import Optional +import warnings import aiohttp -import numpy as np import openai from openai.types.chat.chat_completion import ChatCompletion from openai.types.completion import Completion -from pydantic import BaseModel, Field from pydantic_cli import FailedExecutionException -from tenacity import retry, stop_after_attempt, wait_random_exponential from atroposlib.envs.constants import NAMESPACE_SEP, OPENAI_NAMESPACE +from atroposlib.envs.server_handling.server_baseline import APIServer, APIServerConfig -class OpenaiConfig(BaseModel): +class OpenAIServer(APIServer): """ - Configuration for the server manager. + OpenAI server handling. """ - api_key: Optional[str] = Field( - default=None, description="API key for OpenAI API. Use 'x' for local servers." - ) - base_url: Optional[str] = Field( - default=None, - description="URL of the API endpoint. None if using official OpenAI API, otherwise local server URL.", - ) - timeout: int = Field( - default=1200, description="Timeout for the request in seconds." - ) - num_max_requests_at_once: int = Field( - default=512, - description="Maximum number of concurrent requests. Note: You should divide this by the n kwarg.", - ) - num_requests_for_eval: int = Field( - default=64, description="Maximum number of concurrent requests for evaluation." - ) - model_name: str = Field( - default="default", - description="The model name to use. Required for both OpenAI and local models.", - ) - rolling_buffer_length: int = Field( - default=1000, description="Length of the rolling buffer to store metrics." - ) - - -class AsyncSemWithAdaptiveWeight(asyncio.Semaphore): - def __init__(self, value: int): - super().__init__(value=value) - self.max_val = value - self.weight = 1.0 - - def update_weight(self, weight: float) -> None: - self.weight = weight - - def min_val(self): - return self.max_val * (1.0 - self.weight) - - def release(self): - """Release a semaphore, incrementing the internal counter by one. - - When it was zero on entry and another coroutine is waiting for it to - become larger than zero again, wake up that coroutine. - - If weight is set, it'll only wake up next if the value is greater than the max_val * weight - """ - self._value += 1 - if self._value > self.min_val(): - self._wake_up_next() - - def locked(self): - """Returns True if semaphore cannot be acquired immediately.""" - return self._value <= self.min_val() or ( - any(not w.cancelled() for w in (self._waiters or ())) - ) - - async def acquire(self): - """Acquire a semaphore. - - If the internal counter is larger than zero on entry, - decrement it by one and return True immediately. If it is - zero on entry, block, waiting until some other coroutine has - called release() to make it larger than 0, and then return - True. - """ - if not self.locked(): - self._value -= 1 - return True - - if self._waiters is None: - self._waiters = collections.deque() - fut = self._get_loop().create_future() - self._waiters.append(fut) - - # Finally block should be called before the CancelledError - # handling as we don't want CancelledError to call - # _wake_up_first() and attempt to wake up itself. - try: - try: - await fut - finally: - self._waiters.remove(fut) - except exceptions.CancelledError: - if not fut.cancelled(): - self._value += 1 - self._wake_up_next() - raise - - if self._value > self.min_val(): - self._wake_up_next() - return True - - -class OpenAIServer: - def __init__(self, config: OpenaiConfig): - self.config = config + def __init__(self, config: APIServerConfig): self.openai = openai.AsyncClient( api_key=config.api_key, base_url=config.base_url, timeout=config.timeout, ) - self.sem = AsyncSemWithAdaptiveWeight(config.num_max_requests_at_once) - self.eval_sem = AsyncSemWithAdaptiveWeight(config.num_requests_for_eval) - self.server_healthy = True - self.attempts_list = [] - self.request_timings = [] - # in case eval is much different, we should keep different buffers - self.eval_attempts_list = [] - self.eval_request_timings = [] - self.check_task = None - self.initialized = False - - async def update_weight(self, weight: float) -> None: - # need to update sems - self.sem.update_weight(weight) - self.eval_sem.update_weight(weight) + super().__init__(config) async def check_server_status_task(self): while True: @@ -156,147 +42,90 @@ class OpenAIServer: self.server_healthy = False await asyncio.sleep(1) - async def wandb_metrics( - self, metrics_dict: Optional[dict], server_name: Optional[str] - ): - if server_name is None: - server_name = "server" - if len(self.request_timings) > 0: - metrics_dict[f"server/{server_name}_request_time_avg"] = np.mean( - self.request_timings + async def _chat_completion_wrapper(self, **kwargs) -> ChatCompletion: + """ + Wrapper for the chat completion using the openai client. + """ + assert ( + kwargs.get("model", None) is not None + ), "Model is required for chat completion!" + assert ( + kwargs.get("messages", None) is not None + ), "Messages are required for chat completion!" + if self.config.n_kwarg_is_ignored: + n = kwargs.pop("n", 1) + completion_list = await asyncio.gather( + *[self.openai.chat.completions.create(**kwargs) for _ in range(n)] ) - metrics_dict[f"server/{server_name}_request_time_std"] = np.std( - self.request_timings - ) - metrics_dict[f"server/{server_name}_request_time_99p"] = np.percentile( - self.request_timings, 99 - ) - if len(self.eval_request_timings) > 0: - metrics_dict[f"server/{server_name}_eval_request_time_avg"] = np.mean( - self.eval_request_timings - ) - metrics_dict[f"server/{server_name}_eval_request_time_std"] = np.std( - self.eval_request_timings - ) - metrics_dict[f"server/{server_name}_eval_request_time_99p"] = np.percentile( - self.eval_request_timings, 99 - ) - if len(self.attempts_list) > 0: - metrics_dict[f"server/{server_name}_average_num_attempts"] = np.mean( - self.attempts_list - ) - if len(self.eval_attempts_list) > 0: - metrics_dict[f"server/{server_name}_eval_retry_rate"] = np.mean( - self.eval_attempts_list - ) - return metrics_dict - - @retry( - stop=stop_after_attempt(3), wait=wait_random_exponential(multiplier=1, max=10) - ) - async def _chat_comp(self, stat_dict, **kwargs) -> ChatCompletion: - while not self.server_healthy: - await asyncio.sleep(1) - async with self.sem: - if stat_dict.get("start", None) is None: - stat_dict["start"] = time.time() - stat_dict["attempts"] += 1 - completions = await self.openai.chat.completions.create(**kwargs) - stat_dict["end"] = time.time() - return completions - - @retry( - stop=stop_after_attempt(3), wait=wait_random_exponential(multiplier=1, max=10) - ) - async def _chat_eval(self, stat_dict, **kwargs) -> ChatCompletion: - while not self.server_healthy: - await asyncio.sleep(1) - async with self.eval_sem: - if stat_dict.get("start", None) is None: - stat_dict["start"] = time.time() - stat_dict["attempts"] += 1 - completions = await self.openai.chat.completions.create(**kwargs) - stat_dict["end"] = time.time() - return completions - - @retry( - stop=stop_after_attempt(3), wait=wait_random_exponential(multiplier=1, max=10) - ) - async def chat_completion(self, **kwargs) -> ChatCompletion: - if not self.initialized: - if ( - self.config.base_url is not None - ): # skip health check if using OpenAI API - self.check_task = asyncio.create_task(self.check_server_status_task()) + completions = completion_list[0] + if n > 1: + for c in completion_list[1:]: + completions.choices.extend(c.choices) else: - self.server_healthy = True - self.initialized = True - kwargs["model"] = self.config.model_name - split = kwargs.pop("split", "train") - stat_dict = {} - stat_dict["attempts"] = 0 - if split == "train": - ret_data = await self._chat_comp(stat_dict, **kwargs) - self.request_timings.append(stat_dict["end"] - stat_dict["start"]) - self.attempts_list.append(stat_dict["attempts"]) + completions = await self.openai.chat.completions.create(**kwargs) else: - # Give separate eval workers, if desired, gotta go fast for those evals - ret_data = await self._chat_eval(stat_dict, **kwargs) - self.eval_request_timings.append(stat_dict["end"] - stat_dict["start"]) - self.eval_attempts_list.append(stat_dict["attempts"]) - return ret_data - - @retry( - stop=stop_after_attempt(3), wait=wait_random_exponential(multiplier=1, max=10) - ) - async def _comp(self, stat_dict, **kwargs) -> Completion: - while not self.server_healthy: - await asyncio.sleep(1) - async with self.sem: - if stat_dict.get("start", None) is None: - stat_dict["start"] = time.time() - stat_dict["attempts"] += 1 - completions = await self.openai.completions.create(**kwargs) - stat_dict["end"] = time.time() - return completions - - @retry( - stop=stop_after_attempt(3), wait=wait_random_exponential(multiplier=1, max=10) - ) - async def _comp_eval(self, stat_dict, **kwargs) -> Completion: - while not self.server_healthy: - await asyncio.sleep(1) - async with self.eval_sem: - if stat_dict.get("start", None) is None: - stat_dict["start"] = time.time() - stat_dict["attempts"] += 1 - completions = await self.openai.completions.create(**kwargs) - stat_dict["end"] = time.time() - return completions - - async def completion(self, **kwargs) -> Completion: - if not self.initialized: - if ( - self.config.base_url is not None - ): # skip health check if using OpenAI API - self.check_task = asyncio.create_task(self.check_server_status_task()) + if "n" in kwargs: + n = kwargs["n"] else: - self.server_healthy = True - self.initialized = True - kwargs["model"] = self.config.model_name - split = kwargs.pop("split", "train") - stat_dict = {} - stat_dict["attempts"] = 0 - if split == "train": - ret_data = await self._comp(stat_dict, **kwargs) - self.request_timings.append(stat_dict["end"] - stat_dict["start"]) - self.attempts_list.append(stat_dict["attempts"]) + n = 1 + completions = await self.openai.chat.completions.create(**kwargs) + if len(completions.choices) != n: + if len(completions.choices) != 1: + raise ValueError( + f"Expected 1 or {n} completions, got {len(completions.choices)}!" + ) + else: + warnings.warn("n kwarg is ignored by the API, setting to True") + self.config.n_kwarg_is_ignored = True + completion_list = await asyncio.gather( + *[ + self.openai.chat.completions.create(**kwargs) + for _ in range(1, n) + ] + ) + for c in completion_list: + completions.choices.extend(c.choices) + return completions + + async def _completion_wrapper(self, **kwargs) -> Completion: + """ + Wrapper for the completion using the openai client. + """ + assert ( + kwargs.get("model", None) is not None + ), "Model is required for completion!" + assert ( + kwargs.get("prompt", None) is not None + ), "Prompt is required for completion!" + if self.config.n_kwarg_is_ignored: + n = kwargs.pop("n", 1) + completion_list = await asyncio.gather( + *[self.openai.completions.create(**kwargs) for _ in range(n)] + ) + completions = completion_list[0] + if n > 1: + for c in completion_list[1:]: + completions.choices.extend(c.choices) else: - # Give separate eval workers, if desired, gotta go fast for those evals - ret_data = await self._comp_eval(stat_dict, **kwargs) - self.eval_request_timings.append(stat_dict["end"] - stat_dict["start"]) - self.eval_attempts_list.append(stat_dict["attempts"]) - return ret_data + if "n" in kwargs: + n = kwargs["n"] + else: + n = 1 + completions = await self.openai.completions.create(**kwargs) + if len(completions.choices) != n: + if len(completions.choices) != 1: + raise ValueError( + f"Expected 1 or {n} completions, got {len(completions.choices)}!" + ) + else: + warnings.warn("n kwarg is ignored by the API, setting to True") + self.config.n_kwarg_is_ignored = True + completion_list = await asyncio.gather( + *[self.openai.completions.create(**kwargs) for _ in range(1, n)] + ) + for c in completion_list: + completions.choices.extend(c.choices) + return completions def resolve_openai_configs( @@ -338,7 +167,7 @@ def resolve_openai_configs( f"Using multi-server configuration defined in YAML under '{OPENAI_NAMESPACE}'." ) try: - server_configs = [OpenaiConfig(**cfg) for cfg in openai_yaml_config] + server_configs = [APIServerConfig(**cfg) for cfg in openai_yaml_config] except Exception as e: raise FailedExecutionException( f"Error parsing multi-server OpenAI configuration from YAML under '{OPENAI_NAMESPACE}': {e}" @@ -354,14 +183,14 @@ def resolve_openai_configs( "Using single OpenAI server configuration based on merged settings (default/YAML/CLI)." ) try: - final_openai_config = OpenaiConfig(**openai_config_dict) + final_openai_config = APIServerConfig(**openai_config_dict) except Exception as e: raise FailedExecutionException( f"Error creating final OpenAI configuration from merged settings: {e}\n" f"Merged Dict: {openai_config_dict}" ) from e - if isinstance(default_server_configs, OpenaiConfig): + if isinstance(default_server_configs, APIServerConfig): server_configs = final_openai_config elif isinstance(default_server_configs, list): server_configs = [final_openai_config] diff --git a/atroposlib/envs/server_handling/server_baseline.py b/atroposlib/envs/server_handling/server_baseline.py new file mode 100644 index 00000000..8a24a77b --- /dev/null +++ b/atroposlib/envs/server_handling/server_baseline.py @@ -0,0 +1,340 @@ +import asyncio +import collections +import time +from abc import ABC, abstractmethod +from asyncio import exceptions +from typing import Literal, Optional + +import numpy as np +from openai.types.chat.chat_completion import ChatCompletion +from openai.types.completion import Completion +from pydantic import BaseModel, Field +from tenacity import retry, stop_after_attempt, wait_random_exponential + + +class AsyncSemWithAdaptiveWeight(asyncio.Semaphore): + def __init__(self, value: int): + super().__init__(value=value) + self.max_val = value + self.weight = 1.0 + + def update_weight(self, weight: float) -> None: + """ + Update the weight of the semaphore. + """ + self.weight = weight + + def min_val(self): + """ + Returns the minimum value of the semaphore. + """ + return self.max_val * (1.0 - self.weight) + + def release(self): + """Release a semaphore, incrementing the internal counter by one. + + When it was zero on entry and another coroutine is waiting for it to + become larger than zero again, wake up that coroutine. + + If weight is set, it'll only wake up next if the value is greater than the max_val * weight + """ + self._value += 1 + if self._value > self.min_val(): + self._wake_up_next() + + def locked(self): + """Returns True if semaphore cannot be acquired immediately.""" + return self._value <= self.min_val() or ( + any(not w.cancelled() for w in (self._waiters or ())) + ) + + async def acquire(self): + """Acquire a semaphore. + + If the internal counter is larger than zero on entry, + decrement it by one and return True immediately. If it is + zero on entry, block, waiting until some other coroutine has + called release() to make it larger than 0, and then return + True. + """ + if not self.locked(): + self._value -= 1 + return True + + if self._waiters is None: + self._waiters = collections.deque() + fut = self._get_loop().create_future() + self._waiters.append(fut) + + # Finally block should be called before the CancelledError + # handling as we don't want CancelledError to call + # _wake_up_first() and attempt to wake up itself. + try: + try: + await fut + finally: + self._waiters.remove(fut) + except exceptions.CancelledError: + if not fut.cancelled(): + self._value += 1 + self._wake_up_next() + raise + + if self._value > self.min_val(): + self._wake_up_next() + return True + + +class ServerBaseline(BaseModel): + """ + Baseline configuration for server information. If local, uses ports 9004-9007 for the servers, + assuming a 1:1 split of GPUs. + """ + + timeout: int = Field( + default=1200, description="Timeout for the request in seconds." + ) + num_max_requests_at_once: int = Field( + default=512, + description="Maximum number of concurrent requests. You should divide this by the n kwarg.", + ) + num_requests_for_eval: int = Field( + default=64, description="Maximum number of concurrent requests for evaluation." + ) + model_name: str = Field( + default="default", + description="The model name to use. Only works with sglang, please provide the model name.", + ) + rolling_buffer_length: int = Field( + default=1000, description="Length of the rolling buffer to store metrics." + ) + server_type: Literal["openai", "trl"] = Field( + default="openai", description="Type of server to use, openai or trl" + ) + + +class APIServerConfig(ServerBaseline): + """ + API server configuration. + """ + + api_key: Optional[str] = Field(default="", description="API key for the server.") + base_url: Optional[str] = Field(default="", description="Base URL for the server.") + n_kwarg_is_ignored: bool = Field( + default=False, description="Whether the n kwarg is ignored by this API server." + ) + + +class APIServer(ABC): + """ + Abstract class for API servers. + """ + + def __init__(self, config: APIServerConfig): + self.config = config + self.sem = AsyncSemWithAdaptiveWeight(config.num_max_requests_at_once) + self.eval_sem = AsyncSemWithAdaptiveWeight(config.num_requests_for_eval) + self.server_healthy = True + self.attempts_list = [] + self.request_timings = [] + # in case eval is much different, we should keep different buffers + self.eval_attempts_list = [] + self.eval_request_timings = [] + self.check_task = None + self.initialized = False + + async def update_weight(self, weight: float) -> None: + """ + Update the weight of the semaphores + """ + # need to update sems + self.sem.update_weight(weight) + self.eval_sem.update_weight(weight) + + @abstractmethod + async def check_server_status_task(self): + """ + Check the status of the server. Should be overridden by the child class. + Set self.server_healthy to True if the server is healthy. + """ + self.server_healthy = False + + async def wandb_metrics( + self, metrics_dict: Optional[dict], server_name: Optional[str] + ): + """ + Add metrics to the metrics dictionary. + + If you want to add more metrics, you can do so by overriding this method, but make sure to call + super().wandb_metrics(metrics_dict, server_name) first to get the default metrics, if you still want them. + """ + if server_name is None: + server_name = "server" + if len(self.request_timings) > 0: + metrics_dict[f"server/{server_name}_request_time_avg"] = np.mean( + self.request_timings + ) + metrics_dict[f"server/{server_name}_request_time_std"] = np.std( + self.request_timings + ) + metrics_dict[f"server/{server_name}_request_time_99p"] = np.percentile( + self.request_timings, 99 + ) + if len(self.eval_request_timings) > 0: + metrics_dict[f"server/{server_name}_eval_request_time_avg"] = np.mean( + self.eval_request_timings + ) + metrics_dict[f"server/{server_name}_eval_request_time_std"] = np.std( + self.eval_request_timings + ) + metrics_dict[f"server/{server_name}_eval_request_time_99p"] = np.percentile( + self.eval_request_timings, 99 + ) + if len(self.attempts_list) > 0: + metrics_dict[f"server/{server_name}_average_num_attempts"] = np.mean( + self.attempts_list + ) + if len(self.eval_attempts_list) > 0: + metrics_dict[f"server/{server_name}_eval_retry_rate"] = np.mean( + self.eval_attempts_list + ) + return metrics_dict + + @abstractmethod + async def _chat_completion_wrapper(self, **kwargs) -> ChatCompletion: + """ + Wrapper for the chat completion. Should be overridden by the child class and return a ChatCompletion object. + """ + pass + + @abstractmethod + async def _completion_wrapper(self, **kwargs) -> Completion: + """ + Wrapper for the completion. Should be overridden by the child class and return a Completion object. + """ + pass + + @retry( + stop=stop_after_attempt(3), wait=wait_random_exponential(multiplier=1, max=10) + ) + async def _chat_comp(self, stat_dict, **kwargs) -> ChatCompletion: + """ + Simple retry and stat collection wrapper for the chat completion. + """ + while not self.server_healthy: + await asyncio.sleep(1) + async with self.sem: + if stat_dict.get("start", None) is None: + stat_dict["start"] = time.time() + stat_dict["attempts"] += 1 + completions = await self._chat_completion_wrapper(**kwargs) + stat_dict["end"] = time.time() + return completions + + @retry( + stop=stop_after_attempt(3), wait=wait_random_exponential(multiplier=1, max=10) + ) + async def _chat_eval(self, stat_dict, **kwargs) -> ChatCompletion: + """ + Simple retry and stat collection wrapper for the chat completion. + """ + while not self.server_healthy: + await asyncio.sleep(1) + async with self.eval_sem: + if stat_dict.get("start", None) is None: + stat_dict["start"] = time.time() + stat_dict["attempts"] += 1 + completions = await self._chat_completion_wrapper(**kwargs) + stat_dict["end"] = time.time() + return completions + + @retry( + stop=stop_after_attempt(3), wait=wait_random_exponential(multiplier=1, max=10) + ) + async def chat_completion(self, **kwargs) -> ChatCompletion: + """ + Chat completion handler, waits for the server to be healthy and then calls the chat completion wrapper. + """ + if not self.initialized: + if ( + self.config.base_url is not None + ): # skip health check if using OpenAI API + self.check_task = asyncio.create_task(self.check_server_status_task()) + else: + self.server_healthy = True + self.initialized = True + kwargs["model"] = self.config.model_name + split = kwargs.pop("split", "train") + stat_dict = {} + stat_dict["attempts"] = 0 + if split == "train": + ret_data = await self._chat_comp(stat_dict, **kwargs) + self.request_timings.append(stat_dict["end"] - stat_dict["start"]) + self.attempts_list.append(stat_dict["attempts"]) + else: + # Give separate eval workers, if desired, gotta go fast for those evals + ret_data = await self._chat_eval(stat_dict, **kwargs) + self.eval_request_timings.append(stat_dict["end"] - stat_dict["start"]) + self.eval_attempts_list.append(stat_dict["attempts"]) + return ret_data + + @retry( + stop=stop_after_attempt(3), wait=wait_random_exponential(multiplier=1, max=10) + ) + async def _comp(self, stat_dict, **kwargs) -> Completion: + """ + Simple retry and stat collection wrapper for the completion. + """ + while not self.server_healthy: + await asyncio.sleep(1) + async with self.sem: + if stat_dict.get("start", None) is None: + stat_dict["start"] = time.time() + stat_dict["attempts"] += 1 + completions = await self._completion_wrapper(**kwargs) + stat_dict["end"] = time.time() + return completions + + @retry( + stop=stop_after_attempt(3), wait=wait_random_exponential(multiplier=1, max=10) + ) + async def _comp_eval(self, stat_dict, **kwargs) -> Completion: + """ + Simple retry and stat collection wrapper for the completion. + """ + while not self.server_healthy: + await asyncio.sleep(1) + async with self.eval_sem: + if stat_dict.get("start", None) is None: + stat_dict["start"] = time.time() + stat_dict["attempts"] += 1 + completions = await self._completion_wrapper(**kwargs) + stat_dict["end"] = time.time() + return completions + + async def completion(self, **kwargs) -> Completion: + """ + Completion handler, waits for the server to be healthy and then calls the completion wrapper. + """ + if not self.initialized: + if ( + self.config.base_url is not None + ): # skip health check if using OpenAI API + self.check_task = asyncio.create_task(self.check_server_status_task()) + else: + self.server_healthy = True + self.initialized = True + kwargs["model"] = self.config.model_name + split = kwargs.pop("split", "train") + stat_dict = {} + stat_dict["attempts"] = 0 + if split == "train": + ret_data = await self._comp(stat_dict, **kwargs) + self.request_timings.append(stat_dict["end"] - stat_dict["start"]) + self.attempts_list.append(stat_dict["attempts"]) + else: + # Give separate eval workers, if desired, gotta go fast for those evals + ret_data = await self._comp_eval(stat_dict, **kwargs) + self.eval_request_timings.append(stat_dict["end"] - stat_dict["start"]) + self.eval_attempts_list.append(stat_dict["attempts"]) + return ret_data diff --git a/atroposlib/envs/server_handling/server_manager.py b/atroposlib/envs/server_handling/server_manager.py index 9c41f493..ff402882 100644 --- a/atroposlib/envs/server_handling/server_manager.py +++ b/atroposlib/envs/server_handling/server_manager.py @@ -1,4 +1,5 @@ import asyncio +import inspect import os from contextlib import asynccontextmanager from typing import AsyncGenerator, List, Union @@ -7,56 +8,56 @@ from openai.types.chat.chat_completion import ChatCompletion from openai.types.completion import Completion from pydantic import BaseModel, Field -from atroposlib.envs.server_handling.openai_server import OpenaiConfig, OpenAIServer +from atroposlib.envs.server_handling.openai_server import OpenAIServer +from atroposlib.envs.server_handling.server_baseline import ( + APIServer, + APIServerConfig, + ServerBaseline, +) from atroposlib.envs.server_handling.server_harness import ServerHarness +from atroposlib.envs.server_handling.trl_vllm_server import TrlVllmServer class ServerManagerConfig(BaseModel): slurm: bool = Field( - default=True, description="Whether environment is running on slurm or not." + default=False, description="Whether environment is running on slurm or not." ) testing: bool = Field( default=False, description="If set to True, environment uses mock OpenAI data." ) -class ServerBaseline(BaseModel): - """ - Baseline configuration for server information. If local, uses ports 9004-9007 for the servers, - assuming a 1:1 split of GPUs. - """ - - timeout: int = Field( - default=1200, description="Timeout for the request in seconds." - ) - num_max_requests_at_once: int = Field( - default=512, - description="Maximum number of concurrent requests. You should divide this by the n kwarg.", - ) - num_requests_for_eval: int = Field( - default=64, description="Maximum number of concurrent requests for evaluation." - ) - model_name: str = Field( - default="default", - description="The model name to use. Only works with sglang, please provide the model name.", - ) - rolling_buffer_length: int = Field( - default=1000, description="Length of the rolling buffer to store metrics." - ) - - class ServerManager: def __init__( self, - configs: Union[ServerBaseline, List[OpenaiConfig]], + configs: Union[ServerBaseline, List[APIServerConfig]], + server_class: APIServer = APIServer, slurm=False, testing=False, ): + # First we check to see if it's the base server class, and if so, we need to select the appropriate server class + # You can't use type() to check if it's the base server class, because it's an abstract class, it'll appear as + # an ABCMeta, not what you're expecting. + if inspect.isabstract(server_class): + if not isinstance(configs, list): + if configs.server_type == "openai": + server_class = OpenAIServer + elif configs.server_type == "trl": + server_class = TrlVllmServer + else: + raise ValueError(f"Invalid server type: {configs.server_type}") + else: + if configs[0].server_type == "openai": + server_class = OpenAIServer + elif configs[0].server_type == "trl": + server_class = TrlVllmServer + else: + raise ValueError(f"Invalid server type: {configs[0].server_type}") if testing: # testing :) self.servers = [ServerHarness()] return - if isinstance(configs, ServerBaseline): + if not isinstance(configs, list): urls = [] if os.environ.get("SLURM_JOB_NODELIST", None) is not None: nodelist = ( @@ -84,7 +85,7 @@ class ServerManager: openai_configs = [] for url in urls: openai_configs.append( - OpenaiConfig( + APIServerConfig( base_url=url, timeout=configs.timeout, num_max_requests_at_once=configs.num_max_requests_at_once, @@ -94,9 +95,9 @@ class ServerManager: api_key="x", ) ) - self.servers = [OpenAIServer(config) for config in openai_configs] + self.servers = [server_class(config) for config in openai_configs] elif not slurm: - self.servers = [OpenAIServer(config) for config in configs] + self.servers = [server_class(config) for config in configs] else: nodelist = ( os.popen(f'scontrol show hostnames {os.environ["SLURM_JOB_NODELIST"]}') @@ -109,7 +110,7 @@ class ServerManager: "Not enough nodes to distribute to, assuming single node" " and you've setup your sglang appropriately." ) - self.servers = [OpenAIServer(config) for config in configs] + self.servers = [server_class(config) for config in configs] return urls = [] num_training_nodes = int(os.environ.get("NUM_TRAINING_NODES")) @@ -124,7 +125,7 @@ class ServerManager: new_conf = configs[0].model_copy(deep=True) new_conf.base_url = urls[i] new_configs.append(new_conf) - self.servers = [OpenAIServer(config) for config in new_configs] + self.servers = [server_class(config) for config in new_configs] async def update_weight(self, weight: float): for server in self.servers: diff --git a/atroposlib/envs/server_handling/trl_vllm_server.py b/atroposlib/envs/server_handling/trl_vllm_server.py new file mode 100644 index 00000000..599d08c3 --- /dev/null +++ b/atroposlib/envs/server_handling/trl_vllm_server.py @@ -0,0 +1,126 @@ +""" +This is a server that interfaces with trl's vLLM server. + +Developed with much help from @winglian when they worked on integrating Atropos into Axolotl. +""" + +import time +import uuid + +import aiohttp +from openai.types.chat.chat_completion import ( + ChatCompletion, + ChatCompletionMessage, + Choice, +) +from transformers import AutoTokenizer + +from atroposlib.envs.server_handling.server_baseline import APIServer, APIServerConfig + + +class TrlVllmServer(APIServer): + """ + A server that interfaces with trl's vLLM server. + """ + + def __init__(self, config: APIServerConfig): + self.config = config + self.tokenizer = AutoTokenizer.from_pretrained(config.model_name) + super().__init__(config) + + async def check_server_status_task(self): + """ + TODO: Implement server health check for trl's vLLM server + """ + self.server_healthy = True + + async def _chat_completion_wrapper(self, **kwargs) -> ChatCompletion: + """ + Wrapper for the chat completion using the trl's vLLM server. + """ + url = f"{self.config.base_url}/generate/" + prompt = kwargs.get("messages", []) + prompt = self.tokenizer.apply_chat_template( + prompt, tokenize=False, add_generation_prompt=True + ) + async with aiohttp.ClientSession() as session: + async with session.post( + url, + json={ + "prompts": [prompt], + "n": kwargs.get("n", 1), + "repetition_penalty": kwargs.get("repetition_penalty", 1.0), + "temperature": kwargs.get("temperature", 1.0), + "top_p": kwargs.get("top_p", 1.0), + "top_k": kwargs.get("top_k", -1), + "min_p": kwargs.get("min_p", 0.0), + "max_tokens": kwargs.get("max_tokens", 1024), + }, + ) as response: + completions = await response.json() + completions = ChatCompletion( + id=str(uuid.uuid4()), + object="chat.completion", + created=int(time.time()), + model=self.config.model_name, + choices=[ + Choice( + finish_reason=( + "stop" + if self.tokenizer.eos_token_id in completion + else "length" + ), + index=i, + message=ChatCompletionMessage( + content=self.tokenizer.decode(completion), + role="assistant", + ), + ) + for i, completion in enumerate(completions["completion_ids"]) + ], + ) + return completions + + async def _completion_wrapper(self, **kwargs) -> ChatCompletion: + """ + Wrapper for the completion using the trl's vLLM server. + """ + url = f"{self.config.base_url}/generate/" + prompt = kwargs.get("prompt", "") + async with aiohttp.ClientSession() as session: + async with session.post( + url, + json={ + "prompts": [prompt], + "n": kwargs.get("n", 1), + "repetition_penalty": kwargs.get("repetition_penalty", 1.0), + "temperature": kwargs.get("temperature", 1.0), + "top_p": kwargs.get("top_p", 1.0), + "top_k": kwargs.get("top_k", -1), + "min_p": kwargs.get("min_p", 0.0), + "max_tokens": kwargs.get("max_tokens", 1024), + }, + ) as response: + completions = await response.json() + completions = ChatCompletion( + id=str(uuid.uuid4()), + object="chat.completion", + created=int(time.time()), + model=self.config.model_name, + choices=[ + Choice( + finish_reason=( + "stop" + if self.tokenizer.eos_token_id in completion + else "length" + ), + index=i, + message=ChatCompletionMessage( + content=self.tokenizer.decode(completion), + role="assistant", + ), + ) + for i, completion in enumerate(completions["completion_ids"]) + ], + ) + return completions diff --git a/atroposlib/tests/conftest.py b/atroposlib/tests/conftest.py new file mode 100644 index 00000000..d122e39d --- /dev/null +++ b/atroposlib/tests/conftest.py @@ -0,0 +1,23 @@ +import pytest + + +def pytest_addoption(parser): + parser.addoption( + "--runproviders", action="store_true", default=False, help="run provider tests" + ) + + +def pytest_configure(config): + config.addinivalue_line( + "markers", "providers: mark test as requires providers api keys to run" + ) + + +def pytest_collection_modifyitems(config, items): + if config.getoption("--runproviders"): + # --runproviders given in cli: do not skip slow tests + return + skip_providers = pytest.mark.skip(reason="need --runproviders option to run") + for item in items: + if "providers" in item.keywords: + item.add_marker(skip_providers) diff --git a/atroposlib/tests/test_openai_api_workarounds.py b/atroposlib/tests/test_openai_api_workarounds.py new file mode 100644 index 00000000..8ea91911 --- /dev/null +++ b/atroposlib/tests/test_openai_api_workarounds.py @@ -0,0 +1,110 @@ +import asyncio +import os + +import dotenv +import pytest + +from atroposlib.envs.server_handling.openai_server import APIServerConfig, OpenAIServer + + +@pytest.mark.providers +def test_openai_api_n_kwarg_ignore_discovery(): + dotenv.load_dotenv() + openrouter_api_key = os.getenv("OPENROUTER_API_KEY") + if not openrouter_api_key: + pytest.skip("OPENROUTER_API_KEY not set") + config = APIServerConfig( + api_key=openrouter_api_key, + base_url="https://openrouter.ai/api/v1", + model_name="openai/gpt-4.1-nano", + timeout=1200, + num_max_requests_at_once=512, + num_requests_for_eval=64, + rolling_buffer_length=1024, + ) + assert not config.n_kwarg_is_ignored, "n kwarg is not ignored by default" + n = 4 + server = OpenAIServer( + config=config, + ) + response = asyncio.run( + server.chat_completion( + messages=[ + {"role": "user", "content": "Hello, how are you?"}, + ], + n=n, + ) + ) + assert server.config.n_kwarg_is_ignored, "n kwarg is should be set after discovery" + print(len(response.choices), n) + assert ( + len(response.choices) == n + ), f"Expected {n} responses, got {len(response.choices)}" + + +@pytest.mark.providers +def test_openai_api_n_kwarg_ignore_use(): + dotenv.load_dotenv() + openrouter_api_key = os.getenv("OPENROUTER_API_KEY") + if not openrouter_api_key: + pytest.skip("OPENROUTER_API_KEY not set") + config = APIServerConfig( + api_key=openrouter_api_key, + base_url="https://openrouter.ai/api/v1", + model_name="openai/gpt-4.1-nano", + timeout=1200, + num_max_requests_at_once=512, + num_requests_for_eval=64, + rolling_buffer_length=1024, + n_kwarg_is_ignored=True, + ) + server = OpenAIServer( + config=config, + ) + n = 4 + response = asyncio.run( + server.chat_completion( + messages=[ + {"role": "user", "content": "Hello, how are you?"}, + ], + n=n, + ) + ) + assert server.config.n_kwarg_is_ignored, "n kwarg is should be set after discovery" + assert ( + len(response.choices) == n + ), f"Expected {n} responses, got {len(response.choices)}" + + +@pytest.mark.providers +def test_openai_api_n_kwarg_supported(): + dotenv.load_dotenv() + openai_api_key = os.getenv("OPENAI_API_KEY") + if not openai_api_key: + pytest.skip("OPENAI_API_KEY not set") + config = APIServerConfig( + model_name="gpt-4.1-nano", + timeout=1200, + num_max_requests_at_once=512, + num_requests_for_eval=64, + rolling_buffer_length=1024, + n_kwarg_is_ignored=False, + ) + server = OpenAIServer( + config=config, + ) + n = 4 + response = asyncio.run( + server.chat_completion( + messages=[ + {"role": "user", "content": "Hello, how are you?"}, + ], + n=n, + ) + ) + assert ( + not server.config.n_kwarg_is_ignored + ), "n kwarg should be used with supported models" + assert ( + len(response.choices) == n + ), f"Expected {n} responses, got {len(response.choices)}" diff --git a/environments/code_execution_server/Dockerfile b/environments/code_execution_server/Dockerfile new file mode 100644 index 00000000..06ece2d3 --- /dev/null +++ b/environments/code_execution_server/Dockerfile @@ -0,0 +1,16 @@ +FROM python:3 + +RUN echo "Acquire::http::Pipeline-Depth 0;" > /etc/apt/apt.conf.d/99custom && \ + echo "Acquire::http::No-Cache true;" >> /etc/apt/apt.conf.d/99custom && \ + echo "Acquire::BrokenProxy true;" >> /etc/apt/apt.conf.d/99custom + +RUN apt-get update && apt-get upgrade -y \ + && apt-get install -y build-essential + +RUN pip install flask + +WORKDIR /tmp + +COPY server.py /tmp/server.py + +CMD ["python", "server.py"] diff --git a/environments/code_execution_server/coding_server.py b/environments/code_execution_server/coding_server.py new file mode 100644 index 00000000..5e4774fb --- /dev/null +++ b/environments/code_execution_server/coding_server.py @@ -0,0 +1,196 @@ +import random +from typing import Dict, List, Optional, Tuple, TypedDict, Union + +from datasets import load_dataset +from latex2sympy2_extended import NormalizationConfig +from math_verify import LatexExtractionConfig, parse, verify +from tqdm.asyncio import tqdm_asyncio +import regex as re + +import asyncio +import httpx +import docker, os + +from atroposlib.envs.base import BaseEnv, BaseEnvConfig, OpenaiConfig, ScoredDataGroup +from atroposlib.type_definitions import Item, number + +from atroposlib.type_definitions import GameHistory, Item +from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer + +system_prompt = ( + "You are a deep thinking AI, you may use extremely long chains of thought " + "to deeply consider the problem and deliberate with yourself via systematic " + "reasoning processes to help come to a correct solution prior to answering. " + "You should enclose your thoughts and internal monologue inside " + "tags, and then provide your solution or response to the problem.\n\n" +) + +async def submit_code(client, code, test_input, language="python"): + url = "http://localhost:5002/execute" + payload = { + "code": code, + "input": test_input, + "language": language + } + response = await client.post(url, json=payload) + response_json = response.json() + return response_json["output"] + +async def get_results(code, answer): + async with httpx.AsyncClient() as client: + tasks = [] + for i in range(len(answer)): + tasks.append(submit_code(client, code, answer[i])) + + results = await asyncio.gather(*tasks) + return [result for result in results] + +def init_docker(): + client = docker.from_env() + def build_docker_image(): + try: + # Build the Docker image + print("Building Docker image...") + current_dir = os.path.dirname(os.path.abspath(__file__)) # Get the current directory of the script + image, logs = client.images.build(path=current_dir, tag="code-executor") + + # Print the build logs + for log in logs: + print(log.get('stream', '').strip()) + + print("Docker image built successfully.") + return image + except docker.errors.BuildError as e: + print(f"Error during Docker image build: {e}") + + def run_docker_container(): + try: + # Run the Docker container + print("Running Docker container...") + container = client.containers.run("code-executor", + ports={'5002/tcp': 5002}, + detach=True) # Runs in detached mode (in the background) + + print(f"Docker container is running with ID: {container.id}") + return container + except docker.errors.ContainerError as e: + print(f"Error during Docker container run: {e}") + + build_docker_image() + container = run_docker_container() + return container + +class CodingEnv(BaseEnv): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + async def collect_trajectories( + self, item: Item + ) -> Tuple[GameHistory | None, List[Item]]: + chat_completions = await self.server.chat_completion( + messages=[ + { + "role": "system", + "content": "You must submit your answer with ```python\n{code}```", + }, + dict(item[0][0]), + ], + n=self.config.group_size, + max_tokens=1024 * 4, + ) + to_score = list() + to_backlog = list() + for i, chat_completion in enumerate(chat_completions.choices): + messages = ( + dict(item[0][0]), + {"role": "assistant", "content": chat_completion.message.content}, + ) + to_score.append( + ( + messages, + item[1], + ) + ) + + to_postprocess = await self.score(to_score) + return to_postprocess, to_backlog + + async def evaluate(self, *args, **kwargs): + """ + Evaluate the environment, this is called every steps_per_eval steps + + Included here is an example on how to use eval workers to run a task. + + You may however do whatever you want in this method. + + :param args: + :param kwargs: + :return: None. + """ + return + + async def setup(self): + """Setup the environment""" + self.container = init_docker() + self.train = load_dataset("deepmind/code_contests", split="train") + self.iter = 0 + + async def get_next_item(self) -> Item: + """ + Get the next items to be rolled out + """ + next_item = self.train[self.iter % len(self.train)] + self.iter += 1 + prompt = tuple( + [frozenset({"role": "user", "content": next_item["description"]}.items())] + ) + answer = (tuple(next_item["private_tests"]["input"]), tuple(next_item["private_tests"]["output"]), tuple(next_item["generated_tests"]["input"]), tuple(next_item["generated_tests"]["output"])) + return (prompt, answer) + + def extract_python_code_blocks(self, text): + # Regex specifically looks for ```python\n...code...\n``` + pattern = r'^```(?:\w+)?\s*\n(.*?)(?=^```)```' + result = re.findall(pattern, text, re.DOTALL | re.MULTILINE) + python_blocks = [r for r in result] + return python_blocks + + async def score(self, rollout_group_data) -> Optional[ScoredDataGroup]: + #print("Rollout group data", rollout_group_data) + scores = ScoredDataGroup() + scores["tokens"] = list() + scores["masks"] = list() + scores["scores"] = list() + random.shuffle(rollout_group_data) + for item in rollout_group_data: + out_dict = tokenize_for_trainer(self.tokenizer, item[0]) + tokens = out_dict["tokens"] + masks = out_dict["masks"] + """ + CALCULATE REWARD NOW + """ + code = self.extract_python_code_blocks(item[0][-1]["content"])[0] + test_cases = list(item[1][0]) + list(item[1][2]) + x = await get_results(code, test_cases) + output_cases = list(item[1][1]) + list(item[1][3]) + assert len(x) == len(output_cases) + reward = True + for k in range(len(x)): + if x[k] != output_cases[k]: + reward = False + break + # remove obviously bad examples + if len([1 for i in masks if i != -100]) < 10: + continue + scores["tokens"].append(tokens) + scores["masks"].append(masks) + scores["scores"].append(1.0 if reward else -1.0) + if len(scores["tokens"]) >= self.config.group_size: + break + # check if all the same + # print(scores['scores']) + # if all([scores["scores"][0] == score for score in scores["scores"]]): + # return None # If all the same, we return None + return scores + +if __name__ == "__main__": + CodingEnv.cli() diff --git a/environments/code_execution_server/server.py b/environments/code_execution_server/server.py new file mode 100644 index 00000000..e63f6112 --- /dev/null +++ b/environments/code_execution_server/server.py @@ -0,0 +1,95 @@ +""" +Instructions: + +# Build the image +docker build -t cpp-flask-executor . + +# Run the container +docker run -p 5002:5002 cpp-flask-executor + +curl -X POST http://localhost:5002/execute \ + -H "Content-Type: application/json" \ + -d '{"code": "#include\nint main(){int x; std::cin>>x; std::cout< Tuple[BaseEnvConfig, List[OpenaiConfig]]: + def config_init(self) -> Tuple[BaseEnvConfig, List[APIServerConfig]]: env_config = BaseEnvConfig( tokenizer_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview", group_size=16, @@ -77,7 +77,7 @@ class FundamentalPredictionEnv(BaseEnv): eval_limit_ratio=0.1, ) server_configs = [ - OpenaiConfig( + APIServerConfig( model_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview", base_url="http://localhost:9004/v1", api_key="x", diff --git a/environments/game_environments/gymnasium/gym_taxi.py b/environments/game_environments/gymnasium/gym_taxi.py index d4f239a1..382959bc 100644 --- a/environments/game_environments/gymnasium/gym_taxi.py +++ b/environments/game_environments/gymnasium/gym_taxi.py @@ -2,7 +2,7 @@ from typing import Dict, List, Optional, Tuple import gymnasium as gym -from atroposlib.envs.base import BaseEnv, BaseEnvConfig, OpenaiConfig, ScoredDataItem +from atroposlib.envs.base import APIServerConfig, BaseEnv, BaseEnvConfig, ScoredDataItem from atroposlib.type_definitions import Item start_msg = """### Description @@ -164,7 +164,7 @@ class GymTaxiEnv(BaseEnv): def __init__( self, config: BaseEnvConfig, - server_configs: List[OpenaiConfig], + server_configs: List[APIServerConfig], slurm=True, testing=False, ): @@ -178,7 +178,7 @@ class GymTaxiEnv(BaseEnv): self.print_this_env = False @classmethod - def config_init(cls) -> Tuple[BaseEnvConfig, List[OpenaiConfig]]: + def config_init(cls) -> Tuple[BaseEnvConfig, List[APIServerConfig]]: env_config = BaseEnvConfig( tokenizer_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview", group_size=32, @@ -188,7 +188,7 @@ class GymTaxiEnv(BaseEnv): wandb_name="gym_taxi", ) server_configs = [ - OpenaiConfig( + APIServerConfig( model_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview", base_url="http://localhost:9001/v1", api_key="x", diff --git a/environments/gsm8k_server.py b/environments/gsm8k_server.py index 75dc3e0f..4fd37707 100644 --- a/environments/gsm8k_server.py +++ b/environments/gsm8k_server.py @@ -6,7 +6,12 @@ from latex2sympy2_extended import NormalizationConfig from math_verify import LatexExtractionConfig, parse, verify from tqdm.asyncio import tqdm_asyncio -from atroposlib.envs.base import BaseEnv, BaseEnvConfig, OpenaiConfig, ScoredDataGroup +from atroposlib.envs.base import ( + APIServerConfig, + BaseEnv, + BaseEnvConfig, + ScoredDataGroup, +) from atroposlib.type_definitions import Item, number from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer @@ -38,7 +43,7 @@ class GSM8kEnv(BaseEnv): def __init__( self, config: BaseEnvConfig, - server_configs: List[OpenaiConfig], + server_configs: List[APIServerConfig], slurm=True, testing=False, ): @@ -50,7 +55,7 @@ class GSM8kEnv(BaseEnv): self.completion_lengths = [] @classmethod - def config_init(cls) -> Tuple[BaseEnvConfig, List[OpenaiConfig]]: + def config_init(cls) -> Tuple[BaseEnvConfig, List[APIServerConfig]]: env_config = BaseEnvConfig( tokenizer_name="NousResearch/DeepHermes-3-Llama-3-3B-Preview", group_size=8, @@ -63,7 +68,7 @@ class GSM8kEnv(BaseEnv): wandb_name="gsm8k", ) server_configs = [ - OpenaiConfig( + APIServerConfig( model_name="NousResearch/DeepHermes-3-Llama-3-3B-Preview", base_url="http://localhost:9001/v1", api_key="x", diff --git a/environments/gsm8k_server_axolotl.py b/environments/gsm8k_server_axolotl.py new file mode 100644 index 00000000..244861f8 --- /dev/null +++ b/environments/gsm8k_server_axolotl.py @@ -0,0 +1,300 @@ +import random +from typing import Dict, List, Optional, Tuple, TypedDict, Union + +from datasets import load_dataset +from latex2sympy2_extended import NormalizationConfig +from math_verify import LatexExtractionConfig, parse, verify +from tqdm.asyncio import tqdm_asyncio + +from atroposlib.envs.base import ( + APIServerConfig, + BaseEnv, + BaseEnvConfig, + ScoredDataGroup, +) +from atroposlib.type_definitions import Item, number +from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer + +system_prompt = ( + "You are a deep thinking AI, you may use extremely long chains of thought " + "to deeply consider the problem and deliberate with yourself via systematic " + "reasoning processes to help come to a correct solution prior to answering. " + "You should enclose your thoughts and internal monologue inside " + "tags, and then provide your solution or response to the problem.\n\n" +) + +system_prompt += """You are allocated a maximum of 2048 tokens, please strive to use less. + +You will then provide your answer like this: \\boxed{your answer here} +It is important that you provide your answer in the correct format. +If you do not, you will not receive credit for your answer. +So please end your answer with \\boxed{your answer here}""" + + +class GSM8kRow(TypedDict): + question: str + answer: str + + +class GSM8kEnv(BaseEnv): + + name = "gsm8k" + + def __init__( + self, + config: BaseEnvConfig, + server_configs: List[APIServerConfig], + slurm=True, + testing=False, + ): + super().__init__(config, server_configs, slurm, testing) + self.percent_correct_buffer = list() + self.eval_metrics = list() + # Add tracking for wandb visualizations + self.rollouts_for_wandb = [] + self.completion_lengths = [] + + @classmethod + def config_init(cls) -> Tuple[BaseEnvConfig, List[APIServerConfig]]: + env_config = BaseEnvConfig( + tokenizer_name="Qwen/Qwen3-1.7B", + group_size=8, + use_wandb=True, + rollout_server_url="http://localhost:8000", + total_steps=1000, + batch_size=12, + steps_per_eval=100, + max_token_length=2048, + wandb_name="gsm8k", + ) + server_configs = [ + APIServerConfig( + base_url="http://localhost:9001", + api_key="x", + num_requests_for_eval=256, + model_name="Qwen/Qwen3-1.7B", + server_type="trl", + ), + ] + return env_config, server_configs + + async def wandb_log(self, wandb_metrics: Optional[Dict] = None): + if wandb_metrics is None: + wandb_metrics = {} + + # Try to calculate percent_correct, pass if there's a division by zero + try: + wandb_metrics["train/percent_correct"] = sum( + self.percent_correct_buffer + ) / len(self.percent_correct_buffer) + except ZeroDivisionError: + # Skip if buffer is empty + pass + + self.percent_correct_buffer = list() + for item in self.eval_metrics: + wandb_metrics[item[0]] = item[1] + self.eval_metrics = list() + # Call the parent method to handle the server metrics + await super().wandb_log(wandb_metrics) + + async def setup(self): + self.train = load_dataset("gsm8k", "main", split="train").shuffle(seed=42) + test_data = load_dataset("gsm8k", "main", split="test").shuffle(seed=42) + self.test = list() + for item in test_data: + self.test.append( + { + "question": item["question"], + "gold_answer": item["answer"] + .split("#")[-1] + .strip() + .replace(",", ""), + } + ) + self.iter = 0 + + def save_checkpoint(self, step, data=None): + if data is None: + data = {} + data["iter"] = self.iter + super().save_checkpoint(step, data) + + async def rollout_and_score_eval(self, question: str, answer: str) -> number: + completion = await self.server.chat_completion( + messages=[ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": question}, + ], + n=1, + max_tokens=self.config.max_token_length, + temperature=0.0, + split="eval", + ) + gold_parsed = parse( + "\\boxed{" + answer + "}", + extraction_mode="first_match", + extraction_config=[LatexExtractionConfig()], + ) + answer_parsed = parse( + completion.choices[0].message.content.split("")[-1], + extraction_config=[ + LatexExtractionConfig( + normalization_config=NormalizationConfig( + nits=False, + malformed_operators=False, + basic_latex=True, + equations=True, + boxed="all", + units=True, + ), + # Ensures that boxed is tried first + boxed_match_priority=0, + try_extract_without_anchor=False, + ) + ], + extraction_mode="first_match", + ) + score = 1 if verify(answer_parsed, gold_parsed) else 0 + return score + + async def evaluate(self, *args, **kwargs): + eval_tasks = [] + for item in self.test: + eval_tasks.append( + self.rollout_and_score_eval(item["question"], item["gold_answer"]) + ) + scores = await tqdm_asyncio.gather(*eval_tasks) + self.eval_metrics.append(("eval/percent_correct", sum(scores) / len(scores))) + + async def collect_trajectories( + self, item: GSM8kRow + ) -> Tuple[ScoredDataGroup, list[Item]]: + user_message = {"role": "user", "content": item["question"]} + gold_answer = ( + "\\boxed{" + item["answer"].split("#")[-1].strip().replace(",", "") + "}" + ) + + chat_completions = await self.server.chat_completion( + messages=[{"role": "system", "content": system_prompt}, user_message], + n=self.config.group_size, + max_tokens=self.config.max_token_length, + ) + to_score = list() + to_backlog = list() + for i, chat_completion in enumerate(chat_completions.choices): + messages = ( + {"role": "system", "content": system_prompt}, + user_message, + {"role": "assistant", "content": chat_completion.message.content}, + ) + to_score.append( + { + "messages": messages, + "gold_answer": gold_answer, + "finish_reason": chat_completion.finish_reason, + } + ) + to_postprocess = await self.score(to_score) + return to_postprocess, to_backlog + + async def score( + self, rollout_group_data + ) -> Union[Optional[ScoredDataGroup], List[Optional[ScoredDataGroup]]]: + scores = ScoredDataGroup() + scores["tokens"] = list() + scores["masks"] = list() + scores["scores"] = list() + gold_parsed = parse( + rollout_group_data[0]["gold_answer"], + extraction_mode="first_match", + extraction_config=[LatexExtractionConfig()], + ) + if len(gold_parsed) != 0: + # We require the answer to be provided in correct latex (no malformed operators) + random.shuffle(rollout_group_data) + for item in rollout_group_data: + # print(item[0][-1]["content"]) + answer_parsed = parse( + item["messages"][-1]["content"].split("")[-1], + extraction_config=[ + LatexExtractionConfig( + normalization_config=NormalizationConfig( + nits=False, + malformed_operators=False, + basic_latex=True, + equations=True, + boxed="all", + units=True, + ), + # Ensures that boxed is tried first + boxed_match_priority=0, + try_extract_without_anchor=False, + ) + ], + extraction_mode="first_match", + ) + # Reward 1 if the content is the same as the ground truth, 0 otherwise + reward = verify(answer_parsed, gold_parsed) + # print( + # f"message: {item[0][-1]['content']}, ground_truth: {item[1]}, reward: {reward}" + # ) + out_dict = tokenize_for_trainer( + self.tokenizer, item["messages"], item["finish_reason"] + ) + tokens = out_dict["tokens"] + masks = out_dict["masks"] + # remove obviously bad examples + if len([1 for i in masks if i != -100]) < 10: + continue + scores["tokens"].append(tokens) + scores["masks"].append(masks) + scores["scores"].append(1.0 if reward else -1.0) + if len(scores["tokens"]) >= self.config.group_size: + break + for score in scores["scores"]: + self.percent_correct_buffer.append(max(score, 0)) + # check if all the same + # print(scores['scores']) + if all([score == 1 for score in scores["scores"]]): + # Do length penalty :) + token_lengths = [len(token) for token in scores["tokens"]] + if max(token_lengths) == 0: + # What? But don't want to crash a run so just in case... + return None + + # Get max allowed token length from config + max_allowed_length = self.config.max_token_length + # Set threshold at 50% of max_token_length - no penalty below this + length_threshold = max_allowed_length * 0.5 + + # Apply modified length penalty with threshold + scores["scores"] = [] + for length in token_lengths: + if length <= length_threshold: + # No penalty for responses under threshold + scores["scores"].append(1.0) + else: + # Calculate how far we are between threshold and max as a percentage + percentage_of_range = (length - length_threshold) / ( + max_allowed_length - length_threshold + ) + # Cap at 1.0 in case length exceeds max_allowed_length + percentage_of_range = min(percentage_of_range, 1.0) + # Apply linear penalty scaling from 1.0 down to 0.0 + scores["scores"].append(1.0 - percentage_of_range) + if all([scores["scores"][0] == score for score in scores["scores"]]): + return None # If all the same, we return None + return scores + else: + # If the gold solution is not parseable, we return None + return None + + async def get_next_item(self) -> GSM8kRow: + next_item = self.train[self.iter % len(self.train)] + self.iter += 1 + return next_item + + +if __name__ == "__main__": + GSM8kEnv.cli() diff --git a/environments/math_server.py b/environments/math_server.py index 544f95a0..d00fc886 100644 --- a/environments/math_server.py +++ b/environments/math_server.py @@ -14,10 +14,10 @@ from pydantic import Field from tqdm.asyncio import tqdm_asyncio from atroposlib.envs.base import ( + APIServerConfig, BaseEnv, BaseEnvConfig, EvalHandlingEnum, - OpenaiConfig, ScoredDataGroup, ) from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer @@ -154,7 +154,7 @@ class MathEnv(BaseEnv): def __init__( self, config: RSConfig, - server_configs: List[OpenaiConfig], + server_configs: List[APIServerConfig], slurm=True, testing=False, ): @@ -177,7 +177,7 @@ class MathEnv(BaseEnv): self.iter = 0 @classmethod - def config_init(self) -> Tuple[RSConfig, List[OpenaiConfig]]: + def config_init(self) -> Tuple[RSConfig, List[APIServerConfig]]: env_config = RSConfig( tokenizer_name="deepseek-ai/DeepSeek-R1-Distill-Llama-8B", group_size=8, @@ -192,7 +192,7 @@ class MathEnv(BaseEnv): eval_limit_ratio=0.1, ) server_configs = [ - OpenaiConfig( + APIServerConfig( model_name="deepseek-ai/DeepSeek-R1-Distill-Llama-8B", base_url="http://localhost:9004/v1", api_key="x", diff --git a/environments/mcqa_thinking_env.py b/environments/mcqa_thinking_env.py index 9b5cff07..f584f715 100644 --- a/environments/mcqa_thinking_env.py +++ b/environments/mcqa_thinking_env.py @@ -7,11 +7,11 @@ from datasets import load_dataset from tqdm.asyncio import tqdm_asyncio from atroposlib.envs.base import ( + APIServerConfig, BaseEnv, BaseEnvConfig, EvalHandlingEnum, Item, - OpenaiConfig, ScoredDataGroup, ) from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer @@ -28,7 +28,7 @@ class MCQAThinkingEnv(BaseEnv): def __init__( self, config: BaseEnvConfig, - server_configs: List[OpenaiConfig], + server_configs: List[APIServerConfig], slurm=True, testing=False, ): @@ -46,7 +46,7 @@ class MCQAThinkingEnv(BaseEnv): self.eval_metrics = list() @classmethod - def config_init(self) -> Tuple[BaseEnvConfig, List[OpenaiConfig]]: + def config_init(self) -> Tuple[BaseEnvConfig, List[APIServerConfig]]: env_config = BaseEnvConfig( tokenizer_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview", group_size=32, @@ -64,7 +64,7 @@ class MCQAThinkingEnv(BaseEnv): eval_limit_ratio=0.1, ) server_configs = [ - OpenaiConfig( + APIServerConfig( model_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview", base_url="http://localhost:9004/v1", api_key="x", diff --git a/environments/multimodal_dpo/clevr_cogen_a_train.py b/environments/multimodal_dpo/clevr_cogen_a_train.py index a1c60a26..8aa7f6d1 100644 --- a/environments/multimodal_dpo/clevr_cogen_a_train.py +++ b/environments/multimodal_dpo/clevr_cogen_a_train.py @@ -7,7 +7,12 @@ from typing import List, Optional, Tuple from datasets import load_dataset -from atroposlib.envs.base import BaseEnv, BaseEnvConfig, OpenaiConfig, ScoredDataGroup +from atroposlib.envs.base import ( + APIServerConfig, + BaseEnv, + BaseEnvConfig, + ScoredDataGroup, +) from atroposlib.type_definitions import GameHistory, Item from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer @@ -216,8 +221,7 @@ class MultimodalExampleEnv(BaseEnv): return scores @classmethod - def config_init(cls) -> Tuple[BaseEnvConfig, List[OpenaiConfig]]: - + def config_init(cls) -> Tuple[BaseEnvConfig, List[APIServerConfig]]: config = BaseEnvConfig( wandb_name="clevr_cogen_a_train", tokenizer_name="Qwen/Qwen2-VL-2B-Instruct", @@ -232,7 +236,7 @@ class MultimodalExampleEnv(BaseEnv): ) server_configs = [ - OpenaiConfig( + APIServerConfig( model_name="Qwen/Qwen2-VL-2B-Instruct", base_url="http://localhost:9001/v1", api_key="x", diff --git a/environments/multimodal_dpo/clevr_complex.py b/environments/multimodal_dpo/clevr_complex.py index 84f7a5a9..6003fe39 100644 --- a/environments/multimodal_dpo/clevr_complex.py +++ b/environments/multimodal_dpo/clevr_complex.py @@ -6,7 +6,12 @@ from typing import List, Optional, Tuple from datasets import load_dataset -from atroposlib.envs.base import BaseEnv, BaseEnvConfig, OpenaiConfig, ScoredDataGroup +from atroposlib.envs.base import ( + APIServerConfig, + BaseEnv, + BaseEnvConfig, + ScoredDataGroup, +) from atroposlib.type_definitions import GameHistory, Item from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer @@ -215,8 +220,7 @@ class MultimodalComplexEnv(BaseEnv): return scores @classmethod - def config_init(cls) -> Tuple[BaseEnvConfig, List[OpenaiConfig]]: - + def config_init(cls) -> Tuple[BaseEnvConfig, List[APIServerConfig]]: config = BaseEnvConfig( wandb_name="clevr_complex", tokenizer_name="Qwen/Qwen2-VL-2B-Instruct", @@ -231,7 +235,7 @@ class MultimodalComplexEnv(BaseEnv): ) server_configs = [ - OpenaiConfig( + APIServerConfig( model_name="Qwen/Qwen2-VL-2B-Instruct", base_url="http://localhost:9001/v1", api_key="x", diff --git a/environments/multimodal_dpo/ocr_vqa.py b/environments/multimodal_dpo/ocr_vqa.py index 1ac43556..6820c052 100644 --- a/environments/multimodal_dpo/ocr_vqa.py +++ b/environments/multimodal_dpo/ocr_vqa.py @@ -7,7 +7,12 @@ from typing import List, Optional, Tuple from datasets import load_dataset -from atroposlib.envs.base import BaseEnv, BaseEnvConfig, OpenaiConfig, ScoredDataGroup +from atroposlib.envs.base import ( + APIServerConfig, + BaseEnv, + BaseEnvConfig, + ScoredDataGroup, +) from atroposlib.type_definitions import GameHistory, Item from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer @@ -159,8 +164,7 @@ class OcrVqaEnv(BaseEnv): return scores @classmethod - def config_init(cls) -> Tuple[BaseEnvConfig, List[OpenaiConfig]]: - + def config_init(cls) -> Tuple[BaseEnvConfig, List[APIServerConfig]]: config = BaseEnvConfig( wandb_name="ocr_vqa", tokenizer_name="Qwen/Qwen2-VL-2B-Instruct", @@ -175,7 +179,7 @@ class OcrVqaEnv(BaseEnv): ) server_configs = [ - OpenaiConfig( + APIServerConfig( model_name="Qwen/Qwen2-VL-2B-Instruct", base_url="http://localhost:9001/v1", api_key="x", diff --git a/environments/multimodal_dpo/pixmo_clocks.py b/environments/multimodal_dpo/pixmo_clocks.py index d929e9b4..02932178 100644 --- a/environments/multimodal_dpo/pixmo_clocks.py +++ b/environments/multimodal_dpo/pixmo_clocks.py @@ -7,7 +7,12 @@ from typing import List, Optional, Tuple from datasets import load_dataset -from atroposlib.envs.base import BaseEnv, BaseEnvConfig, OpenaiConfig, ScoredDataGroup +from atroposlib.envs.base import ( + APIServerConfig, + BaseEnv, + BaseEnvConfig, + ScoredDataGroup, +) from atroposlib.type_definitions import GameHistory, Item from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer @@ -164,8 +169,7 @@ class ClockDatasetEnv(BaseEnv): return scores @classmethod - def config_init(cls) -> Tuple[BaseEnvConfig, List[OpenaiConfig]]: - + def config_init(cls) -> Tuple[BaseEnvConfig, List[APIServerConfig]]: config = BaseEnvConfig( wandb_name="pixmo_clocks", tokenizer_name="Qwen/Qwen2-VL-2B-Instruct", @@ -180,7 +184,7 @@ class ClockDatasetEnv(BaseEnv): ) server_configs = [ - OpenaiConfig( + APIServerConfig( model_name="Qwen/Qwen2-VL-2B-Instruct", base_url="http://localhost:9001/v1", api_key="x", diff --git a/environments/multimodal_dpo/pixmo_count.py b/environments/multimodal_dpo/pixmo_count.py index 6a23990d..26c1928d 100644 --- a/environments/multimodal_dpo/pixmo_count.py +++ b/environments/multimodal_dpo/pixmo_count.py @@ -9,7 +9,12 @@ import requests from datasets import load_dataset from PIL import Image -from atroposlib.envs.base import BaseEnv, BaseEnvConfig, OpenaiConfig, ScoredDataGroup +from atroposlib.envs.base import ( + APIServerConfig, + BaseEnv, + BaseEnvConfig, + ScoredDataGroup, +) from atroposlib.type_definitions import GameHistory, Item from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer @@ -151,7 +156,7 @@ class PixmoCountEnv(BaseEnv): return scores @classmethod - def config_init(cls) -> Tuple[BaseEnvConfig, List[OpenaiConfig]]: + def config_init(cls) -> Tuple[BaseEnvConfig, List[APIServerConfig]]: config = BaseEnvConfig( wandb_name="pixmo_count", @@ -167,7 +172,7 @@ class PixmoCountEnv(BaseEnv): ) server_configs = [ - OpenaiConfig( + APIServerConfig( model_name="Qwen/Qwen2-VL-2B-Instruct", base_url="http://localhost:9001/v1", api_key="x", diff --git a/environments/multimodal_dpo/pixmo_point_explanations.py b/environments/multimodal_dpo/pixmo_point_explanations.py index 1eef6a4a..f1e68e06 100644 --- a/environments/multimodal_dpo/pixmo_point_explanations.py +++ b/environments/multimodal_dpo/pixmo_point_explanations.py @@ -9,7 +9,12 @@ import requests from datasets import load_dataset from PIL import Image -from atroposlib.envs.base import BaseEnv, BaseEnvConfig, OpenaiConfig, ScoredDataGroup +from atroposlib.envs.base import ( + APIServerConfig, + BaseEnv, + BaseEnvConfig, + ScoredDataGroup, +) from atroposlib.type_definitions import GameHistory, Item from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer @@ -161,7 +166,7 @@ class PixmoPointExplanationsEnv(BaseEnv): return scores @classmethod - def config_init(cls) -> Tuple[BaseEnvConfig, List[OpenaiConfig]]: + def config_init(cls) -> Tuple[BaseEnvConfig, List[APIServerConfig]]: config = BaseEnvConfig( wandb_name="pixmo_point_explanations", @@ -177,7 +182,7 @@ class PixmoPointExplanationsEnv(BaseEnv): ) server_configs = [ - OpenaiConfig( + APIServerConfig( model_name="Qwen/Qwen2-VL-2B-Instruct", base_url="http://localhost:9001/v1", api_key="x", diff --git a/environments/rlaif_server.py b/environments/rlaif_server.py index d9ffe345..c50ad10a 100644 --- a/environments/rlaif_server.py +++ b/environments/rlaif_server.py @@ -7,10 +7,10 @@ import wandb from datasets import load_dataset from atroposlib.envs.base import ( + APIServerConfig, BaseEnv, BaseEnvConfig, EvalHandlingEnum, - OpenaiConfig, ScoredDataGroup, ) from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer @@ -62,7 +62,7 @@ class RLAIFEnv(BaseEnv): def __init__( self, config: BaseEnvConfig, - server_configs: List[OpenaiConfig], + server_configs: List[APIServerConfig], slurm=True, testing=False, ): @@ -72,7 +72,7 @@ class RLAIFEnv(BaseEnv): self.judgement_strings = list() @classmethod - def config_init(self) -> Tuple[BaseEnvConfig, List[OpenaiConfig]]: + def config_init(self) -> Tuple[BaseEnvConfig, List[APIServerConfig]]: env_config = BaseEnvConfig( tokenizer_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview", group_size=2, @@ -89,7 +89,7 @@ class RLAIFEnv(BaseEnv): eval_limit_ratio=0.1, ) server_configs = [ - OpenaiConfig( + APIServerConfig( model_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview", base_url="http://localhost:9004/v1", api_key="x", diff --git a/environments/tool_calling_server.py b/environments/tool_calling_server.py index b1203c41..e14bc717 100644 --- a/environments/tool_calling_server.py +++ b/environments/tool_calling_server.py @@ -8,11 +8,11 @@ from datasets import load_dataset from tqdm.asyncio import tqdm_asyncio from atroposlib.envs.base import ( + APIServerConfig, BaseEnv, BaseEnvConfig, EvalHandlingEnum, Item, - OpenaiConfig, ScoredDataGroup, ) from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer @@ -29,7 +29,7 @@ class SingleToolCallingEnv(BaseEnv): def __init__( self, config: BaseEnvConfig, - server_configs: List[OpenaiConfig], + server_configs: List[APIServerConfig], slurm=True, testing=False, ): @@ -41,7 +41,7 @@ class SingleToolCallingEnv(BaseEnv): self.completion_lengths = [] @classmethod - def config_init(self) -> Tuple[BaseEnvConfig, List[OpenaiConfig]]: + def config_init(self) -> Tuple[BaseEnvConfig, List[APIServerConfig]]: env_config = BaseEnvConfig( tokenizer_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview", group_size=32, @@ -57,14 +57,14 @@ class SingleToolCallingEnv(BaseEnv): eval_limit_ratio=0.1, ) server_configs = [ - OpenaiConfig( + APIServerConfig( model_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview", base_url="http://localhost:9004/v1", api_key="x", num_max_requests_at_once=32, num_requests_for_eval=256, ), - OpenaiConfig( + APIServerConfig( model_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview", base_url="http://localhost:9005/v1", api_key="x", diff --git a/llm.txt b/llm.txt index 72c65437..a5419504 100644 --- a/llm.txt +++ b/llm.txt @@ -254,7 +254,7 @@ This class provides the foundation for creating custom RL environments. Subclass * **`async def postprocess_histories(self, trajectories: Union[Optional[ScoredDataGroup], List[Optional[ScoredDataGroup]]]) -> Union[Optional[ScoredDataGroup], List[Optional[ScoredDataGroup]]]`**: Called after `collect_trajectories` and before sending data to the server. Use for final processing, scoring, filtering, or formatting of the collected group data. * **`async def wandb_log(self, wandb_metrics: Optional[Dict] = None)`**: Called periodically for W&B logging. Add custom metrics to `wandb_metrics`. **Crucially, call `await super().wandb_log(wandb_metrics)`** at the end to include base metrics and rollouts. * **`save_checkpoint(self, step, data=None)`**: Called automatically by the server based on `checkpoint_interval`. Saves the provided `data` dict (populated with environment state) to JSON. Override to customize *what* or *how* data is saved. -* **`@classmethod config_init(cls) -> Tuple[BaseEnvConfig, Union[ServerBaseline, List[OpenaiConfig]]]`**: Used by CLI `serve` command setup. Returns initial `BaseEnvConfig` and server config(s). Override for custom default CLI configurations. Default returns `cls.env_config_cls(), ServerBaseline()`. +* **`@classmethod config_init(cls) -> Tuple[BaseEnvConfig, Union[ServerBaseline, List[APIServerConfig]]]`**: Used by CLI `serve` command setup. Returns initial `BaseEnvConfig` and server config(s). Override for custom default CLI configurations. Default returns `cls.env_config_cls(), ServerBaseline()`. * **`async def cleanup(self)`**: Called after each item processing (`handle_env`). Use for per-item cleanup if needed (rarely required). **Provided Functionality:** @@ -308,7 +308,7 @@ Settings for the `ServerManager` which handles inference server interactions. #### 10.2.3. Server Baseline Config (`atroposlib.envs.server_handling.server_manager.ServerBaseline`) -Default settings used by `ServerManager` if specific `OpenaiConfig` list isn't provided (e.g., for local/SLURM discovery). +Default settings used by `ServerManager` if specific `APIServerConfig` list isn't provided (e.g., for local/SLURM discovery). | Parameter | Type | Default | Description | | :------------------------- | :------ | :-------- | :------------------------------------------------------------------------------------------------------ | @@ -318,7 +318,7 @@ Default settings used by `ServerManager` if specific `OpenaiConfig` list isn't p | `model_name` | `str` | `default` | Default model name for inference calls. | | `rolling_buffer_length` | `int` | `1000` | Buffer length for server metrics (timings, attempts). | -#### 10.2.4. OpenAI Server Config (`atroposlib.envs.server_handling.openai_server.OpenaiConfig`) +#### 10.2.4. OpenAI Server Config (`atroposlib.envs.server_handling.openai_server.APIServerConfig`) Configuration for individual OpenAI-compatible API servers (official OpenAI, local vLLM/SGLang, etc.). A list of these can be passed to the environment. diff --git a/pyproject.toml b/pyproject.toml index 1547aee7..740a7a07 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,11 +1,11 @@ [project] name = "atroposlib" -version = "0.1.0" +version = "0.2.0" description = "Atropos: An Environment and Rollout handler for LLM RL" readme = "README.md" requires-python = ">=3.10" dependencies = [ - "transformers==4.48.3", + "transformers", "datasets", "openai", "aiohttp",