From 80d2608c4eb03fbbe530ecc864406f9d8a85e791 Mon Sep 17 00:00:00 2001 From: Jai Suphavadeeprasit Date: Tue, 30 Dec 2025 10:01:27 -0500 Subject: [PATCH] basic changes --- example_trainer/grpo.py | 9 +- example_trainer/vllm_weight_bridge.py | 132 +++++++++++++++++++------- 2 files changed, 105 insertions(+), 36 deletions(-) diff --git a/example_trainer/grpo.py b/example_trainer/grpo.py index 7f83fc90..3c7942bb 100644 --- a/example_trainer/grpo.py +++ b/example_trainer/grpo.py @@ -195,8 +195,9 @@ def register_trainer(config: TrainingConfig): response = requests.post( "http://localhost:8000/register", json={ - "wandb_group": config.wandb_group, - "wandb_project": config.wandb_project, + # wandb fields are required strings - use empty string if None + "wandb_group": config.wandb_group or "", + "wandb_project": config.wandb_project or "", "batch_size": config.batch_size * config.gradient_accumulation_steps, "max_token_len": config.seq_len, "starting_step": 0, @@ -1103,9 +1104,9 @@ def train_shared_vllm(config: TrainingConfig): model, tokenizer = load_model_and_tokenizer(config, bridge=bridge) optimizer = AdamW(model.parameters(), lr=config.lr) - # For NCCL mode, set param list from trainer's model + # For NCCL mode, build mapping between trainer's and vLLM's param names if config.use_shared_memory: - bridge.set_param_list_from_model(model) + bridge.build_param_mapping(model) print(f"[3/3] Starting training for {config.training_steps} steps") print("NOTE: vLLM sees weight updates immediately after each step!") diff --git a/example_trainer/vllm_weight_bridge.py b/example_trainer/vllm_weight_bridge.py index 5307bd54..b24ef9b2 100644 --- a/example_trainer/vllm_weight_bridge.py +++ b/example_trainer/vllm_weight_bridge.py @@ -366,36 +366,42 @@ class VLLMWeightBridge: log_dir = self.config.log_dir or os.environ.get("LOGDIR", ".") json_path = Path(log_dir) / "vllm_bridge_config.json" - # Wait for file (vLLM needs time to load model and export params) + # Wait for file WITH param_names populated (not just file existence) + # vllm_api_server creates empty file first, patched_gpu_runner fills it later wait_time = 0 max_wait = min(self.config.timeout_seconds, 120) # Max 2 minutes - while not json_path.exists() and wait_time < max_wait: - if wait_time % 10 == 0: - print(f"[Bridge] Waiting for {json_path}... ({wait_time}s)") + + while wait_time < max_wait: + if json_path.exists(): + try: + with open(json_path, "r") as f: + data = json.load(f) + + # Check if param_names is populated (not just file exists) + param_names = data.get("param_names", []) + if len(param_names) > 0: + self.param_mappings = data.get("param_mappings", {}) + self.param_name_list = param_names + print(f"[Bridge] Loaded {len(self.param_name_list)} vLLM parameter names") + return + else: + # File exists but param_names not yet populated + if wait_time % 10 == 0: + print(f"[Bridge] Waiting for vLLM to export params... ({wait_time}s)") + except (json.JSONDecodeError, IOError): + # File being written, wait and retry + pass + else: + if wait_time % 10 == 0: + print(f"[Bridge] Waiting for {json_path}... ({wait_time}s)") + time.sleep(1) wait_time += 1 - if not json_path.exists(): - print(f"[Bridge] Warning: Config file not found after {wait_time}s") - print("[Bridge] Will use trainer's model params directly") - self.param_mappings = {} - self.param_name_list = [] - return - - time.sleep(1.0) # Wait for file to finish writing - - try: - with open(json_path, "r") as f: - data = json.load(f) - - self.param_mappings = data.get("param_mappings", {}) - self.param_name_list = data.get("param_names", sorted(self.param_mappings.keys())) - - print(f"[Bridge] Loaded {len(self.param_name_list)} vLLM parameter names") - except Exception as e: - print(f"[Bridge] Warning: Failed to load config: {e}") - self.param_mappings = {} - self.param_name_list = [] + print(f"[Bridge] Warning: Config file not populated after {wait_time}s") + print("[Bridge] Will use trainer's model params directly") + self.param_mappings = {} + self.param_name_list = [] def set_param_list_from_model(self, model: nn.Module) -> None: """ @@ -404,8 +410,59 @@ class VLLMWeightBridge: Call this if vLLM's param names don't match the trainer's. """ self.param_name_list = sorted(name for name, _ in model.named_parameters()) + self._trainer_to_vllm_map = {} # 1:1 mapping print(f"[Bridge] Using trainer's {len(self.param_name_list)} parameter names") + def build_param_mapping(self, model: nn.Module) -> None: + """ + Build mapping between trainer's HuggingFace params and vLLM's params. + + HuggingFace models often have a "model." prefix that vLLM strips. + This builds a mapping to translate between the two naming conventions. + """ + trainer_params = dict(model.named_parameters()) + trainer_names = set(trainer_params.keys()) + + # Build mapping: vLLM name -> trainer name + self._vllm_to_trainer_map: Dict[str, str] = {} + + for vllm_name in self.param_name_list: + # Try exact match first + if vllm_name in trainer_names: + self._vllm_to_trainer_map[vllm_name] = vllm_name + continue + + # Try adding "model." prefix (common for HuggingFace models) + hf_name = f"model.{vllm_name}" + if hf_name in trainer_names: + self._vllm_to_trainer_map[vllm_name] = hf_name + continue + + # Try other common prefixes + for prefix in ["transformer.", "gpt.", "bert.", "encoder.", "decoder."]: + prefixed = f"{prefix}{vllm_name}" + if prefixed in trainer_names: + self._vllm_to_trainer_map[vllm_name] = prefixed + break + + mapped = len(self._vllm_to_trainer_map) + total = len(self.param_name_list) + + if mapped == 0: + print(f"[Bridge] ⚠ Warning: No params matched between trainer and vLLM!") + print(f"[Bridge] Trainer params (sample): {list(trainer_names)[:3]}") + print(f"[Bridge] vLLM params (sample): {self.param_name_list[:3]}") + # Fall back to trainer's param list + self.param_name_list = sorted(trainer_names) + self._vllm_to_trainer_map = {n: n for n in self.param_name_list} + print(f"[Bridge] Falling back to trainer's {len(self.param_name_list)} params") + elif mapped < total: + print(f"[Bridge] Mapped {mapped}/{total} params from vLLM to trainer") + # Only keep mapped params + self.param_name_list = sorted(self._vllm_to_trainer_map.keys()) + else: + print(f"[Bridge] ✓ All {mapped} vLLM params mapped to trainer") + def broadcast_weights(self, model: nn.Module) -> None: """ Broadcast all model weights to vLLM inference workers. @@ -426,18 +483,26 @@ class VLLMWeightBridge: self._update_count += 1 start_time = time.time() - state_dict = dict(model.named_parameters()) + trainer_state_dict = dict(model.named_parameters()) + + # Get mapping (vLLM name -> trainer name) + vllm_to_trainer = getattr(self, '_vllm_to_trainer_map', {}) + num_params = 0 + skipped = 0 with torch.no_grad(): - for idx, param_name in enumerate(self.param_name_list): - # Get tensor for this parameter - if param_name not in state_dict: + for idx, vllm_name in enumerate(self.param_name_list): + # Get trainer's parameter name for this vLLM param + trainer_name = vllm_to_trainer.get(vllm_name, vllm_name) + + if trainer_name not in trainer_state_dict: + skipped += 1 continue - tensor = state_dict[param_name].data + tensor = trainer_state_dict[trainer_name].data - # Step 1: Broadcast parameter index + # Step 1: Broadcast parameter index (vLLM's index) idx_tensor = torch.tensor([idx], dtype=torch.long, device=self.device) dist.broadcast(idx_tensor, src=0, group=self.nccl_group) @@ -447,7 +512,10 @@ class VLLMWeightBridge: num_params += 1 elapsed = time.time() - start_time - print(f"[Bridge] Broadcast {num_params} params, update #{self._update_count} ({elapsed:.2f}s)") + if skipped > 0: + print(f"[Bridge] Broadcast {num_params} params (skipped {skipped}), update #{self._update_count} ({elapsed:.2f}s)") + else: + print(f"[Bridge] Broadcast {num_params} params, update #{self._update_count} ({elapsed:.2f}s)") def broadcast_single_param( self,