mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
200 lines
17 KiB
Markdown
200 lines
17 KiB
Markdown
# Base Environment (`BaseEnv`)
|
|
|
|
The `BaseEnv` class (located in `atroposlib/envs/base.py`) provides a foundation for creating custom reinforcement learning environments that interact with Atropos. When creating your own environment, you will typically subclass `BaseEnv` and implement several key methods.
|
|
|
|
## Design philosophy
|
|
|
|
Every environment in Atropos is a microservice that generates rollout data async from whatever trainer you attach to it. Environments (possibly many at once) send data to the [Atropos API server](https://github.com/NousResearch/atropos/tree/main/atroposlib/api) which sequesters rollout data. Your trainer of choice grabs batches of data from the API and backpropagates.
|
|
|
|

|
|
|
|
Unlike other popular alternatives like Gymnasium which model environments as [MDPs](https://arxiv.org/abs/2412.05265), we think about environments as dataloaders and do not make any assumptions about how a trajectory is produced. For multi-agent for example, this means our design is agnostic to [AEC](https://pettingzoo.farama.org/api/aec/) vs. [POSG](https://pettingzoo.farama.org/api/parallel/) - both are supported out of the box.
|
|
To achieve this generality, our environment abstraction deviates from other open source alternatives in several key ways.
|
|
|
|
- **Inference-scoring fusion**: A popular design choice in open-source LLM RL trainers is to separate inference and scoring into independent abstractions. While this makes a lot of sense for single-turn environments like one-shot MCQA, we found that this led to awkwardness in multi-turn setups with process rewards. As such, we assume the existence of a single method `collect_trajectories` which is responsible for both inference and scoring. Users are still welcome to call separate inference and scoring methods from within `collect_trajectories`.
|
|
|
|
- **Groups as atomic units of data**: A natural choice for a data atom in RL is a single trajectory. However, many popular RL methods for fine-tuning LLMs such as DPO and GRPO involve packing contrastive data into the same batch. As such the most fundamental dataloading method in our abstraction is not `collect_trajectory` (singular) but `collect_trajectories` (plural). We do not enforce any definition of what a "group" is other than a set of rollouts. Although a "group" is most commonly constructed by generated multiple rollouts starting from the same initial state (as in DPO and GRPO), a user could just as easily pack `n` similar-sounding problems with very different solutions into a group. For cases like PPO where advantages don't depend on group statistics a user can simply use group size 1.
|
|
|
|
- **Environments return tokens (not messages!)**: One of the most peculiar design choices we made was that at least for text-only environments, environments are responsible for tokenization. This gives us the flexibility to assign token-level rewards and to mix completions-based (e.g. autocomplete suggestion accept/reject) and chat-based (e.g. instruct-model code generation) environments together in the same training run. For cases like multimodal where an OpenAI-formatted message list needs to be passed to a transformers `AutoProcessor`, we support a `list[dict]`-valued `messages` key within our group abstraction [ScoredDataGroup](https://github.com/NousResearch/atropos/blob/a282604baac8dbb3b201f992cfc889ee1e5a0f4a/atroposlib/envs/base.py#L55).
|
|
|
|
## Working with Servers and ManagedServer
|
|
|
|
**🎯 Recommended Approach:** Use `ManagedServer` for automatic token and logprob tracking!
|
|
|
|
When implementing `collect_trajectory` or `collect_trajectories`, you need to interact with your inference server to generate completions and extract tokens/logprobs for training. The **recommended way** to do this is using `ManagedServer`, which automatically handles tokenization, masking, and logprob alignment.
|
|
|
|
### ManagedServer Overview
|
|
|
|
`ManagedServer` wraps your `APIServer` and automatically tracks:
|
|
- **Tokens**: Full unmasked token sequences
|
|
- **Masked Tokens**: Training format with `-100` for prompt positions, actual token IDs for completion
|
|
- **Logprobs**: Training format with `1.0` for masked positions, actual logprob values for completion
|
|
- **Full Text**: Complete text (prompt + completion)
|
|
- **Metadata**: Finish reasons and other information
|
|
|
|
**Why 1.0 for masked logprobs?** It represents an "obviously bad" probability (e^1.0 ≈ 2.718 > 1.0, which is invalid), making it easy to identify and ignore during training.
|
|
|
|
### Basic Usage Pattern
|
|
|
|
```python
|
|
async def collect_trajectories(self, item):
|
|
prompt = format_prompt(item)
|
|
|
|
# Use managed server with context manager
|
|
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
|
|
completion = await managed.completion(
|
|
prompt=prompt,
|
|
n=self.config.group_size,
|
|
max_tokens=4096,
|
|
temperature=1.0,
|
|
)
|
|
|
|
# Get tracked sequences with aligned tokens and logprobs
|
|
state = managed.get_state()
|
|
nodes = state["nodes"]
|
|
|
|
# Extract pre-computed, guaranteed-aligned data
|
|
for choice, node in zip(completion.choices, nodes):
|
|
tokens = node.tokens # ✅ Automatically computed
|
|
masked_tokens = node.masked_tokens # ✅ Automatically masked
|
|
logprobs = node.logprobs # ✅ Automatically aligned
|
|
finish_reason = node.metadata["finish_reason"]
|
|
|
|
# Score and return...
|
|
```
|
|
|
|
### Chat Completion Pattern
|
|
|
|
For chat-based environments, use `chat_completion()`:
|
|
|
|
```python
|
|
async def collect_trajectories(self, item):
|
|
messages = [
|
|
{"role": "system", "content": system_prompt},
|
|
{"role": "user", "content": item["question"]},
|
|
]
|
|
|
|
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
|
|
chat_completion = await managed.chat_completion(
|
|
messages=messages,
|
|
n=self.config.group_size,
|
|
max_tokens=4096,
|
|
)
|
|
|
|
state = managed.get_state()
|
|
nodes = state["nodes"]
|
|
|
|
# Process nodes...
|
|
```
|
|
|
|
### Benefits Over Manual Handling
|
|
|
|
❌ **Without ManagedServer:**
|
|
- Manually tokenize prompts and completions
|
|
- Manually compute prompt lengths
|
|
- Manually apply masking logic
|
|
- Manually extract and align logprobs
|
|
- Prone to off-by-one errors
|
|
|
|
✅ **With ManagedServer:**
|
|
- Automatic tokenization
|
|
- Automatic masking
|
|
- Guaranteed alignment
|
|
- Clean, simple code
|
|
- Works with both `completion()` and `chat_completion()` APIs
|
|
|
|
### Complete Documentation
|
|
|
|
For detailed examples, advanced patterns (multi-turn, RLAIF, backlog workflows), API reference, and migration guide, see:
|
|
|
|
📚 **[ManagedServer Complete Guide](server_handling/MANAGED_SERVER.md)**
|
|
|
|
## Core Methods to Implement
|
|
|
|
These methods **must** be implemented in your subclass:
|
|
|
|
* **`async def setup(self)`**: This method is called once at the beginning of the environment's lifecycle (`env_manager`). Use it for any initial setup required for your specific environment, such as loading datasets, initializing models, or connecting to external resources.
|
|
|
|
* **`async def get_next_item(self) -> Item`**: This method is responsible for generating or retrieving the next piece of data (prompt, state, etc.) that will be used to start a new trajectory collection. If no more items are available or should be generated, it can return `None` to signal the worker to pause.
|
|
|
|
* **`async def collect_trajectories(self, item: Item) -> Tuple[Union[Optional[ScoredDataGroup], List[Optional[ScoredDataGroup]], List[Any | None]], List[Item]]`**: The default implementation of this method runs `collect_trajectory` (see below) multiple times in parallel (controlled by `group_size`). You can override this if you have a more efficient way to generate the entire group of responses/trajectories at once based on the input `item` (e.g. the `n` parameter in the OpenAI chat completions API) or some desired coupling of rollouts (e.g. via MCTS). It should return the collected group data and a list of backlog items.
|
|
|
|
* **`async def collect_trajectory(self, item: Item) -> Tuple[Any | ScoredDataItem | None, List[Item]]`**: If the rollouts for your environment can be sampled independently, the easiest way to implement GRPO-style grouping is to define the `collect_trajectory` method and use the default implementation of `collect_trajectories` which runs `group_size` instances of `collect_trajectory` in parallel. This method defines the logic for a *single* logical trajectory collection step based on the input `item`.
|
|
* **Return value**: It returns a tuple containing:\
|
|
1. The ScoredDataItem for this step (one trajectory). This data can be processed further in `postprocess_histories`, if you require additional filtering right before sending to the API.
|
|
2. A list of new `Item` objects to be added to the backlog for future processing (e.g., follow-up prompts).\
|
|
* **Should I define `collect_trajectory` or override `collect_trajectories`?** If you've got some way to generate your group more efficiently than a bunch of separate but parallel calls to `collect_trajectory`, or if your rollouts aren't independent as in MCTS, you should override `collect_trajectories`. If simplicity and iteration speed is more valuable than efficiency (e.g. at the start of a development cycle) and your rollouts are independent then `collect_trajectory` is for you.
|
|
|
|
* **`async def evaluate(self, *args, **kwargs)`**: This method is called periodically (controlled by `steps_per_eval` in the config) to perform evaluation runs. You define the evaluation logic here. The base class provides an example using `self.eval_workers` for parallel evaluation tasks, but you can implement any evaluation procedure suitable for your environment.
|
|
|
|
## Optional Methods to Override
|
|
|
|
These methods have default implementations or are optional based on your needs:
|
|
|
|
* **`async def collect_trajectories(self, item: Item) -> Tuple[Union[Optional[ScoredDataGroup], List[Optional[ScoredDataGroup]], List[Any | None]], List[Item]]`**: The default implementation of this method runs `collect_trajectory` multiple times in parallel (controlled by `group_size`). You can override this *instead* of `collect_trajectory` if you have a more efficient way to generate the entire group of responses/trajectories at once based on the input `item`. It should return the collected group data and a list of backlog items.
|
|
|
|
* **`async def postprocess_histories(self, trajectories: Union[Optional[ScoredDataGroup], List[Optional[ScoredDataGroup]]]) -> Union[Optional[ScoredDataGroup], List[Optional[ScoredDataGroup]]]`**: This method is called after `collect_trajectories` and before the data is sent to the training server. It receives the collected data from the parallel runs (or your custom `collect_trajectories` implementation). Use this to perform final processing, scoring, or formatting you may require before sending to the server. You usually won't need this.
|
|
|
|
* **`async def wandb_log(self, wandb_metrics: Optional[Dict] = None)`**: Called periodically to log metrics to Weights & Biases. If you override this to add custom metrics, **ensure you call `super().wandb_log(wandb_metrics)`** at the end of your implementation. This ensures that the base class's performance metrics and rollout tables are also logged.
|
|
```python
|
|
async def wandb_log(self, wandb_metrics: Optional[Dict] = None):
|
|
if wandb_metrics is None:
|
|
wandb_metrics = {}
|
|
# Add your custom metrics
|
|
wandb_metrics['my_custom_metric'] = calculate_my_metric()
|
|
# ... add more metrics
|
|
|
|
# Call the parent method to log base metrics
|
|
await super().wandb_log(wandb_metrics)
|
|
```
|
|
|
|
* **`save_checkpoint(self, step, data=None)`**: The base class calls this method automatically at checkpoint intervals determined by the server. It saves the provided `data` dictionary (which you might populate with environment-specific state) to a JSON file. You can override this to customize *what* data is saved or *how* it's saved (e.g., using a different format or location), but the triggering mechanism remains automatic.
|
|
|
|
* **`@classmethod config_init(cls) -> Tuple[BaseEnvConfig, Union[ServerBaseline, List[APIServerConfig]]]`**: This class method is used by the default `get_cli_serve_config_cls` implementation to get the initial environment configuration (`BaseEnvConfig` subclass) and server configurations (`ServerBaseline` or `List[APIServerConfig]`) when setting up the `serve` command. The default implementation returns `cls.env_config_cls(), ServerBaseline()`. You might override this if your environment requires different default configurations or specific server setups (like multiple `APIServerConfig` instances) when run via the CLI `serve` command.
|
|
|
|
* **`async def cleanup(self)`**: Called after each call to `handle_env`. You can implement this for any cleanup needed after processing a single item, though it's often not required.
|
|
|
|
## Overridable Class Variables
|
|
|
|
These class-level variables in `BaseEnv` can be overridden in your subclass to customize its behavior:
|
|
|
|
* **`name: Optional[str]`**:
|
|
* Default: `None`
|
|
* Purpose: You can set a string name for your environment. This name is used by default for `wandb_name` in the `BaseEnvConfig` if not otherwise specified, influencing how runs are grouped or named in Weights & Biases. It can also be useful for general identification or logging purposes.
|
|
|
|
* **`env_config_cls: Type[BaseEnvConfig]`**:
|
|
* Default: `BaseEnvConfig`
|
|
* Purpose: This variable holds the Pydantic model class that will be used for your environment's configuration. If your environment requires custom configuration fields beyond what `BaseEnvConfig` offers, you should create a new class that inherits from `BaseEnvConfig` (or a subclass thereof) and assign it to `env_config_cls`. This allows the CLI and other parts of the system to correctly parse and manage your environment's specific settings.
|
|
```python
|
|
from pydantic import Field
|
|
from atroposlib.envs import BaseEnv, BaseEnvConfig
|
|
|
|
class MyEnvConfig(BaseEnvConfig):
|
|
my_custom_param: str = Field(default="default_value", description="A custom parameter for MyEnv")
|
|
|
|
class MyEnv(BaseEnv):
|
|
env_config_cls = MyEnvConfig
|
|
name = "MyCustomEnvironment"
|
|
# ... other implementations
|
|
```
|
|
|
|
* **`server_cls: Type[APIServer]`**:
|
|
* Default: `APIServer`
|
|
* Purpose: Specifies the class to be used for managing interactions with API servers (e.g., inference endpoints). Should mostly be used for developing additional API interfaces, but if you need a nonstandard way of connecting with an existing API you can use this to easily slot in any modifications you need.
|
|
* **Note:** In most cases, you should use the `server_type` field in your `APIServerConfig` instead of overriding this. Set `server_type` to `"openai"` (default), `"vllm"`, `"sglang"`, or `"trl"` to automatically use the appropriate server class with enhanced features like native API access and full token/logprob tracking.
|
|
|
|
## Provided Functionality
|
|
|
|
`BaseEnv` provides several helpful features:
|
|
|
|
* **Parallel Trajectory Collection (`collect_trajectories`)**: The base implementation runs your `collect_trajectory` method multiple times in parallel (based on `group_size`) and gathers the results. You can override `collect_trajectories` directly for custom group generation logic (see Optional Methods).
|
|
* **Server Interaction**: Handles registration with the rollout server, fetching configuration (like `batch_size`), sending scored data (`handle_send_to_api` with retries), and status updates.
|
|
* **WandB Integration**: Sets up WandB logging (if enabled) based on server information and provides the `wandb_log` hook for custom metrics (remember to call `super().wandb_log()`). It uses helper methods `add_rollouts_for_wandb` (to temporarily store rollout data) and `create_rollout_table` (to format the data into a `wandb.Table`). You can override either of these helpers for custom logging behavior (e.g., changing what data is stored or how the final table is structured).
|
|
* **Checkpointing**:
|
|
* The environment automatically triggers checkpoint saves based on the `checkpoint_interval` received from the server, calling the `save_checkpoint` method (see Optional Methods).
|
|
* `load_checkpoint(self)`: Loads data from the checkpoint file corresponding to the environment's `curr_step`. It attempts to restore attributes of the environment object based on the keys in the loaded JSON data. This is called automatically if `curr_step > 0` during registration.
|
|
* **Worker Management**: Manages asynchronous worker tasks for collecting trajectories (`add_train_workers`, `handle_env`).
|
|
* **Performance Monitoring**: Tracks and logs various performance statistics (task durations, worker counts, etc.).
|
|
* **CLI Integration**: Provides a `cli()` class method using `pydantic-cli` to easily create command-line interfaces for your environment (e.g., `python your_env_module.py serve --port 8001 ...`). See `get_cli_serve_config_cls` and `get_cli_process_config_cls`.
|
|
|
|
By implementing the required methods and optionally overriding others, you can create diverse environments that leverage the distributed training infrastructure provided by the `Atropos` framework.
|