param locations update

This commit is contained in:
Jai Suphavadeeprasit 2025-12-29 20:11:23 -05:00
parent e2c99f7f97
commit ff8eaf9e3c
4 changed files with 117 additions and 284 deletions

View file

@ -111,19 +111,47 @@ def _create_patched_runner(BaseRunner: type) -> type:
traceback.print_exc()
def _setup_shared_memory(self) -> None:
"""Move model tensors to shared memory."""
"""Move model tensors to shared memory and export param info."""
import json
from pathlib import Path
# Make entire model shareable
self.model.share_memory()
# Also share_memory_() on each parameter individually
# (some implementations may need this)
state_dict = self.model.state_dict()
for key, val in state_dict.items():
if val.is_cuda or val.device.type == 'cuda':
# For CUDA tensors, we need to ensure they're in shared memory
val.share_memory_()
print(f"[vLLM Patch] Shared {len(state_dict)} tensors in model")
# Export parameter info to JSON for trainer
log_dir = os.environ.get("LOGDIR", ".")
json_path = Path(log_dir) / "vllm_bridge_config.json"
param_mappings = {}
for name, tensor in state_dict.items():
param_mappings[name] = {
"vllm_name": name,
"shape": list(tensor.shape),
"dtype": str(tensor.dtype),
}
info = {
"model": str(self.model_config.model),
"tp_degree": self.parallel_config.tensor_parallel_size,
"dp_shard_degree": 1,
"param_mappings": param_mappings,
"param_names": sorted(state_dict.keys()),
}
try:
with open(json_path, "w") as f:
json.dump(info, f, indent=2)
print(f"[vLLM Patch] Exported {len(param_mappings)} params to {json_path}")
except Exception as e:
print(f"[vLLM Patch] Warning: Failed to export params: {e}")
def _spawn_weight_updater(self) -> None:
"""Spawn the daemon process for receiving weight updates."""