diff --git a/environments/eval_environments/gsm8k_eval.py b/environments/eval_environments/gsm8k_eval.py index 8dec68a5..85faedd9 100644 --- a/environments/eval_environments/gsm8k_eval.py +++ b/environments/eval_environments/gsm8k_eval.py @@ -133,10 +133,10 @@ class GSM8KEvalEnv(BaseEnv): self, config: GSM8KEvalConfig, server_configs: List[APIServerConfig], - slurm_job_id: Optional[str] = None, + slurm=False, testing: bool = False, ): - super().__init__(config, server_configs, slurm_job_id, testing) + super().__init__(config, server_configs, slurm, testing) self.config: GSM8KEvalConfig = config self.eval_items: List[Dict] = [] self._dataset_loaded = False diff --git a/example_trainer/grpo.py b/example_trainer/grpo.py index 73201600..9cd935ad 100644 --- a/example_trainer/grpo.py +++ b/example_trainer/grpo.py @@ -679,12 +679,15 @@ def _attach_to_vllm_shared_tensors( # This catches obvious mapping failures before we try to load # ========================================================================= hf_param_count = len(list(model.named_parameters())) + # Note: attached_count may include fused tensors that map to multiple HF params + # So coverage can exceed 100% - that's OK mapping_coverage = attached_count / hf_param_count if hf_param_count > 0 else 0 - print(f"[Setup] Mapping coverage: {attached_count}/{hf_param_count} ({mapping_coverage:.1%})") + print(f"[Setup] Mapping coverage: {attached_count} tensors for {hf_param_count} parameters") # Expect at least 90% coverage for a valid mapping - # Some params like inv_freq buffers won't be in vLLM + # Note: with fused tensors, we may have MORE mappings than params + # So we check if we have at least 90% of params covered if mapping_coverage < 0.90: unmapped_params = set(model.state_dict().keys()) - set(hf_state_dict.keys()) warning_msg = f"[Setup] WARNING: Low mapping coverage ({mapping_coverage:.1%})\n" @@ -701,6 +704,8 @@ def _attach_to_vllm_shared_tensors( " 2. tensor-parallel-size=1 for single-copy mode\n" " 3. vllm_bridge_config.json contains valid ipc_handles" ) + else: + print(f"[Setup] ✓ Good mapping coverage ({mapping_coverage:.1%})") print(f"[Setup] ✓ Attached {attached_count} tensors to vLLM's shared memory")