mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
param locations update
This commit is contained in:
parent
e2c99f7f97
commit
ff8eaf9e3c
4 changed files with 117 additions and 284 deletions
|
|
@ -111,19 +111,47 @@ def _create_patched_runner(BaseRunner: type) -> type:
|
|||
traceback.print_exc()
|
||||
|
||||
def _setup_shared_memory(self) -> None:
|
||||
"""Move model tensors to shared memory."""
|
||||
"""Move model tensors to shared memory and export param info."""
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
# Make entire model shareable
|
||||
self.model.share_memory()
|
||||
|
||||
# Also share_memory_() on each parameter individually
|
||||
# (some implementations may need this)
|
||||
state_dict = self.model.state_dict()
|
||||
for key, val in state_dict.items():
|
||||
if val.is_cuda or val.device.type == 'cuda':
|
||||
# For CUDA tensors, we need to ensure they're in shared memory
|
||||
val.share_memory_()
|
||||
|
||||
print(f"[vLLM Patch] Shared {len(state_dict)} tensors in model")
|
||||
|
||||
# Export parameter info to JSON for trainer
|
||||
log_dir = os.environ.get("LOGDIR", ".")
|
||||
json_path = Path(log_dir) / "vllm_bridge_config.json"
|
||||
|
||||
param_mappings = {}
|
||||
for name, tensor in state_dict.items():
|
||||
param_mappings[name] = {
|
||||
"vllm_name": name,
|
||||
"shape": list(tensor.shape),
|
||||
"dtype": str(tensor.dtype),
|
||||
}
|
||||
|
||||
info = {
|
||||
"model": str(self.model_config.model),
|
||||
"tp_degree": self.parallel_config.tensor_parallel_size,
|
||||
"dp_shard_degree": 1,
|
||||
"param_mappings": param_mappings,
|
||||
"param_names": sorted(state_dict.keys()),
|
||||
}
|
||||
|
||||
try:
|
||||
with open(json_path, "w") as f:
|
||||
json.dump(info, f, indent=2)
|
||||
print(f"[vLLM Patch] Exported {len(param_mappings)} params to {json_path}")
|
||||
except Exception as e:
|
||||
print(f"[vLLM Patch] Warning: Failed to export params: {e}")
|
||||
|
||||
def _spawn_weight_updater(self) -> None:
|
||||
"""Spawn the daemon process for receiving weight updates."""
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue