diff --git a/environments/community/mcp_tool_calling/GRPO_README.md b/environments/community/mcp_tool_calling/GRPO_README.md index 911856b5..f176c75f 100644 --- a/environments/community/mcp_tool_calling/GRPO_README.md +++ b/environments/community/mcp_tool_calling/GRPO_README.md @@ -86,7 +86,7 @@ atropos-grpo \ ## Objective Notes -- GRPO uses rollout/inference logprobs (`pi_old`) for importance-ratio computation. +- GRPO uses rollout `inference_logprobs` for importance-ratio computation. - The trainer currently uses clipped importance-ratio updates without a separate frozen-reference-model KL term. ## Outputs diff --git a/example_trainer/README.md b/example_trainer/README.md index 29820d66..1ad3944b 100644 --- a/example_trainer/README.md +++ b/example_trainer/README.md @@ -48,7 +48,7 @@ After `pip install -e .` from the repository root, you can launch with either: |---------|---------------| | **Advantage** | How much better/worse than average a response was | | **Importance Sampling** | Corrects for policy drift during training | -| **Rollout Logprobs** | Provide `pi_old` needed for importance sampling ratios | +| **Rollout Logprobs** | Token-level `inference_logprobs` captured during rollout and used in ratio computation | | **Clipping** | Limits update magnitude for stability | @@ -285,28 +285,24 @@ VLLM_ENABLE_SHARED_WEIGHTS=1 python -m example_trainer.run \ ### 1. Use `--openai.server_type vllm` for Training -**CRITICAL:** The atropos environment MUST use `server_type=vllm` to get logprobs for proper GRPO training. - -Only `server_type=vllm` calls the `/generate` endpoint which returns token-level logprobs. These logprobs serve as the rollout policy (`pi_old`) for importance sampling in GRPO. +For this example trainer implementation, set `--openai.server_type vllm` so the +environment uses the `/generate` path and includes token-level +`inference_logprobs` in the trajectory payload consumed by the trainer. ```bash # CORRECT - gets logprobs for training (REQUIRED!) --openai.server_type vllm -# WRONG for training - no logprobs, training will FAIL +# WRONG for this trainer path - missing rollout inference_logprobs --openai.server_type openai ``` -**What happens without logprobs:** -- The trainer will raise an error: "GRPO requires inference_logprobs for importance sampling!" -- Without rollout logprobs, GRPO updates are unsafe and training is aborted. - **How logprobs flow through the system:** 1. Environment calls vLLM `/generate` with `logprobs=true` 2. vLLM returns token-level logprobs for each generated token 3. Environment embeds these in trajectory data sent to API 4. Trainer extracts and aligns logprobs with training labels -5. GRPO loss uses logprobs as π_old for importance sampling ratio +5. GRPO loss uses these rollout logprobs in importance-ratio terms ### 2. Clipping Is Essential @@ -908,12 +904,12 @@ If your model has `N` layers: - Apply temperature scaling (from data) - Compute log probabilities per token -2. Reference Policy (π_old): +2. Rollout Logprobs: - Extract from inference_logprobs (from vLLM at generation time) - Already aligned with labels by data.py 3. Importance Sampling: - - log_ratio = log π_new(a|s) - log π_old(a|s) + - log_ratio = current_logprob - rollout_inference_logprob - ratio = exp(log_ratio) - Clipped ratio = clip(ratio, 1-ε, 1+ε) @@ -931,3 +927,6 @@ If your model has `N` layers: - clipped_fraction: % of tokens clipped - alignment/* : Token-level logprob alignment (verifies weight sharing) ``` + +For algorithm background and design tradeoffs, see: +- https://fengyao.notion.site/off-policy-rl