README changes

This commit is contained in:
Jai Suphavadeeprasit 2026-03-02 09:44:06 -05:00
parent 91afc9e46e
commit 4a7da8049f
2 changed files with 12 additions and 13 deletions

View file

@ -48,7 +48,7 @@ After `pip install -e .` from the repository root, you can launch with either:
|---------|---------------|
| **Advantage** | How much better/worse than average a response was |
| **Importance Sampling** | Corrects for policy drift during training |
| **Rollout Logprobs** | Provide `pi_old` needed for importance sampling ratios |
| **Rollout Logprobs** | Token-level `inference_logprobs` captured during rollout and used in ratio computation |
| **Clipping** | Limits update magnitude for stability |
@ -285,28 +285,24 @@ VLLM_ENABLE_SHARED_WEIGHTS=1 python -m example_trainer.run \
### 1. Use `--openai.server_type vllm` for Training
**CRITICAL:** The atropos environment MUST use `server_type=vllm` to get logprobs for proper GRPO training.
Only `server_type=vllm` calls the `/generate` endpoint which returns token-level logprobs. These logprobs serve as the rollout policy (`pi_old`) for importance sampling in GRPO.
For this example trainer implementation, set `--openai.server_type vllm` so the
environment uses the `/generate` path and includes token-level
`inference_logprobs` in the trajectory payload consumed by the trainer.
```bash
# CORRECT - gets logprobs for training (REQUIRED!)
--openai.server_type vllm
# WRONG for training - no logprobs, training will FAIL
# WRONG for this trainer path - missing rollout inference_logprobs
--openai.server_type openai
```
**What happens without logprobs:**
- The trainer will raise an error: "GRPO requires inference_logprobs for importance sampling!"
- Without rollout logprobs, GRPO updates are unsafe and training is aborted.
**How logprobs flow through the system:**
1. Environment calls vLLM `/generate` with `logprobs=true`
2. vLLM returns token-level logprobs for each generated token
3. Environment embeds these in trajectory data sent to API
4. Trainer extracts and aligns logprobs with training labels
5. GRPO loss uses logprobs as π_old for importance sampling ratio
5. GRPO loss uses these rollout logprobs in importance-ratio terms
### 2. Clipping Is Essential
@ -908,12 +904,12 @@ If your model has `N` layers:
- Apply temperature scaling (from data)
- Compute log probabilities per token
2. Reference Policy (π_old):
2. Rollout Logprobs:
- Extract from inference_logprobs (from vLLM at generation time)
- Already aligned with labels by data.py
3. Importance Sampling:
- log_ratio = log π_new(a|s) - log π_old(a|s)
- log_ratio = current_logprob - rollout_inference_logprob
- ratio = exp(log_ratio)
- Clipped ratio = clip(ratio, 1-ε, 1+ε)
@ -931,3 +927,6 @@ If your model has `N` layers:
- clipped_fraction: % of tokens clipped
- alignment/* : Token-level logprob alignment (verifies weight sharing)
```
For algorithm background and design tradeoffs, see:
- https://fengyao.notion.site/off-policy-rl