diff --git a/example_trainer/trainers.py b/example_trainer/trainers.py index e4b297fb..3e30657f 100644 --- a/example_trainer/trainers.py +++ b/example_trainer/trainers.py @@ -217,6 +217,74 @@ def train_shared_vllm(config: TrainingConfig): else: print(f" ✗ Modification didn't stick - tensor may be a copy!") + # === CRITICAL TEST: Does vLLM SEE weight modifications? === + print(f"\n [CRITICAL] Testing if vLLM sees weight modifications...") + try: + import requests + + test_prompt = "2+2=" + vllm_url = f"http://localhost:{config.vllm_port}" + + # Get baseline output from vLLM + response1 = requests.post( + f"{vllm_url}/generate", + json={"prompt": test_prompt, "max_tokens": 3, "temperature": 0.0}, + timeout=30, + ) + baseline_output = response1.json().get("text", [""])[0] if response1.status_code == 200 else "ERROR" + + # CORRUPT a weight dramatically (this should break the model) + embed_param = None + for name, param in model.named_parameters(): + if "embed_tokens" in name: + embed_param = param + break + + if embed_param is not None: + original_embed = embed_param.data[0, :10].clone() + + # Corrupt the embedding with extreme values + embed_param.data[0, :10] = 1000.0 + + # Query vLLM again - if sharing works, output should be GARBAGE + response2 = requests.post( + f"{vllm_url}/generate", + json={"prompt": test_prompt, "max_tokens": 3, "temperature": 0.0}, + timeout=30, + ) + corrupted_output = response2.json().get("text", [""])[0] if response2.status_code == 200 else "ERROR" + + # Restore the embedding + embed_param.data[0, :10] = original_embed + + # Query vLLM again - should be back to normal + response3 = requests.post( + f"{vllm_url}/generate", + json={"prompt": test_prompt, "max_tokens": 3, "temperature": 0.0}, + timeout=30, + ) + restored_output = response3.json().get("text", [""])[0] if response3.status_code == 200 else "ERROR" + + print(f" Baseline vLLM output: '{baseline_output}'") + print(f" Corrupted vLLM output: '{corrupted_output}'") + print(f" Restored vLLM output: '{restored_output}'") + + # Check if vLLM saw the corruption + if corrupted_output != baseline_output: + print(f" ✓✓✓ vLLM SEES WEIGHT UPDATES! Output changed when weights corrupted.") + if restored_output == baseline_output: + print(f" ✓✓✓ Output restored after weight restoration. SHARING IS WORKING!") + else: + print(f" ⚠ Output didn't fully restore - may need vLLM cache clear") + else: + print(f" ✗✗✗ vLLM DID NOT SEE CORRUPTION - SHARING IS BROKEN!") + print(f" vLLM may have internal weight copies/cache.") + print(f" The IPC attachment gives write access but vLLM doesn't read from it.") + except Exception as e: + import traceback + print(f" Critical test failed: {e}") + traceback.print_exc() + # Now test vLLM logprobs vs trainer logprobs print(f"\n Testing logprob alignment with vLLM...") try: