mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
README changes
This commit is contained in:
parent
91afc9e46e
commit
4a7da8049f
2 changed files with 12 additions and 13 deletions
|
|
@ -48,7 +48,7 @@ After `pip install -e .` from the repository root, you can launch with either:
|
|||
|---------|---------------|
|
||||
| **Advantage** | How much better/worse than average a response was |
|
||||
| **Importance Sampling** | Corrects for policy drift during training |
|
||||
| **Rollout Logprobs** | Provide `pi_old` needed for importance sampling ratios |
|
||||
| **Rollout Logprobs** | Token-level `inference_logprobs` captured during rollout and used in ratio computation |
|
||||
| **Clipping** | Limits update magnitude for stability |
|
||||
|
||||
|
||||
|
|
@ -285,28 +285,24 @@ VLLM_ENABLE_SHARED_WEIGHTS=1 python -m example_trainer.run \
|
|||
|
||||
### 1. Use `--openai.server_type vllm` for Training
|
||||
|
||||
**CRITICAL:** The atropos environment MUST use `server_type=vllm` to get logprobs for proper GRPO training.
|
||||
|
||||
Only `server_type=vllm` calls the `/generate` endpoint which returns token-level logprobs. These logprobs serve as the rollout policy (`pi_old`) for importance sampling in GRPO.
|
||||
For this example trainer implementation, set `--openai.server_type vllm` so the
|
||||
environment uses the `/generate` path and includes token-level
|
||||
`inference_logprobs` in the trajectory payload consumed by the trainer.
|
||||
|
||||
```bash
|
||||
# CORRECT - gets logprobs for training (REQUIRED!)
|
||||
--openai.server_type vllm
|
||||
|
||||
# WRONG for training - no logprobs, training will FAIL
|
||||
# WRONG for this trainer path - missing rollout inference_logprobs
|
||||
--openai.server_type openai
|
||||
```
|
||||
|
||||
**What happens without logprobs:**
|
||||
- The trainer will raise an error: "GRPO requires inference_logprobs for importance sampling!"
|
||||
- Without rollout logprobs, GRPO updates are unsafe and training is aborted.
|
||||
|
||||
**How logprobs flow through the system:**
|
||||
1. Environment calls vLLM `/generate` with `logprobs=true`
|
||||
2. vLLM returns token-level logprobs for each generated token
|
||||
3. Environment embeds these in trajectory data sent to API
|
||||
4. Trainer extracts and aligns logprobs with training labels
|
||||
5. GRPO loss uses logprobs as π_old for importance sampling ratio
|
||||
5. GRPO loss uses these rollout logprobs in importance-ratio terms
|
||||
|
||||
### 2. Clipping Is Essential
|
||||
|
||||
|
|
@ -908,12 +904,12 @@ If your model has `N` layers:
|
|||
- Apply temperature scaling (from data)
|
||||
- Compute log probabilities per token
|
||||
|
||||
2. Reference Policy (π_old):
|
||||
2. Rollout Logprobs:
|
||||
- Extract from inference_logprobs (from vLLM at generation time)
|
||||
- Already aligned with labels by data.py
|
||||
|
||||
3. Importance Sampling:
|
||||
- log_ratio = log π_new(a|s) - log π_old(a|s)
|
||||
- log_ratio = current_logprob - rollout_inference_logprob
|
||||
- ratio = exp(log_ratio)
|
||||
- Clipped ratio = clip(ratio, 1-ε, 1+ε)
|
||||
|
||||
|
|
@ -931,3 +927,6 @@ If your model has `N` layers:
|
|||
- clipped_fraction: % of tokens clipped
|
||||
- alignment/* : Token-level logprob alignment (verifies weight sharing)
|
||||
```
|
||||
|
||||
For algorithm background and design tradeoffs, see:
|
||||
- https://fengyao.notion.site/off-policy-rl
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue