mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-22 16:48:57 +00:00
basic changes
This commit is contained in:
parent
14ebf7a492
commit
80d2608c4e
2 changed files with 105 additions and 36 deletions
|
|
@ -195,8 +195,9 @@ def register_trainer(config: TrainingConfig):
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
"http://localhost:8000/register",
|
"http://localhost:8000/register",
|
||||||
json={
|
json={
|
||||||
"wandb_group": config.wandb_group,
|
# wandb fields are required strings - use empty string if None
|
||||||
"wandb_project": config.wandb_project,
|
"wandb_group": config.wandb_group or "",
|
||||||
|
"wandb_project": config.wandb_project or "",
|
||||||
"batch_size": config.batch_size * config.gradient_accumulation_steps,
|
"batch_size": config.batch_size * config.gradient_accumulation_steps,
|
||||||
"max_token_len": config.seq_len,
|
"max_token_len": config.seq_len,
|
||||||
"starting_step": 0,
|
"starting_step": 0,
|
||||||
|
|
@ -1103,9 +1104,9 @@ def train_shared_vllm(config: TrainingConfig):
|
||||||
model, tokenizer = load_model_and_tokenizer(config, bridge=bridge)
|
model, tokenizer = load_model_and_tokenizer(config, bridge=bridge)
|
||||||
optimizer = AdamW(model.parameters(), lr=config.lr)
|
optimizer = AdamW(model.parameters(), lr=config.lr)
|
||||||
|
|
||||||
# For NCCL mode, set param list from trainer's model
|
# For NCCL mode, build mapping between trainer's and vLLM's param names
|
||||||
if config.use_shared_memory:
|
if config.use_shared_memory:
|
||||||
bridge.set_param_list_from_model(model)
|
bridge.build_param_mapping(model)
|
||||||
|
|
||||||
print(f"[3/3] Starting training for {config.training_steps} steps")
|
print(f"[3/3] Starting training for {config.training_steps} steps")
|
||||||
print("NOTE: vLLM sees weight updates immediately after each step!")
|
print("NOTE: vLLM sees weight updates immediately after each step!")
|
||||||
|
|
|
||||||
|
|
@ -366,36 +366,42 @@ class VLLMWeightBridge:
|
||||||
log_dir = self.config.log_dir or os.environ.get("LOGDIR", ".")
|
log_dir = self.config.log_dir or os.environ.get("LOGDIR", ".")
|
||||||
json_path = Path(log_dir) / "vllm_bridge_config.json"
|
json_path = Path(log_dir) / "vllm_bridge_config.json"
|
||||||
|
|
||||||
# Wait for file (vLLM needs time to load model and export params)
|
# Wait for file WITH param_names populated (not just file existence)
|
||||||
|
# vllm_api_server creates empty file first, patched_gpu_runner fills it later
|
||||||
wait_time = 0
|
wait_time = 0
|
||||||
max_wait = min(self.config.timeout_seconds, 120) # Max 2 minutes
|
max_wait = min(self.config.timeout_seconds, 120) # Max 2 minutes
|
||||||
while not json_path.exists() and wait_time < max_wait:
|
|
||||||
if wait_time % 10 == 0:
|
while wait_time < max_wait:
|
||||||
print(f"[Bridge] Waiting for {json_path}... ({wait_time}s)")
|
if json_path.exists():
|
||||||
|
try:
|
||||||
|
with open(json_path, "r") as f:
|
||||||
|
data = json.load(f)
|
||||||
|
|
||||||
|
# Check if param_names is populated (not just file exists)
|
||||||
|
param_names = data.get("param_names", [])
|
||||||
|
if len(param_names) > 0:
|
||||||
|
self.param_mappings = data.get("param_mappings", {})
|
||||||
|
self.param_name_list = param_names
|
||||||
|
print(f"[Bridge] Loaded {len(self.param_name_list)} vLLM parameter names")
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
# File exists but param_names not yet populated
|
||||||
|
if wait_time % 10 == 0:
|
||||||
|
print(f"[Bridge] Waiting for vLLM to export params... ({wait_time}s)")
|
||||||
|
except (json.JSONDecodeError, IOError):
|
||||||
|
# File being written, wait and retry
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
if wait_time % 10 == 0:
|
||||||
|
print(f"[Bridge] Waiting for {json_path}... ({wait_time}s)")
|
||||||
|
|
||||||
time.sleep(1)
|
time.sleep(1)
|
||||||
wait_time += 1
|
wait_time += 1
|
||||||
|
|
||||||
if not json_path.exists():
|
print(f"[Bridge] Warning: Config file not populated after {wait_time}s")
|
||||||
print(f"[Bridge] Warning: Config file not found after {wait_time}s")
|
print("[Bridge] Will use trainer's model params directly")
|
||||||
print("[Bridge] Will use trainer's model params directly")
|
self.param_mappings = {}
|
||||||
self.param_mappings = {}
|
self.param_name_list = []
|
||||||
self.param_name_list = []
|
|
||||||
return
|
|
||||||
|
|
||||||
time.sleep(1.0) # Wait for file to finish writing
|
|
||||||
|
|
||||||
try:
|
|
||||||
with open(json_path, "r") as f:
|
|
||||||
data = json.load(f)
|
|
||||||
|
|
||||||
self.param_mappings = data.get("param_mappings", {})
|
|
||||||
self.param_name_list = data.get("param_names", sorted(self.param_mappings.keys()))
|
|
||||||
|
|
||||||
print(f"[Bridge] Loaded {len(self.param_name_list)} vLLM parameter names")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"[Bridge] Warning: Failed to load config: {e}")
|
|
||||||
self.param_mappings = {}
|
|
||||||
self.param_name_list = []
|
|
||||||
|
|
||||||
def set_param_list_from_model(self, model: nn.Module) -> None:
|
def set_param_list_from_model(self, model: nn.Module) -> None:
|
||||||
"""
|
"""
|
||||||
|
|
@ -404,8 +410,59 @@ class VLLMWeightBridge:
|
||||||
Call this if vLLM's param names don't match the trainer's.
|
Call this if vLLM's param names don't match the trainer's.
|
||||||
"""
|
"""
|
||||||
self.param_name_list = sorted(name for name, _ in model.named_parameters())
|
self.param_name_list = sorted(name for name, _ in model.named_parameters())
|
||||||
|
self._trainer_to_vllm_map = {} # 1:1 mapping
|
||||||
print(f"[Bridge] Using trainer's {len(self.param_name_list)} parameter names")
|
print(f"[Bridge] Using trainer's {len(self.param_name_list)} parameter names")
|
||||||
|
|
||||||
|
def build_param_mapping(self, model: nn.Module) -> None:
|
||||||
|
"""
|
||||||
|
Build mapping between trainer's HuggingFace params and vLLM's params.
|
||||||
|
|
||||||
|
HuggingFace models often have a "model." prefix that vLLM strips.
|
||||||
|
This builds a mapping to translate between the two naming conventions.
|
||||||
|
"""
|
||||||
|
trainer_params = dict(model.named_parameters())
|
||||||
|
trainer_names = set(trainer_params.keys())
|
||||||
|
|
||||||
|
# Build mapping: vLLM name -> trainer name
|
||||||
|
self._vllm_to_trainer_map: Dict[str, str] = {}
|
||||||
|
|
||||||
|
for vllm_name in self.param_name_list:
|
||||||
|
# Try exact match first
|
||||||
|
if vllm_name in trainer_names:
|
||||||
|
self._vllm_to_trainer_map[vllm_name] = vllm_name
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Try adding "model." prefix (common for HuggingFace models)
|
||||||
|
hf_name = f"model.{vllm_name}"
|
||||||
|
if hf_name in trainer_names:
|
||||||
|
self._vllm_to_trainer_map[vllm_name] = hf_name
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Try other common prefixes
|
||||||
|
for prefix in ["transformer.", "gpt.", "bert.", "encoder.", "decoder."]:
|
||||||
|
prefixed = f"{prefix}{vllm_name}"
|
||||||
|
if prefixed in trainer_names:
|
||||||
|
self._vllm_to_trainer_map[vllm_name] = prefixed
|
||||||
|
break
|
||||||
|
|
||||||
|
mapped = len(self._vllm_to_trainer_map)
|
||||||
|
total = len(self.param_name_list)
|
||||||
|
|
||||||
|
if mapped == 0:
|
||||||
|
print(f"[Bridge] ⚠ Warning: No params matched between trainer and vLLM!")
|
||||||
|
print(f"[Bridge] Trainer params (sample): {list(trainer_names)[:3]}")
|
||||||
|
print(f"[Bridge] vLLM params (sample): {self.param_name_list[:3]}")
|
||||||
|
# Fall back to trainer's param list
|
||||||
|
self.param_name_list = sorted(trainer_names)
|
||||||
|
self._vllm_to_trainer_map = {n: n for n in self.param_name_list}
|
||||||
|
print(f"[Bridge] Falling back to trainer's {len(self.param_name_list)} params")
|
||||||
|
elif mapped < total:
|
||||||
|
print(f"[Bridge] Mapped {mapped}/{total} params from vLLM to trainer")
|
||||||
|
# Only keep mapped params
|
||||||
|
self.param_name_list = sorted(self._vllm_to_trainer_map.keys())
|
||||||
|
else:
|
||||||
|
print(f"[Bridge] ✓ All {mapped} vLLM params mapped to trainer")
|
||||||
|
|
||||||
def broadcast_weights(self, model: nn.Module) -> None:
|
def broadcast_weights(self, model: nn.Module) -> None:
|
||||||
"""
|
"""
|
||||||
Broadcast all model weights to vLLM inference workers.
|
Broadcast all model weights to vLLM inference workers.
|
||||||
|
|
@ -426,18 +483,26 @@ class VLLMWeightBridge:
|
||||||
self._update_count += 1
|
self._update_count += 1
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
state_dict = dict(model.named_parameters())
|
trainer_state_dict = dict(model.named_parameters())
|
||||||
|
|
||||||
|
# Get mapping (vLLM name -> trainer name)
|
||||||
|
vllm_to_trainer = getattr(self, '_vllm_to_trainer_map', {})
|
||||||
|
|
||||||
num_params = 0
|
num_params = 0
|
||||||
|
skipped = 0
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for idx, param_name in enumerate(self.param_name_list):
|
for idx, vllm_name in enumerate(self.param_name_list):
|
||||||
# Get tensor for this parameter
|
# Get trainer's parameter name for this vLLM param
|
||||||
if param_name not in state_dict:
|
trainer_name = vllm_to_trainer.get(vllm_name, vllm_name)
|
||||||
|
|
||||||
|
if trainer_name not in trainer_state_dict:
|
||||||
|
skipped += 1
|
||||||
continue
|
continue
|
||||||
|
|
||||||
tensor = state_dict[param_name].data
|
tensor = trainer_state_dict[trainer_name].data
|
||||||
|
|
||||||
# Step 1: Broadcast parameter index
|
# Step 1: Broadcast parameter index (vLLM's index)
|
||||||
idx_tensor = torch.tensor([idx], dtype=torch.long, device=self.device)
|
idx_tensor = torch.tensor([idx], dtype=torch.long, device=self.device)
|
||||||
dist.broadcast(idx_tensor, src=0, group=self.nccl_group)
|
dist.broadcast(idx_tensor, src=0, group=self.nccl_group)
|
||||||
|
|
||||||
|
|
@ -447,7 +512,10 @@ class VLLMWeightBridge:
|
||||||
num_params += 1
|
num_params += 1
|
||||||
|
|
||||||
elapsed = time.time() - start_time
|
elapsed = time.time() - start_time
|
||||||
print(f"[Bridge] Broadcast {num_params} params, update #{self._update_count} ({elapsed:.2f}s)")
|
if skipped > 0:
|
||||||
|
print(f"[Bridge] Broadcast {num_params} params (skipped {skipped}), update #{self._update_count} ({elapsed:.2f}s)")
|
||||||
|
else:
|
||||||
|
print(f"[Bridge] Broadcast {num_params} params, update #{self._update_count} ({elapsed:.2f}s)")
|
||||||
|
|
||||||
def broadcast_single_param(
|
def broadcast_single_param(
|
||||||
self,
|
self,
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue