single copy

This commit is contained in:
Jai Suphavadeeprasit 2026-01-13 12:00:35 -05:00
parent 5ba06c7d4a
commit 3de03d6db3
3 changed files with 603 additions and 810 deletions

File diff suppressed because it is too large Load diff

View file

@ -155,6 +155,18 @@ class TrainingConfig(BaseModel):
"Weight updates are broadcast to vLLM's daemon process."
),
)
# Single-copy mode (TRUE shared memory - no extra model copy)
single_copy: bool = Field(
False,
description=(
"Enable TRUE single-copy mode via CUDA IPC. "
"The trainer attaches to vLLM's model tensors directly, "
"meaning only ONE copy of the model exists in GPU memory. "
"Requires trainer and vLLM to be on the SAME GPU(s). "
"vLLM must be started with VLLM_ENABLE_SHARED_WEIGHTS=1."
),
)
def check_atropos_api(timeout: float = 30.0) -> bool:
@ -414,9 +426,143 @@ def setup_wandb(config: TrainingConfig) -> bool:
return False
def _attach_to_vllm_shared_tensors(
config: TrainingConfig,
bridge_config_path: str,
) -> Optional[torch.nn.Module]:
"""
Attach to vLLM's shared tensors via CUDA IPC (true single-copy mode).
This creates a model whose parameters point to the SAME GPU memory as vLLM,
meaning only ONE copy of the model exists in GPU memory.
Args:
config: Training configuration
bridge_config_path: Path to vllm_bridge_config.json
Returns:
Model with parameters pointing to vLLM's tensors, or None if not possible
"""
try:
with open(bridge_config_path, 'r') as f:
bridge_config = json.load(f)
except Exception as e:
print(f"[Setup] Could not read bridge config: {e}")
return None
if not bridge_config.get("single_copy_enabled", False):
print("[Setup] Single-copy mode not available (no IPC handles exported)")
return None
ipc_handles = bridge_config.get("ipc_handles", {})
if not ipc_handles:
print("[Setup] No IPC handles found in bridge config")
return None
print(f"[Setup] Attaching to vLLM's shared tensors ({len(ipc_handles)} tensors)...")
print("[Setup] TRUE SINGLE-COPY MODE - No additional model memory!")
# Create model architecture (meta device - no memory allocation)
with torch.device('meta'):
model = AutoModelForCausalLM.from_pretrained(
config.model_name,
torch_dtype=torch.bfloat16,
)
# Map vLLM tensor names to HuggingFace model parameter names
hf_state_dict = {}
vllm_to_hf_mapping = _create_vllm_to_hf_mapping(model, ipc_handles)
attached_count = 0
for hf_name, vllm_name in vllm_to_hf_mapping.items():
if vllm_name not in ipc_handles:
continue
ipc_info = ipc_handles[vllm_name]
try:
# Reconstruct tensor from IPC handle
handle_bytes = bytes.fromhex(ipc_info["handle"])
storage_size = ipc_info["storage_size"]
device_index = ipc_info["device_index"]
# Create storage from IPC handle
storage = torch.cuda.UntypedStorage._new_shared_cuda(
device_index,
handle_bytes,
storage_size,
)
# Reconstruct tensor
dtype = getattr(torch, ipc_info["dtype"].replace("torch.", ""))
tensor = torch.tensor([], dtype=dtype, device=f"cuda:{device_index}")
tensor.set_(
storage,
storage_offset=ipc_info["storage_offset"],
size=ipc_info["shape"],
stride=ipc_info["stride"],
)
# Make tensor require gradients for training
tensor.requires_grad_(True)
hf_state_dict[hf_name] = tensor
attached_count += 1
except Exception as e:
print(f"[Setup] Failed to attach {hf_name}: {e}")
continue
if attached_count == 0:
print("[Setup] Could not attach any tensors, falling back to regular loading")
return None
print(f"[Setup] ✓ Attached {attached_count} tensors to vLLM's shared memory")
# Load state dict into model
model.load_state_dict(hf_state_dict, strict=False, assign=True)
return model
def _create_vllm_to_hf_mapping(model: torch.nn.Module, ipc_handles: dict) -> dict:
"""
Create mapping from HuggingFace parameter names to vLLM tensor names.
vLLM uses slightly different naming conventions than HuggingFace.
This function creates the bidirectional mapping.
"""
hf_params = set(model.state_dict().keys())
vllm_params = set(ipc_handles.keys())
mapping = {}
for hf_name in hf_params:
# Try direct match first
if hf_name in vllm_params:
mapping[hf_name] = hf_name
continue
# Try common transformations
# vLLM often uses 'model.' prefix
vllm_name = f"model.{hf_name}" if not hf_name.startswith("model.") else hf_name
if vllm_name in vllm_params:
mapping[hf_name] = vllm_name
continue
# Remove 'model.' prefix if present
if hf_name.startswith("model."):
vllm_name = hf_name[6:]
if vllm_name in vllm_params:
mapping[hf_name] = vllm_name
return mapping
def load_model_and_tokenizer(
config: TrainingConfig,
bridge: Optional["VLLMWeightBridge"] = None,
single_copy: bool = False,
) -> Tuple[torch.nn.Module, "AutoTokenizer"]:
"""
Load or attach to model based on weight_bridge_mode.
@ -424,6 +570,7 @@ def load_model_and_tokenizer(
Args:
config: Training configuration
bridge: Optional weight bridge for shared_vllm mode
single_copy: If True, try to attach to vLLM's shared tensors (no extra memory)
Returns:
Tuple of (model, tokenizer)
@ -431,8 +578,21 @@ def load_model_and_tokenizer(
tokenizer = AutoTokenizer.from_pretrained(config.model_name)
if config.weight_bridge_mode == "shared_vllm" and bridge is not None:
# Shared vLLM mode: load model, weights will be broadcast via NCCL
print("[Setup] Loading model for shared vLLM mode...")
# Try single-copy mode first if enabled
if single_copy or os.environ.get("VLLM_SINGLE_COPY", "0") == "1":
log_dir = os.environ.get("LOGDIR", ".")
bridge_config_path = os.path.join(log_dir, "vllm_bridge_config.json")
model = _attach_to_vllm_shared_tensors(config, bridge_config_path)
if model is not None:
print("[Setup] ✓ Single-copy mode active - using vLLM's tensors directly!")
model.train()
return model, tokenizer
else:
print("[Setup] Single-copy failed, falling back to broadcast mode...")
# Fallback: Load separate model, broadcast updates via NCCL
print("[Setup] Loading model for shared vLLM mode (broadcast)...")
if config.use_shared_memory:
print("[Setup] NCCL shared memory mode - updates broadcast to vLLM daemon")
else:
@ -1101,7 +1261,11 @@ def train_shared_vllm(config: TrainingConfig):
# Load model with bridge attachment
print("[2/3] Loading model with shared weights...")
model, tokenizer = load_model_and_tokenizer(config, bridge=bridge)
model, tokenizer = load_model_and_tokenizer(
config,
bridge=bridge,
single_copy=config.single_copy
)
# maybe we can actually pick optimizer
@ -1560,6 +1724,18 @@ def parse_args() -> argparse.Namespace:
"vLLM must be started with VLLM_ENABLE_SHARED_WEIGHTS=1."
),
)
parser.add_argument(
"--single-copy",
action="store_true",
help=(
"Enable TRUE single-copy mode (shared_vllm mode only). "
"Trainer attaches to vLLM's model tensors via CUDA IPC. "
"Only ONE copy of the model exists in GPU memory! "
"Requires trainer and vLLM to be on the SAME GPU(s). "
"vLLM must be started with VLLM_ENABLE_SHARED_WEIGHTS=1."
),
)
return parser.parse_args()
@ -1591,6 +1767,7 @@ def config_from_args(args: argparse.Namespace) -> TrainingConfig:
lora_dropout=args.lora_dropout,
lora_target_modules=args.lora_target_modules,
use_shared_memory=getattr(args, 'use_shared_memory', False),
single_copy=getattr(args, 'single_copy', False),
)

View file

@ -230,13 +230,36 @@ def _create_patched_runner(BaseRunner: type) -> type:
param_mappings = {}
param_names = []
ipc_handles = {}
for name, tensor in state_dict.items():
param_mappings[name] = {
"vllm_name": name,
"shape": list(tensor.shape),
"dtype": str(tensor.dtype),
"device": str(tensor.device),
}
param_names.append(name)
# Export CUDA IPC handles for true single-copy mode
if tensor.is_cuda:
try:
# Get the storage's IPC handle
storage = tensor.untyped_storage()
ipc_handle = storage._share_cuda_()
ipc_handles[name] = {
"handle": ipc_handle[0].hex() if isinstance(ipc_handle[0], bytes) else str(ipc_handle[0]),
"storage_size": ipc_handle[1],
"storage_offset": tensor.storage_offset(),
"shape": list(tensor.shape),
"stride": list(tensor.stride()),
"dtype": str(tensor.dtype),
"device_index": tensor.device.index,
}
except Exception as e:
print(f"[vLLM Patch] Could not get IPC handle for {name}: {e}", flush=True)
print(f"[vLLM Patch] Exported {len(ipc_handles)} IPC handles for single-copy mode", flush=True)
# Get model info
model_name = "unknown"
@ -253,8 +276,10 @@ def _create_patched_runner(BaseRunner: type) -> type:
"dp_shard_degree": 1,
"param_mappings": param_mappings,
"param_names": sorted(param_names),
"ipc_handles": ipc_handles,
"shared_weights_enabled": True,
"num_params": len(param_names),
"single_copy_enabled": len(ipc_handles) > 0,
}
try: