mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-24 17:04:55 +00:00
python versioning problems
This commit is contained in:
parent
bab3d85d85
commit
d0b097974b
4 changed files with 5 additions and 105 deletions
|
|
@ -23,89 +23,6 @@ from .config import TrainingConfig
|
|||
# Global storage for logprob alignment stats
|
||||
_logprob_alignment_stats: Dict[str, float] = {}
|
||||
|
||||
# Global storage for weight verification
|
||||
_weight_snapshot: Dict[str, float] = {}
|
||||
|
||||
|
||||
def verify_vllm_sees_updates(model: torch.nn.Module, vllm_port: int, step: int) -> bool:
|
||||
"""
|
||||
Verify that vLLM actually sees weight updates by corrupting a weight
|
||||
and checking if vLLM's output changes.
|
||||
|
||||
Returns True if vLLM sees updates, False otherwise.
|
||||
"""
|
||||
import requests
|
||||
|
||||
try:
|
||||
# Find embedding layer
|
||||
embed_param = None
|
||||
for name, param in model.named_parameters():
|
||||
if "embed_tokens" in name:
|
||||
embed_param = param
|
||||
break
|
||||
|
||||
if embed_param is None:
|
||||
return True # Can't verify, assume OK
|
||||
|
||||
test_prompt = "Hello"
|
||||
vllm_url = f"http://localhost:{vllm_port}"
|
||||
|
||||
# Get baseline
|
||||
r1 = requests.post(
|
||||
f"{vllm_url}/generate",
|
||||
json={"prompt": test_prompt, "max_tokens": 3, "temperature": 0.0},
|
||||
timeout=10,
|
||||
)
|
||||
baseline = r1.json().get("text", [""])[0] if r1.status_code == 200 else None
|
||||
|
||||
if baseline is None:
|
||||
return True # Can't verify
|
||||
|
||||
# Corrupt weight
|
||||
original = embed_param.data[0, 0].clone()
|
||||
embed_param.data[0, 0] = 9999.0
|
||||
|
||||
# Query vLLM
|
||||
r2 = requests.post(
|
||||
f"{vllm_url}/generate",
|
||||
json={"prompt": test_prompt, "max_tokens": 3, "temperature": 0.0},
|
||||
timeout=10,
|
||||
)
|
||||
corrupted = r2.json().get("text", [""])[0] if r2.status_code == 200 else baseline
|
||||
|
||||
# Restore
|
||||
embed_param.data[0, 0] = original
|
||||
|
||||
# Check if output changed
|
||||
sharing_works = (corrupted != baseline)
|
||||
|
||||
if not sharing_works and step > 0:
|
||||
print(f" [WARN] Step {step}: vLLM may not see weight updates!")
|
||||
|
||||
return sharing_works
|
||||
|
||||
except Exception:
|
||||
return True # Can't verify, assume OK
|
||||
|
||||
|
||||
def snapshot_weights(model: torch.nn.Module) -> Dict[str, float]:
|
||||
"""Take a snapshot of sample weight values for comparison."""
|
||||
snapshot = {}
|
||||
for name, param in model.named_parameters():
|
||||
if any(x in name for x in ["layers.0.", "layers.10.", "embed_tokens", "lm_head"]):
|
||||
snapshot[name] = param.data.flatten()[0].item()
|
||||
return snapshot
|
||||
|
||||
|
||||
def compare_weight_snapshots(old: Dict[str, float], new: Dict[str, float]) -> Dict[str, float]:
|
||||
"""Compare two weight snapshots and return differences."""
|
||||
diffs = {}
|
||||
for name in old:
|
||||
if name in new:
|
||||
diffs[name] = abs(new[name] - old[name])
|
||||
return diffs
|
||||
|
||||
|
||||
def setup_wandb(config: TrainingConfig) -> bool:
|
||||
"""
|
||||
Initialize Weights & Biases logging if enabled.
|
||||
|
|
@ -161,7 +78,7 @@ def compute_grpo_loss(
|
|||
Compute GRPO (Group Relative Policy Optimization) loss for a single micro-batch.
|
||||
|
||||
This implements proper GRPO/PPO with:
|
||||
- Importance sampling ratio: π(a|s) / π_old(a|s)
|
||||
- Importance sampling ratio: policy(a|s) / policy_old(a|s)
|
||||
- PPO-style clipping to prevent large updates
|
||||
- KL penalty to prevent reward hacking/policy collapse
|
||||
|
||||
|
|
@ -189,7 +106,7 @@ def compute_grpo_loss(
|
|||
outputs = model(tokens)
|
||||
logits = outputs.logits
|
||||
|
||||
# Temperature scaling for training
|
||||
# Temperature scaling for training otherwise likely ratio is off
|
||||
t = temperatures.to(logits.device, logits.dtype)
|
||||
t = torch.where(t <= 0, torch.ones_like(t), t)
|
||||
scaled_logits = logits / t
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue