mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-30 17:40:36 +00:00
python versioning problems
This commit is contained in:
parent
bab3d85d85
commit
d0b097974b
4 changed files with 5 additions and 105 deletions
|
|
@ -123,7 +123,7 @@ def pad_data_to_good_offset(
|
|||
# IMPORTANT: inference_logprobs is ALREADY ALIGNED with tokens/masks:
|
||||
# - 1.0 for prompt tokens (masked positions)
|
||||
# - actual negative logprobs for generated tokens
|
||||
# We just need to pad to match the sequence length, no realignment needed!
|
||||
# We just need to pad to match the sequence length
|
||||
if extract_inference_logprobs and "inference_logprobs" in item:
|
||||
if i < len(item["inference_logprobs"]):
|
||||
raw_logprobs = np.array(item["inference_logprobs"][i], dtype=np.float32)
|
||||
|
|
@ -140,10 +140,10 @@ def pad_data_to_good_offset(
|
|||
# Shift by 1 to match causal label shift
|
||||
inference_logprobs_padded.append(padded_logprobs[1:])
|
||||
else:
|
||||
# No logprobs for this sample, use 1.0 (masked placeholder)
|
||||
# No logprobs for this sample, use 1.0
|
||||
inference_logprobs_padded.append(np.full(token_setup_len - 1, 1.0, dtype=np.float32))
|
||||
elif extract_inference_logprobs:
|
||||
# No inference_logprobs in item, use 1.0 (masked placeholder)
|
||||
# No inference_logprobs in item, use 1.0
|
||||
inference_logprobs_padded.append(np.full(token_setup_len - 1, 1.0, dtype=np.float32))
|
||||
|
||||
# Extract temperature (priority: override > generation_params > group_overrides > 1.0)
|
||||
|
|
|
|||
|
|
@ -1,19 +1,11 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
GRPO Trainer - Multi-Mode Entry Point
|
||||
|
||||
Supports three training modes:
|
||||
- none (legacy): Periodic checkpoint saves + vLLM restarts
|
||||
- shared_vllm: Single-copy mode with CUDA IPC weight sharing
|
||||
- lora_only: LoRA adapter training
|
||||
|
||||
For the unified single-command experience (shared_vllm + auto vLLM launch),
|
||||
use run.py instead:
|
||||
python example_trainer/run.py --model Qwen/Qwen3-4B --training-steps 20
|
||||
|
||||
This script requires vLLM to be running separately (except for legacy mode
|
||||
which manages vLLM internally).
|
||||
|
||||
Usage:
|
||||
# Legacy mode (manages vLLM internally)
|
||||
python -m example_trainer.grpo --model-name Qwen/Qwen2.5-3B-Instruct
|
||||
|
|
|
|||
|
|
@ -23,89 +23,6 @@ from .config import TrainingConfig
|
|||
# Global storage for logprob alignment stats
|
||||
_logprob_alignment_stats: Dict[str, float] = {}
|
||||
|
||||
# Global storage for weight verification
|
||||
_weight_snapshot: Dict[str, float] = {}
|
||||
|
||||
|
||||
def verify_vllm_sees_updates(model: torch.nn.Module, vllm_port: int, step: int) -> bool:
|
||||
"""
|
||||
Verify that vLLM actually sees weight updates by corrupting a weight
|
||||
and checking if vLLM's output changes.
|
||||
|
||||
Returns True if vLLM sees updates, False otherwise.
|
||||
"""
|
||||
import requests
|
||||
|
||||
try:
|
||||
# Find embedding layer
|
||||
embed_param = None
|
||||
for name, param in model.named_parameters():
|
||||
if "embed_tokens" in name:
|
||||
embed_param = param
|
||||
break
|
||||
|
||||
if embed_param is None:
|
||||
return True # Can't verify, assume OK
|
||||
|
||||
test_prompt = "Hello"
|
||||
vllm_url = f"http://localhost:{vllm_port}"
|
||||
|
||||
# Get baseline
|
||||
r1 = requests.post(
|
||||
f"{vllm_url}/generate",
|
||||
json={"prompt": test_prompt, "max_tokens": 3, "temperature": 0.0},
|
||||
timeout=10,
|
||||
)
|
||||
baseline = r1.json().get("text", [""])[0] if r1.status_code == 200 else None
|
||||
|
||||
if baseline is None:
|
||||
return True # Can't verify
|
||||
|
||||
# Corrupt weight
|
||||
original = embed_param.data[0, 0].clone()
|
||||
embed_param.data[0, 0] = 9999.0
|
||||
|
||||
# Query vLLM
|
||||
r2 = requests.post(
|
||||
f"{vllm_url}/generate",
|
||||
json={"prompt": test_prompt, "max_tokens": 3, "temperature": 0.0},
|
||||
timeout=10,
|
||||
)
|
||||
corrupted = r2.json().get("text", [""])[0] if r2.status_code == 200 else baseline
|
||||
|
||||
# Restore
|
||||
embed_param.data[0, 0] = original
|
||||
|
||||
# Check if output changed
|
||||
sharing_works = (corrupted != baseline)
|
||||
|
||||
if not sharing_works and step > 0:
|
||||
print(f" [WARN] Step {step}: vLLM may not see weight updates!")
|
||||
|
||||
return sharing_works
|
||||
|
||||
except Exception:
|
||||
return True # Can't verify, assume OK
|
||||
|
||||
|
||||
def snapshot_weights(model: torch.nn.Module) -> Dict[str, float]:
|
||||
"""Take a snapshot of sample weight values for comparison."""
|
||||
snapshot = {}
|
||||
for name, param in model.named_parameters():
|
||||
if any(x in name for x in ["layers.0.", "layers.10.", "embed_tokens", "lm_head"]):
|
||||
snapshot[name] = param.data.flatten()[0].item()
|
||||
return snapshot
|
||||
|
||||
|
||||
def compare_weight_snapshots(old: Dict[str, float], new: Dict[str, float]) -> Dict[str, float]:
|
||||
"""Compare two weight snapshots and return differences."""
|
||||
diffs = {}
|
||||
for name in old:
|
||||
if name in new:
|
||||
diffs[name] = abs(new[name] - old[name])
|
||||
return diffs
|
||||
|
||||
|
||||
def setup_wandb(config: TrainingConfig) -> bool:
|
||||
"""
|
||||
Initialize Weights & Biases logging if enabled.
|
||||
|
|
@ -161,7 +78,7 @@ def compute_grpo_loss(
|
|||
Compute GRPO (Group Relative Policy Optimization) loss for a single micro-batch.
|
||||
|
||||
This implements proper GRPO/PPO with:
|
||||
- Importance sampling ratio: π(a|s) / π_old(a|s)
|
||||
- Importance sampling ratio: policy(a|s) / policy_old(a|s)
|
||||
- PPO-style clipping to prevent large updates
|
||||
- KL penalty to prevent reward hacking/policy collapse
|
||||
|
||||
|
|
@ -189,7 +106,7 @@ def compute_grpo_loss(
|
|||
outputs = model(tokens)
|
||||
logits = outputs.logits
|
||||
|
||||
# Temperature scaling for training
|
||||
# Temperature scaling for training otherwise likely ratio is off
|
||||
t = temperatures.to(logits.device, logits.dtype)
|
||||
t = torch.where(t <= 0, torch.ones_like(t), t)
|
||||
scaled_logits = logits / t
|
||||
|
|
|
|||
|
|
@ -47,10 +47,6 @@ from dataclasses import dataclass, field
|
|||
from pathlib import Path
|
||||
from typing import Any, List, Optional
|
||||
|
||||
# =============================================================================
|
||||
# CRITICAL: Set up multiprocessing and vLLM engine BEFORE any CUDA imports
|
||||
# =============================================================================
|
||||
|
||||
# Default to v0 engine to avoid CUDA fork issues with v1 engine
|
||||
# Users can override with VLLM_USE_V1=1 if needed
|
||||
os.environ.setdefault("VLLM_USE_V1", "0")
|
||||
|
|
@ -168,11 +164,6 @@ except ImportError:
|
|||
|
||||
logger = init_logger("vllm.entrypoints.api_server")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Global State
|
||||
# =============================================================================
|
||||
|
||||
app = FastAPI()
|
||||
engine: Optional[AsyncLLM] = None
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue