mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-30 17:40:36 +00:00
readme updates
This commit is contained in:
parent
41cab9d52a
commit
011eb42aa3
1 changed files with 68 additions and 0 deletions
|
|
@ -217,6 +217,74 @@ def train_shared_vllm(config: TrainingConfig):
|
|||
else:
|
||||
print(f" ✗ Modification didn't stick - tensor may be a copy!")
|
||||
|
||||
# === CRITICAL TEST: Does vLLM SEE weight modifications? ===
|
||||
print(f"\n [CRITICAL] Testing if vLLM sees weight modifications...")
|
||||
try:
|
||||
import requests
|
||||
|
||||
test_prompt = "2+2="
|
||||
vllm_url = f"http://localhost:{config.vllm_port}"
|
||||
|
||||
# Get baseline output from vLLM
|
||||
response1 = requests.post(
|
||||
f"{vllm_url}/generate",
|
||||
json={"prompt": test_prompt, "max_tokens": 3, "temperature": 0.0},
|
||||
timeout=30,
|
||||
)
|
||||
baseline_output = response1.json().get("text", [""])[0] if response1.status_code == 200 else "ERROR"
|
||||
|
||||
# CORRUPT a weight dramatically (this should break the model)
|
||||
embed_param = None
|
||||
for name, param in model.named_parameters():
|
||||
if "embed_tokens" in name:
|
||||
embed_param = param
|
||||
break
|
||||
|
||||
if embed_param is not None:
|
||||
original_embed = embed_param.data[0, :10].clone()
|
||||
|
||||
# Corrupt the embedding with extreme values
|
||||
embed_param.data[0, :10] = 1000.0
|
||||
|
||||
# Query vLLM again - if sharing works, output should be GARBAGE
|
||||
response2 = requests.post(
|
||||
f"{vllm_url}/generate",
|
||||
json={"prompt": test_prompt, "max_tokens": 3, "temperature": 0.0},
|
||||
timeout=30,
|
||||
)
|
||||
corrupted_output = response2.json().get("text", [""])[0] if response2.status_code == 200 else "ERROR"
|
||||
|
||||
# Restore the embedding
|
||||
embed_param.data[0, :10] = original_embed
|
||||
|
||||
# Query vLLM again - should be back to normal
|
||||
response3 = requests.post(
|
||||
f"{vllm_url}/generate",
|
||||
json={"prompt": test_prompt, "max_tokens": 3, "temperature": 0.0},
|
||||
timeout=30,
|
||||
)
|
||||
restored_output = response3.json().get("text", [""])[0] if response3.status_code == 200 else "ERROR"
|
||||
|
||||
print(f" Baseline vLLM output: '{baseline_output}'")
|
||||
print(f" Corrupted vLLM output: '{corrupted_output}'")
|
||||
print(f" Restored vLLM output: '{restored_output}'")
|
||||
|
||||
# Check if vLLM saw the corruption
|
||||
if corrupted_output != baseline_output:
|
||||
print(f" ✓✓✓ vLLM SEES WEIGHT UPDATES! Output changed when weights corrupted.")
|
||||
if restored_output == baseline_output:
|
||||
print(f" ✓✓✓ Output restored after weight restoration. SHARING IS WORKING!")
|
||||
else:
|
||||
print(f" ⚠ Output didn't fully restore - may need vLLM cache clear")
|
||||
else:
|
||||
print(f" ✗✗✗ vLLM DID NOT SEE CORRUPTION - SHARING IS BROKEN!")
|
||||
print(f" vLLM may have internal weight copies/cache.")
|
||||
print(f" The IPC attachment gives write access but vLLM doesn't read from it.")
|
||||
except Exception as e:
|
||||
import traceback
|
||||
print(f" Critical test failed: {e}")
|
||||
traceback.print_exc()
|
||||
|
||||
# Now test vLLM logprobs vs trainer logprobs
|
||||
print(f"\n Testing logprob alignment with vLLM...")
|
||||
try:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue