readme updates

This commit is contained in:
Jai Suphavadeeprasit 2026-01-27 17:21:43 -05:00
parent 41cab9d52a
commit 011eb42aa3

View file

@ -217,6 +217,74 @@ def train_shared_vllm(config: TrainingConfig):
else:
print(f" ✗ Modification didn't stick - tensor may be a copy!")
# === CRITICAL TEST: Does vLLM SEE weight modifications? ===
print(f"\n [CRITICAL] Testing if vLLM sees weight modifications...")
try:
import requests
test_prompt = "2+2="
vllm_url = f"http://localhost:{config.vllm_port}"
# Get baseline output from vLLM
response1 = requests.post(
f"{vllm_url}/generate",
json={"prompt": test_prompt, "max_tokens": 3, "temperature": 0.0},
timeout=30,
)
baseline_output = response1.json().get("text", [""])[0] if response1.status_code == 200 else "ERROR"
# CORRUPT a weight dramatically (this should break the model)
embed_param = None
for name, param in model.named_parameters():
if "embed_tokens" in name:
embed_param = param
break
if embed_param is not None:
original_embed = embed_param.data[0, :10].clone()
# Corrupt the embedding with extreme values
embed_param.data[0, :10] = 1000.0
# Query vLLM again - if sharing works, output should be GARBAGE
response2 = requests.post(
f"{vllm_url}/generate",
json={"prompt": test_prompt, "max_tokens": 3, "temperature": 0.0},
timeout=30,
)
corrupted_output = response2.json().get("text", [""])[0] if response2.status_code == 200 else "ERROR"
# Restore the embedding
embed_param.data[0, :10] = original_embed
# Query vLLM again - should be back to normal
response3 = requests.post(
f"{vllm_url}/generate",
json={"prompt": test_prompt, "max_tokens": 3, "temperature": 0.0},
timeout=30,
)
restored_output = response3.json().get("text", [""])[0] if response3.status_code == 200 else "ERROR"
print(f" Baseline vLLM output: '{baseline_output}'")
print(f" Corrupted vLLM output: '{corrupted_output}'")
print(f" Restored vLLM output: '{restored_output}'")
# Check if vLLM saw the corruption
if corrupted_output != baseline_output:
print(f" ✓✓✓ vLLM SEES WEIGHT UPDATES! Output changed when weights corrupted.")
if restored_output == baseline_output:
print(f" ✓✓✓ Output restored after weight restoration. SHARING IS WORKING!")
else:
print(f" ⚠ Output didn't fully restore - may need vLLM cache clear")
else:
print(f" ✗✗✗ vLLM DID NOT SEE CORRUPTION - SHARING IS BROKEN!")
print(f" vLLM may have internal weight copies/cache.")
print(f" The IPC attachment gives write access but vLLM doesn't read from it.")
except Exception as e:
import traceback
print(f" Critical test failed: {e}")
traceback.print_exc()
# Now test vLLM logprobs vs trainer logprobs
print(f"\n Testing logprob alignment with vLLM...")
try: