mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-25 17:10:42 +00:00
Updated README
This commit is contained in:
parent
f5172b45a8
commit
826de9e283
1 changed files with 17 additions and 17 deletions
|
|
@ -15,7 +15,7 @@ This design pattern allows the agent to be trained on manageable segments that f
|
|||
|
||||
## The Problem: Ultra-Long Episode Sequences
|
||||
|
||||
Reinforcement learning for language models involves training on sequences of observations, thoughts, and actions. When an agent engages in detailed "thinking" steps (e.g., within `<think>...</think>` blocks), the token length of a single turn can be substantial. Over an entire episode, the total sequence length can easily exceed the maximum sequence length (e.g., 4k, 8k, or even 16k tokens) that contemporary LLMs can process in a single forward pass for training (ie, the maximum sequence length in the trainer). Even a short environment like Blackjack can blow this up in only a couple of steps if it happens to use very long thinking blocks. We could probably reasonably accomodate this for Blackjack by increasing the maximum sequence length, at the cost of more memory, but this isn't a great solution in general. Many environments can take hundreds or thousands of steps to complete - so a way of managing this is necessary. Blackjack is just used to demonstrate how this works - it can be much more simply implemented without thinking blocks as in `blackjack_env_no_thinking`, but even WITHOUT thinking blocks there's plenty of environments that will run far beyond any reasonable maximum sequence length anyway. Not to mention agents that might make use of RAG or some kind of external planning that will further bloat the token count.
|
||||
Reinforcement learning for language models involves training on sequences of observations, thoughts, and actions. When an agent engages in detailed "thinking" steps (e.g., within `<think>...</think>` blocks), the token length of a single turn can be substantial. Over an entire episode, the total sequence length can easily exceed the maximum sequence length (e.g., 4k, 8k, or even 16k tokens) that contemporary LLMs can process in a single forward pass for training (ie, the maximum sequence length in the trainer). Even a short environment like Blackjack can blow this up in only a couple of steps if it happens to use very long thinking blocks. We could probably accomodate this for Blackjack by increasing the maximum sequence length, at the cost of more resources for the trainer, but this isn't a great solution in general. Many environments can take hundreds or thousands of steps to complete - so a way of managing this is necessary. Blackjack is just used to demonstrate how this works - it can be much more simply implemented without thinking blocks as in `blackjack_env_no_thinking`, but even WITHOUT thinking blocks there's plenty of environments that will run far beyond any reasonable maximum sequence length anyway. Not to mention agents that might make use of RAG or some kind of external planning that will further bloat the token count.
|
||||
|
||||
Standard RL training loops assume that an entire episode or a significant, coherent sub-segment (rollout) can be fed into the policy network (ie, the LLM being trained). When episodes are orders of magnitude longer than the `seq_len`, a naive approach of simply truncating or randomly sampling segments breaks the coherence needed for effective policy evaluation and improvement, especially for algorithms like GRPO (**Group Relative Policy Optimization**) that rely on comparing alternative actions from a common state.
|
||||
|
||||
|
|
@ -27,8 +27,8 @@ Key components of this approach:
|
|||
|
||||
1. **Generating Alternatives (`_sample_response`)**: At each state \(s_t\) in an episode, the environment prompts the LLM to generate not one, but `G` (defined by `config.group_size`) different potential continuations (thoughts and actions). This is handled by the `_sample_response` method.
|
||||
|
||||
2. **Value-Guided Greedy Selection**: Playing out all `G` alternatives for the *entire* remaining episode (which could still be very long) is computationally infeasible for training and would again hit sequence length limits. Therefore, a greedy search approach is taken here (and could be modified for more sophisticated strategies like beam search or MCTS):
|
||||
* **Value Estimation (`_estimate_value`)**: The environment uses an internal `_estimate_value(s)` function. For Blackjack, this function calculates an accurate expected future reward (value) for any given game state `s`. This acts as a local critic or value function, crucial for evaluating the long-term prospects of immediate actions. It exhaustively explores all possible moves and rewards from the current state. Which, isn't going to work in every environment - but Blackjack has pretty short episodes and a small enough action & state space it's viable. Other strategies to explore and figure out future rewards are necessary for other environments.
|
||||
2. **Value-Guided Greedy Selection**: Playing out all `G` alternatives for the *entire* remaining episode (which could still be very long) is overkill for training and would again hit sequence length limits. Therefore, a greedy search approach is taken here (although it COULD be modified for more sophisticated strategies like beam search or MCTS if desired):
|
||||
* **Value Estimation (`_estimate_value`)**: The environment uses an internal `_estimate_value(s)` function. For Blackjack, this function calculates an accurate expected future reward (value) for any given game state `s`. This acts as a local critic or value function, crucial for evaluating the long-term prospects of immediate actions. It exhaustively explores all possible moves and rewards from the current state. Which, isn't going to work in every environment - but Blackjack has pretty short episodes and a small enough action & state space it's viable to do here. Other strategies to explore and figure out future rewards are necessary for other environments.
|
||||
* **Advantage Calculation**: For each of the `G` alternatives (\(a_i\)) sampled from state \(s_t\), the environment simulates the action, observes the immediate reward (\(R_i\)), and the next state (\(s'_{i}\)). It then calculates an advantage for each alternative, typically using the formula:
|
||||
\[ A(s_t, a_i) = R_i + \gamma V(s'_{i}) - V(s_t) \]
|
||||
(In `_collect_trajectory`, `gamma` is effectively 1, and \(R_i\) is represented by `alt_combined_rewards[i]`, \(V(s'_{i})\) by `alt_value_next[i]`, and \(V(s_t)\) by `value_t`).
|
||||
|
|
@ -38,15 +38,15 @@ Key components of this approach:
|
|||
* **Choosing the Path (`select_best_index`)**: The `select_best_index` function is then used to pick the alternative with the highest calculated advantage. This chosen alternative's action is what is actually "played" in the environment, advancing the episode to the next state `s_{t+1}`. The other `G-1` alternatives serve as counterfactual data for training. So, we end up with a "canonical" trajectory through the environment. For more comprehensive exploration of alternatives, we'd need to use some more comprehensive form of search like MCTS, which is overkill for something like Blackjack (but we'll demo in some other more complex environments to be added)
|
||||
|
||||
3. **Managing Historical Context Length (`truncate_thinking`, `ensure_trajectory_token_limit`)**: As an episode progresses, the history of observations, thoughts, and actions accumulates. To ensure that the prompt fed to the LLM for generating the *next* `G` alternatives remains within the operational context window (e.g., `max_prompt_tokens`), the environment employs truncation strategies:
|
||||
* Currently, we use simple truncation (e.g., removing the oldest messages or earlier parts of messages that probably don't have the LLMs final thoughts about it's decisions). More sophisticated context management techniques, such as summarization of earlier parts of the episode, could be implemented in the future to retain relevant historical information more effectively within the token constraints.
|
||||
* The `truncate_thinking` utility (from `atroposlib/utils/message_history_utils.py`) is used to shorten the content within `<think>...</think>` blocks in the message history if they exceed a certain token budget. This helps manage the verbosity of past thinking steps.
|
||||
* We use `ensure_trajectory_token_limit` in `blackjack_env_thinking.py` for truncating the overall message history (list of observations, thoughts, and chosen actions from previous turns) before it's used to prompt the LLM for the current step's alternatives. This ensures that the input prompt, including the system message and the current game state, doesn't exceed the LLM's maximum input token limit.
|
||||
* Currently, we use simple truncation (e.g., removing the oldest messages or earlier parts of messages that probably don't have the LLMs final thoughts about it's decisions). More sophisticated context management techniques, such as summarization of earlier parts of the episode, could be implemented in the future to retain relevant historical information more effectively within the token constraints.
|
||||
|
||||
This step is crucial because, without it, the growing history would quickly make it impossible to generate new actions as the episode continues. It allows the trajectory "windows" in each ScoredDataGroup to be trained on longer, and keep the multiturn nature of the training intact.
|
||||
|
||||
## Data Structure for the GRPO Trainer
|
||||
|
||||
The critical insight is how data is packaged for the GRPO trainer:
|
||||
How data is packaged for the GRPO trainer:
|
||||
|
||||
* At each step `t` of the actual trajectory taken by the agent, the `_collect_trajectory` method compiles a `BlackjackScoredDataGroup`.
|
||||
* This single `BlackjackScoredDataGroup` contains the full text (including thoughts), tokenized representations (`tokens`), attention `masks`, and `scores` (which are the \(A(s_t, a_i)\) values) for **all `G` alternatives** that were considered at state \(s_t\).
|
||||
|
|
@ -62,7 +62,7 @@ The GRPO trainer typically computes a loss using these advantages. For example,
|
|||
\[ L = -\sum_{j=1}^{M} \sum_{k=1}^{K_j} \left( \frac{\pi_{\theta}(a_{jk} | s_j)}{\pi_{\theta_{\text{old}}}(a_{jk} | s_j)} A_{jk}^{\text{GRPO}} \right) \]
|
||||
(often with a KL divergence penalty for stability, ensuring the new policy \(\pi_{\theta}\) doesn\'t deviate too drastically from the old policy \(\pi_{\theta_{\text{old}}}\\)). The `ratio = torch.exp(logp - logp.detach())` and `loss = -reward * ratio` (where `reward` is the \(A_{jk}^{\text{GRPO}}\) advantage) in a typical trainer snippet would align with this principle.
|
||||
|
||||
The `blackjack_env_thinking` environment's design is compatible with GRPO's core requirements for input data BUT allowing it to be used across long trajectories. We don't get a nice, well defined reward at every step of every environments - but we want to keep that nice, objective, outcome-oriented reward structure, even in reward-sparse environments.
|
||||
The `blackjack_env_thinking` environment's design is compatible with GRPO's core requirements for input data BUT allowing it to be used across long trajectories. We don't get a nice, well defined reward at every step of every environment - but we want to keep that nice, objective, outcome-oriented RLVR-style reward structure, even in reward-sparse environments.
|
||||
|
||||
1. **Alternative Generation**: From a state \(s_t\), the environment generates `G` alternative continuations (thoughts and actions \(a_1, ..., a_G\)).
|
||||
2. **Value-Informed Scoring (within the environment)**: For each alternative \(a_i\), the environment itself calculates a "score" using its internal value estimation: \(S_i = R_i + \gamma V_{\text{env}}(s'_{i}) - V_{\text{env}}(s_t)\). This score, \(S_i\), represents a local, value-informed assessment of that alternative's quality.
|
||||
|
|
@ -77,7 +77,7 @@ This two-step process is vital:
|
|||
* The environment provides high-quality, comparable, and value-informed *reward signals* (our scores \(S_i\)) for a set of diverse actions originating from the exact same state. The subtraction of \(V_{\text{env}}(s_t)\) in our score calculation helps to center these reward signals, potentially aiding training stability.
|
||||
* The GRPO algorithm then applies its group-relative normalization to these rewards to derive the advantages that drive policy learning. This ensures that the policy is updated based on how much better or worse an alternative is compared to the *average quality of alternatives considered at that specific step*.
|
||||
|
||||
This structure ensures that the GRPO trainer receives a rich dataset where each group of `G` items represents directly comparable choices from a common decision point (state \(s_t\)), each with a meaningful reward signal attached. The environment's scoring and selection mechanism ensures that the data generated is not only diverse but also reflects locally optimal decision-making. This is particularly handy for long sequences because it allows the model to learn from nuanced differences in multi-step "thinking" paths that all originate from the same immediate context. Greedy selection is faster for the training loop, and probably sufficient for a simple environment like Blackjack, but as mentioned previously you could (somewhat) easily use MCTS or beam search instead to explore more complex environments thoroughly.
|
||||
This structure ensures that the GRPO trainer receives a rich dataset where each group of `G` items represents directly comparable choices from a common decision point (state \(s_t\)), each with a meaningful reward signal attached. The environment's scoring and selection mechanism ensures that the data generated is not only diverse but also reflects locally optimal decision-making. This is particularly handy for long sequences because it allows the model to learn from nuanced differences in multi-step "thinking" paths that all originate from the same immediate context. Greedy selection is faster for the training loop, and probably sufficient for a simple environment like Blackjack, but as mentioned previously you could (somewhat) easily upgrade to beam search or MCTS instead to explore more complex environments thoroughly.
|
||||
|
||||
**Why a naive approach of chunking independent long rollouts fails for GRPO:**
|
||||
|
||||
|
|
@ -86,20 +86,20 @@ Imagine having `G` complete, very long, independent rollouts. If we simply chopp
|
|||
* `rollout_2_chunk_k` starts at a completely different state \(S_{2,k}\).
|
||||
The "advantages" or scores calculated for these chunks would not be comparable in the way GRPO intends, as they don't represent alternative choices from a common decision point. No Bueno! The `blackjack_env_thinking` design ensures this common-state origin for each group of `G` alternatives it sends to the trainer.
|
||||
|
||||
## Contrast with Simpler, Shorter Episodes (`blackjack_env_no_thinking`)
|
||||
|
||||
In the `blackjack_env_no_thinking` environment:
|
||||
* Episodes are very short (a few turns, no long thinking blocks).
|
||||
* The entire sequence of (observation, action, LLM response) usually fits within the model's `seq_len`.
|
||||
* `collect_trajectory` returns a single `ScoredDataItem` representing the full episode. The "score" is simply the final game outcome (e.g., +1 for a win).
|
||||
* The trainer can then process these entire episodes, for example, by calculating Monte Carlo returns for each step or using other standard RL techniques. The complexity of per-step alternative generation for windowing and local value estimation is not primarily needed for fitting within `seq_len`.
|
||||
|
||||
## Local Exhaustive Exploration and Value Estimation
|
||||
## Value Estimation via Local Exhaustive Exploration
|
||||
|
||||
The `blackjack_env_thinking` strategy of generating `G` alternatives, simulating each for one step, and then using `_estimate_value` to predict future outcomes allows for a *local exhaustive search* around the current state. This is feasible here because:
|
||||
1. Blackjack step simulation is computationally cheap.
|
||||
2. `G` is a manageable number (e.g., 16 to 32).
|
||||
3. An accurate value function \(V(s)\) can be reasonably implemented for a deterministic, small-state-space game like Blackjack (the `_get_v_star_recursive` in `_estimate_value` gives an exact calculation).
|
||||
|
||||
If an accurate value function were not available (e.g., in a more complex environment or if it needed to be learned by a separate model), the quality of the greedy rollouts and the calculated advantages would depend on the accuracy of this learned value estimate or something like VinePPO's Monte Carlo rollouts to get a similar read on future rewards.
|
||||
If an accurate value function were not available (e.g., in a more complex environment or if it needed to be learned by a separate model), the quality of the greedy rollouts and the calculated advantages would depend on the accuracy of this learned value estimate or something like VinePPO's Monte Carlo rollouts to get a similar estimate of future rewards.
|
||||
|
||||
## Contrast with Simpler, Shorter Episodes (`blackjack_env_no_thinking`)
|
||||
|
||||
In the `blackjack_env_no_thinking` environment:
|
||||
* Episodes are very short (no long thinking blocks!!)
|
||||
* The entire sequence of (observation, action, LLM response) usually fits within the model's `seq_len`. Blackjack is at most a few turns, so this is ok if you JUST want to train on actions, not additional long chains of thought.
|
||||
* `collect_trajectory` returns a single `ScoredDataItem` representing the full episode. The "score" is simply the final game outcome (e.g., +1 for a win) and some bonuses for formatting and correct tool calling.
|
||||
* The trainer can then process these entire episodes using the normal GRPO method (ie, we're just sending the full alternative trajectories and their scores to be compared, similar to the single-step bandit problems people are commonly using for RLVR). The complexity of per-step alternative generation for windowing and local value estimation isn't needed for fitting within `seq_len`.
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue