basic changes

This commit is contained in:
Jai Suphavadeeprasit 2025-12-30 10:01:27 -05:00
parent 14ebf7a492
commit 80d2608c4e
2 changed files with 105 additions and 36 deletions

View file

@ -195,8 +195,9 @@ def register_trainer(config: TrainingConfig):
response = requests.post( response = requests.post(
"http://localhost:8000/register", "http://localhost:8000/register",
json={ json={
"wandb_group": config.wandb_group, # wandb fields are required strings - use empty string if None
"wandb_project": config.wandb_project, "wandb_group": config.wandb_group or "",
"wandb_project": config.wandb_project or "",
"batch_size": config.batch_size * config.gradient_accumulation_steps, "batch_size": config.batch_size * config.gradient_accumulation_steps,
"max_token_len": config.seq_len, "max_token_len": config.seq_len,
"starting_step": 0, "starting_step": 0,
@ -1103,9 +1104,9 @@ def train_shared_vllm(config: TrainingConfig):
model, tokenizer = load_model_and_tokenizer(config, bridge=bridge) model, tokenizer = load_model_and_tokenizer(config, bridge=bridge)
optimizer = AdamW(model.parameters(), lr=config.lr) optimizer = AdamW(model.parameters(), lr=config.lr)
# For NCCL mode, set param list from trainer's model # For NCCL mode, build mapping between trainer's and vLLM's param names
if config.use_shared_memory: if config.use_shared_memory:
bridge.set_param_list_from_model(model) bridge.build_param_mapping(model)
print(f"[3/3] Starting training for {config.training_steps} steps") print(f"[3/3] Starting training for {config.training_steps} steps")
print("NOTE: vLLM sees weight updates immediately after each step!") print("NOTE: vLLM sees weight updates immediately after each step!")

View file

@ -366,36 +366,42 @@ class VLLMWeightBridge:
log_dir = self.config.log_dir or os.environ.get("LOGDIR", ".") log_dir = self.config.log_dir or os.environ.get("LOGDIR", ".")
json_path = Path(log_dir) / "vllm_bridge_config.json" json_path = Path(log_dir) / "vllm_bridge_config.json"
# Wait for file (vLLM needs time to load model and export params) # Wait for file WITH param_names populated (not just file existence)
# vllm_api_server creates empty file first, patched_gpu_runner fills it later
wait_time = 0 wait_time = 0
max_wait = min(self.config.timeout_seconds, 120) # Max 2 minutes max_wait = min(self.config.timeout_seconds, 120) # Max 2 minutes
while not json_path.exists() and wait_time < max_wait:
if wait_time % 10 == 0: while wait_time < max_wait:
print(f"[Bridge] Waiting for {json_path}... ({wait_time}s)") if json_path.exists():
try:
with open(json_path, "r") as f:
data = json.load(f)
# Check if param_names is populated (not just file exists)
param_names = data.get("param_names", [])
if len(param_names) > 0:
self.param_mappings = data.get("param_mappings", {})
self.param_name_list = param_names
print(f"[Bridge] Loaded {len(self.param_name_list)} vLLM parameter names")
return
else:
# File exists but param_names not yet populated
if wait_time % 10 == 0:
print(f"[Bridge] Waiting for vLLM to export params... ({wait_time}s)")
except (json.JSONDecodeError, IOError):
# File being written, wait and retry
pass
else:
if wait_time % 10 == 0:
print(f"[Bridge] Waiting for {json_path}... ({wait_time}s)")
time.sleep(1) time.sleep(1)
wait_time += 1 wait_time += 1
if not json_path.exists(): print(f"[Bridge] Warning: Config file not populated after {wait_time}s")
print(f"[Bridge] Warning: Config file not found after {wait_time}s") print("[Bridge] Will use trainer's model params directly")
print("[Bridge] Will use trainer's model params directly") self.param_mappings = {}
self.param_mappings = {} self.param_name_list = []
self.param_name_list = []
return
time.sleep(1.0) # Wait for file to finish writing
try:
with open(json_path, "r") as f:
data = json.load(f)
self.param_mappings = data.get("param_mappings", {})
self.param_name_list = data.get("param_names", sorted(self.param_mappings.keys()))
print(f"[Bridge] Loaded {len(self.param_name_list)} vLLM parameter names")
except Exception as e:
print(f"[Bridge] Warning: Failed to load config: {e}")
self.param_mappings = {}
self.param_name_list = []
def set_param_list_from_model(self, model: nn.Module) -> None: def set_param_list_from_model(self, model: nn.Module) -> None:
""" """
@ -404,8 +410,59 @@ class VLLMWeightBridge:
Call this if vLLM's param names don't match the trainer's. Call this if vLLM's param names don't match the trainer's.
""" """
self.param_name_list = sorted(name for name, _ in model.named_parameters()) self.param_name_list = sorted(name for name, _ in model.named_parameters())
self._trainer_to_vllm_map = {} # 1:1 mapping
print(f"[Bridge] Using trainer's {len(self.param_name_list)} parameter names") print(f"[Bridge] Using trainer's {len(self.param_name_list)} parameter names")
def build_param_mapping(self, model: nn.Module) -> None:
"""
Build mapping between trainer's HuggingFace params and vLLM's params.
HuggingFace models often have a "model." prefix that vLLM strips.
This builds a mapping to translate between the two naming conventions.
"""
trainer_params = dict(model.named_parameters())
trainer_names = set(trainer_params.keys())
# Build mapping: vLLM name -> trainer name
self._vllm_to_trainer_map: Dict[str, str] = {}
for vllm_name in self.param_name_list:
# Try exact match first
if vllm_name in trainer_names:
self._vllm_to_trainer_map[vllm_name] = vllm_name
continue
# Try adding "model." prefix (common for HuggingFace models)
hf_name = f"model.{vllm_name}"
if hf_name in trainer_names:
self._vllm_to_trainer_map[vllm_name] = hf_name
continue
# Try other common prefixes
for prefix in ["transformer.", "gpt.", "bert.", "encoder.", "decoder."]:
prefixed = f"{prefix}{vllm_name}"
if prefixed in trainer_names:
self._vllm_to_trainer_map[vllm_name] = prefixed
break
mapped = len(self._vllm_to_trainer_map)
total = len(self.param_name_list)
if mapped == 0:
print(f"[Bridge] ⚠ Warning: No params matched between trainer and vLLM!")
print(f"[Bridge] Trainer params (sample): {list(trainer_names)[:3]}")
print(f"[Bridge] vLLM params (sample): {self.param_name_list[:3]}")
# Fall back to trainer's param list
self.param_name_list = sorted(trainer_names)
self._vllm_to_trainer_map = {n: n for n in self.param_name_list}
print(f"[Bridge] Falling back to trainer's {len(self.param_name_list)} params")
elif mapped < total:
print(f"[Bridge] Mapped {mapped}/{total} params from vLLM to trainer")
# Only keep mapped params
self.param_name_list = sorted(self._vllm_to_trainer_map.keys())
else:
print(f"[Bridge] ✓ All {mapped} vLLM params mapped to trainer")
def broadcast_weights(self, model: nn.Module) -> None: def broadcast_weights(self, model: nn.Module) -> None:
""" """
Broadcast all model weights to vLLM inference workers. Broadcast all model weights to vLLM inference workers.
@ -426,18 +483,26 @@ class VLLMWeightBridge:
self._update_count += 1 self._update_count += 1
start_time = time.time() start_time = time.time()
state_dict = dict(model.named_parameters()) trainer_state_dict = dict(model.named_parameters())
# Get mapping (vLLM name -> trainer name)
vllm_to_trainer = getattr(self, '_vllm_to_trainer_map', {})
num_params = 0 num_params = 0
skipped = 0
with torch.no_grad(): with torch.no_grad():
for idx, param_name in enumerate(self.param_name_list): for idx, vllm_name in enumerate(self.param_name_list):
# Get tensor for this parameter # Get trainer's parameter name for this vLLM param
if param_name not in state_dict: trainer_name = vllm_to_trainer.get(vllm_name, vllm_name)
if trainer_name not in trainer_state_dict:
skipped += 1
continue continue
tensor = state_dict[param_name].data tensor = trainer_state_dict[trainer_name].data
# Step 1: Broadcast parameter index # Step 1: Broadcast parameter index (vLLM's index)
idx_tensor = torch.tensor([idx], dtype=torch.long, device=self.device) idx_tensor = torch.tensor([idx], dtype=torch.long, device=self.device)
dist.broadcast(idx_tensor, src=0, group=self.nccl_group) dist.broadcast(idx_tensor, src=0, group=self.nccl_group)
@ -447,7 +512,10 @@ class VLLMWeightBridge:
num_params += 1 num_params += 1
elapsed = time.time() - start_time elapsed = time.time() - start_time
print(f"[Bridge] Broadcast {num_params} params, update #{self._update_count} ({elapsed:.2f}s)") if skipped > 0:
print(f"[Bridge] Broadcast {num_params} params (skipped {skipped}), update #{self._update_count} ({elapsed:.2f}s)")
else:
print(f"[Bridge] Broadcast {num_params} params, update #{self._update_count} ({elapsed:.2f}s)")
def broadcast_single_param( def broadcast_single_param(
self, self,