python versioning problems

This commit is contained in:
Jai Suphavadeeprasit 2026-02-03 11:23:53 -05:00
parent bab3d85d85
commit d0b097974b
4 changed files with 5 additions and 105 deletions

View file

@ -23,89 +23,6 @@ from .config import TrainingConfig
# Global storage for logprob alignment stats
_logprob_alignment_stats: Dict[str, float] = {}
# Global storage for weight verification
_weight_snapshot: Dict[str, float] = {}
def verify_vllm_sees_updates(model: torch.nn.Module, vllm_port: int, step: int) -> bool:
"""
Verify that vLLM actually sees weight updates by corrupting a weight
and checking if vLLM's output changes.
Returns True if vLLM sees updates, False otherwise.
"""
import requests
try:
# Find embedding layer
embed_param = None
for name, param in model.named_parameters():
if "embed_tokens" in name:
embed_param = param
break
if embed_param is None:
return True # Can't verify, assume OK
test_prompt = "Hello"
vllm_url = f"http://localhost:{vllm_port}"
# Get baseline
r1 = requests.post(
f"{vllm_url}/generate",
json={"prompt": test_prompt, "max_tokens": 3, "temperature": 0.0},
timeout=10,
)
baseline = r1.json().get("text", [""])[0] if r1.status_code == 200 else None
if baseline is None:
return True # Can't verify
# Corrupt weight
original = embed_param.data[0, 0].clone()
embed_param.data[0, 0] = 9999.0
# Query vLLM
r2 = requests.post(
f"{vllm_url}/generate",
json={"prompt": test_prompt, "max_tokens": 3, "temperature": 0.0},
timeout=10,
)
corrupted = r2.json().get("text", [""])[0] if r2.status_code == 200 else baseline
# Restore
embed_param.data[0, 0] = original
# Check if output changed
sharing_works = (corrupted != baseline)
if not sharing_works and step > 0:
print(f" [WARN] Step {step}: vLLM may not see weight updates!")
return sharing_works
except Exception:
return True # Can't verify, assume OK
def snapshot_weights(model: torch.nn.Module) -> Dict[str, float]:
"""Take a snapshot of sample weight values for comparison."""
snapshot = {}
for name, param in model.named_parameters():
if any(x in name for x in ["layers.0.", "layers.10.", "embed_tokens", "lm_head"]):
snapshot[name] = param.data.flatten()[0].item()
return snapshot
def compare_weight_snapshots(old: Dict[str, float], new: Dict[str, float]) -> Dict[str, float]:
"""Compare two weight snapshots and return differences."""
diffs = {}
for name in old:
if name in new:
diffs[name] = abs(new[name] - old[name])
return diffs
def setup_wandb(config: TrainingConfig) -> bool:
"""
Initialize Weights & Biases logging if enabled.
@ -161,7 +78,7 @@ def compute_grpo_loss(
Compute GRPO (Group Relative Policy Optimization) loss for a single micro-batch.
This implements proper GRPO/PPO with:
- Importance sampling ratio: π(a|s) / π_old(a|s)
- Importance sampling ratio: policy(a|s) / policy_old(a|s)
- PPO-style clipping to prevent large updates
- KL penalty to prevent reward hacking/policy collapse
@ -189,7 +106,7 @@ def compute_grpo_loss(
outputs = model(tokens)
logits = outputs.logits
# Temperature scaling for training
# Temperature scaling for training otherwise likely ratio is off
t = temperatures.to(logits.device, logits.dtype)
t = torch.where(t <= 0, torch.ones_like(t), t)
scaled_logits = logits / t