python versioning problems

This commit is contained in:
Jai Suphavadeeprasit 2026-02-03 11:23:53 -05:00
parent bab3d85d85
commit d0b097974b
4 changed files with 5 additions and 105 deletions

View file

@ -123,7 +123,7 @@ def pad_data_to_good_offset(
# IMPORTANT: inference_logprobs is ALREADY ALIGNED with tokens/masks:
# - 1.0 for prompt tokens (masked positions)
# - actual negative logprobs for generated tokens
# We just need to pad to match the sequence length, no realignment needed!
# We just need to pad to match the sequence length
if extract_inference_logprobs and "inference_logprobs" in item:
if i < len(item["inference_logprobs"]):
raw_logprobs = np.array(item["inference_logprobs"][i], dtype=np.float32)
@ -140,10 +140,10 @@ def pad_data_to_good_offset(
# Shift by 1 to match causal label shift
inference_logprobs_padded.append(padded_logprobs[1:])
else:
# No logprobs for this sample, use 1.0 (masked placeholder)
# No logprobs for this sample, use 1.0
inference_logprobs_padded.append(np.full(token_setup_len - 1, 1.0, dtype=np.float32))
elif extract_inference_logprobs:
# No inference_logprobs in item, use 1.0 (masked placeholder)
# No inference_logprobs in item, use 1.0
inference_logprobs_padded.append(np.full(token_setup_len - 1, 1.0, dtype=np.float32))
# Extract temperature (priority: override > generation_params > group_overrides > 1.0)

View file

@ -1,19 +1,11 @@
#!/usr/bin/env python3
"""
GRPO Trainer - Multi-Mode Entry Point
Supports three training modes:
- none (legacy): Periodic checkpoint saves + vLLM restarts
- shared_vllm: Single-copy mode with CUDA IPC weight sharing
- lora_only: LoRA adapter training
For the unified single-command experience (shared_vllm + auto vLLM launch),
use run.py instead:
python example_trainer/run.py --model Qwen/Qwen3-4B --training-steps 20
This script requires vLLM to be running separately (except for legacy mode
which manages vLLM internally).
Usage:
# Legacy mode (manages vLLM internally)
python -m example_trainer.grpo --model-name Qwen/Qwen2.5-3B-Instruct

View file

@ -23,89 +23,6 @@ from .config import TrainingConfig
# Global storage for logprob alignment stats
_logprob_alignment_stats: Dict[str, float] = {}
# Global storage for weight verification
_weight_snapshot: Dict[str, float] = {}
def verify_vllm_sees_updates(model: torch.nn.Module, vllm_port: int, step: int) -> bool:
"""
Verify that vLLM actually sees weight updates by corrupting a weight
and checking if vLLM's output changes.
Returns True if vLLM sees updates, False otherwise.
"""
import requests
try:
# Find embedding layer
embed_param = None
for name, param in model.named_parameters():
if "embed_tokens" in name:
embed_param = param
break
if embed_param is None:
return True # Can't verify, assume OK
test_prompt = "Hello"
vllm_url = f"http://localhost:{vllm_port}"
# Get baseline
r1 = requests.post(
f"{vllm_url}/generate",
json={"prompt": test_prompt, "max_tokens": 3, "temperature": 0.0},
timeout=10,
)
baseline = r1.json().get("text", [""])[0] if r1.status_code == 200 else None
if baseline is None:
return True # Can't verify
# Corrupt weight
original = embed_param.data[0, 0].clone()
embed_param.data[0, 0] = 9999.0
# Query vLLM
r2 = requests.post(
f"{vllm_url}/generate",
json={"prompt": test_prompt, "max_tokens": 3, "temperature": 0.0},
timeout=10,
)
corrupted = r2.json().get("text", [""])[0] if r2.status_code == 200 else baseline
# Restore
embed_param.data[0, 0] = original
# Check if output changed
sharing_works = (corrupted != baseline)
if not sharing_works and step > 0:
print(f" [WARN] Step {step}: vLLM may not see weight updates!")
return sharing_works
except Exception:
return True # Can't verify, assume OK
def snapshot_weights(model: torch.nn.Module) -> Dict[str, float]:
"""Take a snapshot of sample weight values for comparison."""
snapshot = {}
for name, param in model.named_parameters():
if any(x in name for x in ["layers.0.", "layers.10.", "embed_tokens", "lm_head"]):
snapshot[name] = param.data.flatten()[0].item()
return snapshot
def compare_weight_snapshots(old: Dict[str, float], new: Dict[str, float]) -> Dict[str, float]:
"""Compare two weight snapshots and return differences."""
diffs = {}
for name in old:
if name in new:
diffs[name] = abs(new[name] - old[name])
return diffs
def setup_wandb(config: TrainingConfig) -> bool:
"""
Initialize Weights & Biases logging if enabled.
@ -161,7 +78,7 @@ def compute_grpo_loss(
Compute GRPO (Group Relative Policy Optimization) loss for a single micro-batch.
This implements proper GRPO/PPO with:
- Importance sampling ratio: π(a|s) / π_old(a|s)
- Importance sampling ratio: policy(a|s) / policy_old(a|s)
- PPO-style clipping to prevent large updates
- KL penalty to prevent reward hacking/policy collapse
@ -189,7 +106,7 @@ def compute_grpo_loss(
outputs = model(tokens)
logits = outputs.logits
# Temperature scaling for training
# Temperature scaling for training otherwise likely ratio is off
t = temperatures.to(logits.device, logits.dtype)
t = torch.where(t <= 0, torch.ones_like(t), t)
scaled_logits = logits / t

View file

@ -47,10 +47,6 @@ from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, List, Optional
# =============================================================================
# CRITICAL: Set up multiprocessing and vLLM engine BEFORE any CUDA imports
# =============================================================================
# Default to v0 engine to avoid CUDA fork issues with v1 engine
# Users can override with VLLM_USE_V1=1 if needed
os.environ.setdefault("VLLM_USE_V1", "0")
@ -168,11 +164,6 @@ except ImportError:
logger = init_logger("vllm.entrypoints.api_server")
# =============================================================================
# Global State
# =============================================================================
app = FastAPI()
engine: Optional[AsyncLLM] = None