mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
basic changes
This commit is contained in:
parent
14ebf7a492
commit
80d2608c4e
2 changed files with 105 additions and 36 deletions
|
|
@ -195,8 +195,9 @@ def register_trainer(config: TrainingConfig):
|
|||
response = requests.post(
|
||||
"http://localhost:8000/register",
|
||||
json={
|
||||
"wandb_group": config.wandb_group,
|
||||
"wandb_project": config.wandb_project,
|
||||
# wandb fields are required strings - use empty string if None
|
||||
"wandb_group": config.wandb_group or "",
|
||||
"wandb_project": config.wandb_project or "",
|
||||
"batch_size": config.batch_size * config.gradient_accumulation_steps,
|
||||
"max_token_len": config.seq_len,
|
||||
"starting_step": 0,
|
||||
|
|
@ -1103,9 +1104,9 @@ def train_shared_vllm(config: TrainingConfig):
|
|||
model, tokenizer = load_model_and_tokenizer(config, bridge=bridge)
|
||||
optimizer = AdamW(model.parameters(), lr=config.lr)
|
||||
|
||||
# For NCCL mode, set param list from trainer's model
|
||||
# For NCCL mode, build mapping between trainer's and vLLM's param names
|
||||
if config.use_shared_memory:
|
||||
bridge.set_param_list_from_model(model)
|
||||
bridge.build_param_mapping(model)
|
||||
|
||||
print(f"[3/3] Starting training for {config.training_steps} steps")
|
||||
print("NOTE: vLLM sees weight updates immediately after each step!")
|
||||
|
|
|
|||
|
|
@ -366,36 +366,42 @@ class VLLMWeightBridge:
|
|||
log_dir = self.config.log_dir or os.environ.get("LOGDIR", ".")
|
||||
json_path = Path(log_dir) / "vllm_bridge_config.json"
|
||||
|
||||
# Wait for file (vLLM needs time to load model and export params)
|
||||
# Wait for file WITH param_names populated (not just file existence)
|
||||
# vllm_api_server creates empty file first, patched_gpu_runner fills it later
|
||||
wait_time = 0
|
||||
max_wait = min(self.config.timeout_seconds, 120) # Max 2 minutes
|
||||
while not json_path.exists() and wait_time < max_wait:
|
||||
if wait_time % 10 == 0:
|
||||
print(f"[Bridge] Waiting for {json_path}... ({wait_time}s)")
|
||||
|
||||
while wait_time < max_wait:
|
||||
if json_path.exists():
|
||||
try:
|
||||
with open(json_path, "r") as f:
|
||||
data = json.load(f)
|
||||
|
||||
# Check if param_names is populated (not just file exists)
|
||||
param_names = data.get("param_names", [])
|
||||
if len(param_names) > 0:
|
||||
self.param_mappings = data.get("param_mappings", {})
|
||||
self.param_name_list = param_names
|
||||
print(f"[Bridge] Loaded {len(self.param_name_list)} vLLM parameter names")
|
||||
return
|
||||
else:
|
||||
# File exists but param_names not yet populated
|
||||
if wait_time % 10 == 0:
|
||||
print(f"[Bridge] Waiting for vLLM to export params... ({wait_time}s)")
|
||||
except (json.JSONDecodeError, IOError):
|
||||
# File being written, wait and retry
|
||||
pass
|
||||
else:
|
||||
if wait_time % 10 == 0:
|
||||
print(f"[Bridge] Waiting for {json_path}... ({wait_time}s)")
|
||||
|
||||
time.sleep(1)
|
||||
wait_time += 1
|
||||
|
||||
if not json_path.exists():
|
||||
print(f"[Bridge] Warning: Config file not found after {wait_time}s")
|
||||
print("[Bridge] Will use trainer's model params directly")
|
||||
self.param_mappings = {}
|
||||
self.param_name_list = []
|
||||
return
|
||||
|
||||
time.sleep(1.0) # Wait for file to finish writing
|
||||
|
||||
try:
|
||||
with open(json_path, "r") as f:
|
||||
data = json.load(f)
|
||||
|
||||
self.param_mappings = data.get("param_mappings", {})
|
||||
self.param_name_list = data.get("param_names", sorted(self.param_mappings.keys()))
|
||||
|
||||
print(f"[Bridge] Loaded {len(self.param_name_list)} vLLM parameter names")
|
||||
except Exception as e:
|
||||
print(f"[Bridge] Warning: Failed to load config: {e}")
|
||||
self.param_mappings = {}
|
||||
self.param_name_list = []
|
||||
print(f"[Bridge] Warning: Config file not populated after {wait_time}s")
|
||||
print("[Bridge] Will use trainer's model params directly")
|
||||
self.param_mappings = {}
|
||||
self.param_name_list = []
|
||||
|
||||
def set_param_list_from_model(self, model: nn.Module) -> None:
|
||||
"""
|
||||
|
|
@ -404,8 +410,59 @@ class VLLMWeightBridge:
|
|||
Call this if vLLM's param names don't match the trainer's.
|
||||
"""
|
||||
self.param_name_list = sorted(name for name, _ in model.named_parameters())
|
||||
self._trainer_to_vllm_map = {} # 1:1 mapping
|
||||
print(f"[Bridge] Using trainer's {len(self.param_name_list)} parameter names")
|
||||
|
||||
def build_param_mapping(self, model: nn.Module) -> None:
|
||||
"""
|
||||
Build mapping between trainer's HuggingFace params and vLLM's params.
|
||||
|
||||
HuggingFace models often have a "model." prefix that vLLM strips.
|
||||
This builds a mapping to translate between the two naming conventions.
|
||||
"""
|
||||
trainer_params = dict(model.named_parameters())
|
||||
trainer_names = set(trainer_params.keys())
|
||||
|
||||
# Build mapping: vLLM name -> trainer name
|
||||
self._vllm_to_trainer_map: Dict[str, str] = {}
|
||||
|
||||
for vllm_name in self.param_name_list:
|
||||
# Try exact match first
|
||||
if vllm_name in trainer_names:
|
||||
self._vllm_to_trainer_map[vllm_name] = vllm_name
|
||||
continue
|
||||
|
||||
# Try adding "model." prefix (common for HuggingFace models)
|
||||
hf_name = f"model.{vllm_name}"
|
||||
if hf_name in trainer_names:
|
||||
self._vllm_to_trainer_map[vllm_name] = hf_name
|
||||
continue
|
||||
|
||||
# Try other common prefixes
|
||||
for prefix in ["transformer.", "gpt.", "bert.", "encoder.", "decoder."]:
|
||||
prefixed = f"{prefix}{vllm_name}"
|
||||
if prefixed in trainer_names:
|
||||
self._vllm_to_trainer_map[vllm_name] = prefixed
|
||||
break
|
||||
|
||||
mapped = len(self._vllm_to_trainer_map)
|
||||
total = len(self.param_name_list)
|
||||
|
||||
if mapped == 0:
|
||||
print(f"[Bridge] ⚠ Warning: No params matched between trainer and vLLM!")
|
||||
print(f"[Bridge] Trainer params (sample): {list(trainer_names)[:3]}")
|
||||
print(f"[Bridge] vLLM params (sample): {self.param_name_list[:3]}")
|
||||
# Fall back to trainer's param list
|
||||
self.param_name_list = sorted(trainer_names)
|
||||
self._vllm_to_trainer_map = {n: n for n in self.param_name_list}
|
||||
print(f"[Bridge] Falling back to trainer's {len(self.param_name_list)} params")
|
||||
elif mapped < total:
|
||||
print(f"[Bridge] Mapped {mapped}/{total} params from vLLM to trainer")
|
||||
# Only keep mapped params
|
||||
self.param_name_list = sorted(self._vllm_to_trainer_map.keys())
|
||||
else:
|
||||
print(f"[Bridge] ✓ All {mapped} vLLM params mapped to trainer")
|
||||
|
||||
def broadcast_weights(self, model: nn.Module) -> None:
|
||||
"""
|
||||
Broadcast all model weights to vLLM inference workers.
|
||||
|
|
@ -426,18 +483,26 @@ class VLLMWeightBridge:
|
|||
self._update_count += 1
|
||||
start_time = time.time()
|
||||
|
||||
state_dict = dict(model.named_parameters())
|
||||
trainer_state_dict = dict(model.named_parameters())
|
||||
|
||||
# Get mapping (vLLM name -> trainer name)
|
||||
vllm_to_trainer = getattr(self, '_vllm_to_trainer_map', {})
|
||||
|
||||
num_params = 0
|
||||
skipped = 0
|
||||
|
||||
with torch.no_grad():
|
||||
for idx, param_name in enumerate(self.param_name_list):
|
||||
# Get tensor for this parameter
|
||||
if param_name not in state_dict:
|
||||
for idx, vllm_name in enumerate(self.param_name_list):
|
||||
# Get trainer's parameter name for this vLLM param
|
||||
trainer_name = vllm_to_trainer.get(vllm_name, vllm_name)
|
||||
|
||||
if trainer_name not in trainer_state_dict:
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
tensor = state_dict[param_name].data
|
||||
tensor = trainer_state_dict[trainer_name].data
|
||||
|
||||
# Step 1: Broadcast parameter index
|
||||
# Step 1: Broadcast parameter index (vLLM's index)
|
||||
idx_tensor = torch.tensor([idx], dtype=torch.long, device=self.device)
|
||||
dist.broadcast(idx_tensor, src=0, group=self.nccl_group)
|
||||
|
||||
|
|
@ -447,7 +512,10 @@ class VLLMWeightBridge:
|
|||
num_params += 1
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
print(f"[Bridge] Broadcast {num_params} params, update #{self._update_count} ({elapsed:.2f}s)")
|
||||
if skipped > 0:
|
||||
print(f"[Bridge] Broadcast {num_params} params (skipped {skipped}), update #{self._update_count} ({elapsed:.2f}s)")
|
||||
else:
|
||||
print(f"[Bridge] Broadcast {num_params} params, update #{self._update_count} ({elapsed:.2f}s)")
|
||||
|
||||
def broadcast_single_param(
|
||||
self,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue