mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-23 16:54:56 +00:00
readme updates
This commit is contained in:
parent
998c062709
commit
41cab9d52a
2 changed files with 47 additions and 15 deletions
|
|
@ -291,17 +291,19 @@ def train_shared_vllm(config: TrainingConfig):
|
|||
if trainer_logprobs and vllm_logprobs:
|
||||
for i, (vlp, tlp) in enumerate(zip(vllm_logprobs, trainer_logprobs)):
|
||||
diff = abs(vlp - tlp)
|
||||
status = "✓" if diff < 0.1 else "✗"
|
||||
status = "✓" if diff < 0.25 else "⚠" # 0.25 threshold accounts for impl differences
|
||||
print(f" Token {i}: vLLM={vlp:.4f}, Trainer={tlp:.4f}, diff={diff:.4f} {status}")
|
||||
|
||||
mean_diff = sum(abs(v-t) for v,t in zip(vllm_logprobs, trainer_logprobs)) / len(trainer_logprobs)
|
||||
print(f" Mean diff: {mean_diff:.4f}")
|
||||
|
||||
if mean_diff < 0.1:
|
||||
print(f" ✓ WEIGHTS ARE SHARED CORRECTLY!")
|
||||
if mean_diff < 0.05:
|
||||
print(f" ✓ PERFECT ALIGNMENT - weights shared and same compute path")
|
||||
elif mean_diff < 0.25:
|
||||
print(f" ✓ WEIGHTS ARE SHARED (diff {mean_diff:.2f} is due to different forward pass implementations)")
|
||||
print(f" vLLM uses Flash Attention, trainer uses HuggingFace - small diff is expected!")
|
||||
else:
|
||||
print(f" ✗ WEIGHTS ARE NOT SHARED - IPC attachment may have failed!")
|
||||
print(f" The trainer may have copies of weights, not shared memory.")
|
||||
print(f" ⚠ Large diff ({mean_diff:.2f}) - may indicate issue with weight sharing")
|
||||
else:
|
||||
print(f" vLLM request failed: {response.status_code}")
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue