mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-22 16:48:57 +00:00
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
parent
d07ab3e3ce
commit
5cfd1929f1
19 changed files with 708 additions and 452 deletions
|
|
@ -18,21 +18,21 @@ example_trainer/
|
|||
├── vllm_manager.py # vLLM process management
|
||||
├── trainers.py # Training mode implementations
|
||||
├── vllm_api_server.py # Custom vLLM server (streamlined for training)
|
||||
├── vllm_patching/ # CUDA IPC patches for weight sharing basically overriding standard vllm for this
|
||||
│ └── patched_gpu_runner.py
|
||||
└── scripts/ # Helper scripts
|
||||
├── vllm_patching/ # CUDA IPC patches for weight sharing basically overriding standard vllm for this
|
||||
│ └── patched_gpu_runner.py
|
||||
└── scripts/ # Helper scripts
|
||||
├── test_lora_mode.sh
|
||||
└── test_single_copy_mode.sh
|
||||
```
|
||||
|
||||
|
||||
GRPO Training Loop
|
||||
GRPO Training Loop
|
||||
|
||||
1. Generate multiple responses to the same prompt
|
||||
2. Score each response (reward)
|
||||
3. Compute ADVANTAGE = reward - mean(rewards)
|
||||
4. Train: increase probability of above-average responses
|
||||
decrease probability of below-average responses
|
||||
1. Generate multiple responses to the same prompt
|
||||
2. Score each response (reward)
|
||||
3. Compute ADVANTAGE = reward - mean(rewards)
|
||||
4. Train: increase probability of above-average responses
|
||||
decrease probability of below-average responses
|
||||
```
|
||||
|
||||
### Key Concepts
|
||||
|
|
@ -330,7 +330,7 @@ The trainer creates **views** into vLLM's fused tensors:
|
|||
|
||||
# Get sizes from model config
|
||||
q_size = num_heads * head_dim # e.g., 4096
|
||||
k_size = num_kv_heads * head_dim # e.g., 1024
|
||||
k_size = num_kv_heads * head_dim # e.g., 1024
|
||||
v_size = num_kv_heads * head_dim # e.g., 1024
|
||||
|
||||
# Create views (no copy!)
|
||||
|
|
@ -542,4 +542,3 @@ python -m example_trainer.vllm_api_server # NOT direct vllm commands
|
|||
| `vllm_api_server.py` | Streamlined vLLM server for training |
|
||||
| `vllm_manager.py` | vLLM process lifecycle management |
|
||||
| `checkpointing.py` | Save/load checkpoints and adapters |
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue