basic changes

This commit is contained in:
Jai Suphavadeeprasit 2025-12-30 10:01:27 -05:00
parent 14ebf7a492
commit 80d2608c4e
2 changed files with 105 additions and 36 deletions

View file

@ -195,8 +195,9 @@ def register_trainer(config: TrainingConfig):
response = requests.post(
"http://localhost:8000/register",
json={
"wandb_group": config.wandb_group,
"wandb_project": config.wandb_project,
# wandb fields are required strings - use empty string if None
"wandb_group": config.wandb_group or "",
"wandb_project": config.wandb_project or "",
"batch_size": config.batch_size * config.gradient_accumulation_steps,
"max_token_len": config.seq_len,
"starting_step": 0,
@ -1103,9 +1104,9 @@ def train_shared_vllm(config: TrainingConfig):
model, tokenizer = load_model_and_tokenizer(config, bridge=bridge)
optimizer = AdamW(model.parameters(), lr=config.lr)
# For NCCL mode, set param list from trainer's model
# For NCCL mode, build mapping between trainer's and vLLM's param names
if config.use_shared_memory:
bridge.set_param_list_from_model(model)
bridge.build_param_mapping(model)
print(f"[3/3] Starting training for {config.training_steps} steps")
print("NOTE: vLLM sees weight updates immediately after each step!")

View file

@ -366,36 +366,42 @@ class VLLMWeightBridge:
log_dir = self.config.log_dir or os.environ.get("LOGDIR", ".")
json_path = Path(log_dir) / "vllm_bridge_config.json"
# Wait for file (vLLM needs time to load model and export params)
# Wait for file WITH param_names populated (not just file existence)
# vllm_api_server creates empty file first, patched_gpu_runner fills it later
wait_time = 0
max_wait = min(self.config.timeout_seconds, 120) # Max 2 minutes
while not json_path.exists() and wait_time < max_wait:
if wait_time % 10 == 0:
print(f"[Bridge] Waiting for {json_path}... ({wait_time}s)")
while wait_time < max_wait:
if json_path.exists():
try:
with open(json_path, "r") as f:
data = json.load(f)
# Check if param_names is populated (not just file exists)
param_names = data.get("param_names", [])
if len(param_names) > 0:
self.param_mappings = data.get("param_mappings", {})
self.param_name_list = param_names
print(f"[Bridge] Loaded {len(self.param_name_list)} vLLM parameter names")
return
else:
# File exists but param_names not yet populated
if wait_time % 10 == 0:
print(f"[Bridge] Waiting for vLLM to export params... ({wait_time}s)")
except (json.JSONDecodeError, IOError):
# File being written, wait and retry
pass
else:
if wait_time % 10 == 0:
print(f"[Bridge] Waiting for {json_path}... ({wait_time}s)")
time.sleep(1)
wait_time += 1
if not json_path.exists():
print(f"[Bridge] Warning: Config file not found after {wait_time}s")
print("[Bridge] Will use trainer's model params directly")
self.param_mappings = {}
self.param_name_list = []
return
time.sleep(1.0) # Wait for file to finish writing
try:
with open(json_path, "r") as f:
data = json.load(f)
self.param_mappings = data.get("param_mappings", {})
self.param_name_list = data.get("param_names", sorted(self.param_mappings.keys()))
print(f"[Bridge] Loaded {len(self.param_name_list)} vLLM parameter names")
except Exception as e:
print(f"[Bridge] Warning: Failed to load config: {e}")
self.param_mappings = {}
self.param_name_list = []
print(f"[Bridge] Warning: Config file not populated after {wait_time}s")
print("[Bridge] Will use trainer's model params directly")
self.param_mappings = {}
self.param_name_list = []
def set_param_list_from_model(self, model: nn.Module) -> None:
"""
@ -404,8 +410,59 @@ class VLLMWeightBridge:
Call this if vLLM's param names don't match the trainer's.
"""
self.param_name_list = sorted(name for name, _ in model.named_parameters())
self._trainer_to_vllm_map = {} # 1:1 mapping
print(f"[Bridge] Using trainer's {len(self.param_name_list)} parameter names")
def build_param_mapping(self, model: nn.Module) -> None:
"""
Build mapping between trainer's HuggingFace params and vLLM's params.
HuggingFace models often have a "model." prefix that vLLM strips.
This builds a mapping to translate between the two naming conventions.
"""
trainer_params = dict(model.named_parameters())
trainer_names = set(trainer_params.keys())
# Build mapping: vLLM name -> trainer name
self._vllm_to_trainer_map: Dict[str, str] = {}
for vllm_name in self.param_name_list:
# Try exact match first
if vllm_name in trainer_names:
self._vllm_to_trainer_map[vllm_name] = vllm_name
continue
# Try adding "model." prefix (common for HuggingFace models)
hf_name = f"model.{vllm_name}"
if hf_name in trainer_names:
self._vllm_to_trainer_map[vllm_name] = hf_name
continue
# Try other common prefixes
for prefix in ["transformer.", "gpt.", "bert.", "encoder.", "decoder."]:
prefixed = f"{prefix}{vllm_name}"
if prefixed in trainer_names:
self._vllm_to_trainer_map[vllm_name] = prefixed
break
mapped = len(self._vllm_to_trainer_map)
total = len(self.param_name_list)
if mapped == 0:
print(f"[Bridge] ⚠ Warning: No params matched between trainer and vLLM!")
print(f"[Bridge] Trainer params (sample): {list(trainer_names)[:3]}")
print(f"[Bridge] vLLM params (sample): {self.param_name_list[:3]}")
# Fall back to trainer's param list
self.param_name_list = sorted(trainer_names)
self._vllm_to_trainer_map = {n: n for n in self.param_name_list}
print(f"[Bridge] Falling back to trainer's {len(self.param_name_list)} params")
elif mapped < total:
print(f"[Bridge] Mapped {mapped}/{total} params from vLLM to trainer")
# Only keep mapped params
self.param_name_list = sorted(self._vllm_to_trainer_map.keys())
else:
print(f"[Bridge] ✓ All {mapped} vLLM params mapped to trainer")
def broadcast_weights(self, model: nn.Module) -> None:
"""
Broadcast all model weights to vLLM inference workers.
@ -426,18 +483,26 @@ class VLLMWeightBridge:
self._update_count += 1
start_time = time.time()
state_dict = dict(model.named_parameters())
trainer_state_dict = dict(model.named_parameters())
# Get mapping (vLLM name -> trainer name)
vllm_to_trainer = getattr(self, '_vllm_to_trainer_map', {})
num_params = 0
skipped = 0
with torch.no_grad():
for idx, param_name in enumerate(self.param_name_list):
# Get tensor for this parameter
if param_name not in state_dict:
for idx, vllm_name in enumerate(self.param_name_list):
# Get trainer's parameter name for this vLLM param
trainer_name = vllm_to_trainer.get(vllm_name, vllm_name)
if trainer_name not in trainer_state_dict:
skipped += 1
continue
tensor = state_dict[param_name].data
tensor = trainer_state_dict[trainer_name].data
# Step 1: Broadcast parameter index
# Step 1: Broadcast parameter index (vLLM's index)
idx_tensor = torch.tensor([idx], dtype=torch.long, device=self.device)
dist.broadcast(idx_tensor, src=0, group=self.nccl_group)
@ -447,7 +512,10 @@ class VLLMWeightBridge:
num_params += 1
elapsed = time.time() - start_time
print(f"[Bridge] Broadcast {num_params} params, update #{self._update_count} ({elapsed:.2f}s)")
if skipped > 0:
print(f"[Bridge] Broadcast {num_params} params (skipped {skipped}), update #{self._update_count} ({elapsed:.2f}s)")
else:
print(f"[Bridge] Broadcast {num_params} params, update #{self._update_count} ({elapsed:.2f}s)")
def broadcast_single_param(
self,