mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
parent
d07ab3e3ce
commit
5cfd1929f1
19 changed files with 708 additions and 452 deletions
|
|
@ -373,7 +373,9 @@ class GSM8KEvalEnv(BaseEnv):
|
|||
|
||||
# Create evaluation tasks
|
||||
async def eval_task(item):
|
||||
return await self.rollout_and_score_eval(item, self.server.servers[0].config)
|
||||
return await self.rollout_and_score_eval(
|
||||
item, self.server.servers[0].config
|
||||
)
|
||||
|
||||
tasks = [eval_task(item) for item in self.eval_items]
|
||||
|
||||
|
|
|
|||
|
|
@ -299,6 +299,7 @@ class MathEnv(BaseEnv):
|
|||
if not self.config.run_evaluation:
|
||||
return
|
||||
import time
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
eval_tasks = []
|
||||
|
|
@ -320,9 +321,7 @@ class MathEnv(BaseEnv):
|
|||
metrics[f"{subset}_accuracy"] = accuracy
|
||||
metrics[f"{subset}_total"] = len(scores)
|
||||
metrics[f"{subset}_correct"] = sum(scores)
|
||||
self.eval_metrics.append(
|
||||
(f"eval/{subset}_percent_correct", accuracy)
|
||||
)
|
||||
self.eval_metrics.append((f"eval/{subset}_percent_correct", accuracy))
|
||||
|
||||
# overall score
|
||||
all_scores = []
|
||||
|
|
@ -332,9 +331,7 @@ class MathEnv(BaseEnv):
|
|||
metrics["overall_accuracy"] = overall_accuracy
|
||||
metrics["overall_total"] = len(all_scores)
|
||||
metrics["overall_correct"] = sum(all_scores)
|
||||
self.eval_metrics.append(
|
||||
("eval/overall_percent_correct", overall_accuracy)
|
||||
)
|
||||
self.eval_metrics.append(("eval/overall_percent_correct", overall_accuracy))
|
||||
|
||||
end_time = time.time()
|
||||
|
||||
|
|
@ -342,7 +339,9 @@ class MathEnv(BaseEnv):
|
|||
print("\n" + "=" * 60)
|
||||
print("Math Zero Evaluation Results")
|
||||
print("=" * 60)
|
||||
print(f"Overall Accuracy: {overall_accuracy:.2%} ({sum(all_scores)}/{len(all_scores)})")
|
||||
print(
|
||||
f"Overall Accuracy: {overall_accuracy:.2%} ({sum(all_scores)}/{len(all_scores)})"
|
||||
)
|
||||
print("\nPer-subset breakdown:")
|
||||
for subset, scores in sorted(task_lists.items()):
|
||||
acc = sum(scores) / len(scores)
|
||||
|
|
|
|||
|
|
@ -542,4 +542,3 @@ python -m example_trainer.vllm_api_server # NOT direct vllm commands
|
|||
| `vllm_api_server.py` | Streamlined vLLM server for training |
|
||||
| `vllm_manager.py` | vLLM process lifecycle management |
|
||||
| `checkpointing.py` | Save/load checkpoints and adapters |
|
||||
|
||||
|
|
|
|||
|
|
@ -20,9 +20,9 @@ Usage:
|
|||
train_legacy(config)
|
||||
"""
|
||||
|
||||
from .cli import config_from_args, parse_args
|
||||
from .config import TrainingConfig
|
||||
from .trainers import train_legacy, train_shared_vllm, train_lora
|
||||
from .cli import parse_args, config_from_args
|
||||
from .trainers import train_legacy, train_lora, train_shared_vllm
|
||||
|
||||
__all__ = [
|
||||
"TrainingConfig",
|
||||
|
|
|
|||
|
|
@ -15,7 +15,9 @@ from tenacity import retry, stop_after_attempt, wait_exponential
|
|||
from .config import TrainingConfig
|
||||
|
||||
|
||||
def check_atropos_api(url: str = "http://localhost:8000", timeout: float = 30.0) -> bool:
|
||||
def check_atropos_api(
|
||||
url: str = "http://localhost:8000", timeout: float = 30.0
|
||||
) -> bool:
|
||||
"""
|
||||
Check if the Atropos API server is reachable.
|
||||
|
||||
|
|
@ -99,4 +101,3 @@ def get_batch(url: str = "http://localhost:8000"):
|
|||
raise RuntimeError(f"Atropos API error: {data.get('message', 'Unknown error')}")
|
||||
|
||||
return data
|
||||
|
||||
|
|
|
|||
|
|
@ -89,11 +89,14 @@ def save_checkpoint(
|
|||
|
||||
# Count how many were non-contiguous (views into fused tensors)
|
||||
view_count = sum(
|
||||
1 for name, param in model.named_parameters()
|
||||
1
|
||||
for name, param in model.named_parameters()
|
||||
if not param.is_contiguous() or param.storage_offset() != 0
|
||||
)
|
||||
if view_count > 0:
|
||||
print(f" [Checkpoint] Unfused {view_count} view tensors (qkv/gate_up fusions)")
|
||||
print(
|
||||
f" [Checkpoint] Unfused {view_count} view tensors (qkv/gate_up fusions)"
|
||||
)
|
||||
|
||||
# Save state dict manually, then save config separately
|
||||
torch.save(state_dict, os.path.join(checkpoint_path, "pytorch_model.bin"))
|
||||
|
|
@ -102,6 +105,7 @@ def save_checkpoint(
|
|||
# CRITICAL: Clean up the copied state_dict to free ~8GB GPU memory!
|
||||
del state_dict
|
||||
import gc
|
||||
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
else:
|
||||
|
|
@ -151,4 +155,3 @@ def save_lora_checkpoint(
|
|||
|
||||
print(" Adapter saved.")
|
||||
return adapter_path
|
||||
|
||||
|
|
|
|||
|
|
@ -11,16 +11,17 @@ import torch
|
|||
|
||||
from .config import TrainingConfig
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Argument Group Builders (modular, reusable)
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def add_model_args(parser: argparse.ArgumentParser) -> None:
|
||||
"""Add model-related arguments."""
|
||||
group = parser.add_argument_group("Model")
|
||||
group.add_argument(
|
||||
"--model", "--model-name",
|
||||
"--model",
|
||||
"--model-name",
|
||||
type=str,
|
||||
required=True,
|
||||
dest="model_name",
|
||||
|
|
@ -67,7 +68,7 @@ def add_training_args(parser: argparse.ArgumentParser) -> None:
|
|||
choices=["adamw", "adamw_8bit", "adamw_cpu", "adafactor"],
|
||||
default="adamw_8bit",
|
||||
help="Optimizer: 'adamw' (full precision), 'adamw_8bit' (8-bit states), "
|
||||
"'adamw_cpu' (CPU offload), 'adafactor' (no momentum)",
|
||||
"'adamw_cpu' (CPU offload), 'adafactor' (no momentum)",
|
||||
)
|
||||
group.add_argument(
|
||||
"--device",
|
||||
|
|
@ -121,7 +122,8 @@ def add_vllm_args(parser: argparse.ArgumentParser) -> None:
|
|||
help="Port for the vLLM server",
|
||||
)
|
||||
group.add_argument(
|
||||
"--gpu-memory-utilization", "--vllm-gpu-memory-utilization",
|
||||
"--gpu-memory-utilization",
|
||||
"--vllm-gpu-memory-utilization",
|
||||
type=float,
|
||||
default=0.45,
|
||||
dest="gpu_memory_utilization",
|
||||
|
|
@ -203,7 +205,9 @@ def add_lora_args(parser: argparse.ArgumentParser) -> None:
|
|||
"""Add LoRA-specific arguments."""
|
||||
group = parser.add_argument_group("LoRA Configuration")
|
||||
group.add_argument("--lora-r", type=int, default=16, help="LoRA rank")
|
||||
group.add_argument("--lora-alpha", type=int, default=32, help="LoRA alpha (scaling factor)")
|
||||
group.add_argument(
|
||||
"--lora-alpha", type=int, default=32, help="LoRA alpha (scaling factor)"
|
||||
)
|
||||
group.add_argument("--lora-dropout", type=float, default=0.05, help="LoRA dropout")
|
||||
group.add_argument(
|
||||
"--lora-target-modules",
|
||||
|
|
@ -219,8 +223,12 @@ def add_distributed_args(parser: argparse.ArgumentParser) -> None:
|
|||
group = parser.add_argument_group("Distributed Training")
|
||||
group.add_argument("--trainer-rank", type=int, default=0, help="Trainer rank")
|
||||
group.add_argument("--world-size", type=int, default=1, help="World size")
|
||||
group.add_argument("--init-method", type=str, default="env://", help="Distributed init method")
|
||||
group.add_argument("--num-inference-nodes", type=int, default=0, help="Number of inference nodes")
|
||||
group.add_argument(
|
||||
"--init-method", type=str, default="env://", help="Distributed init method"
|
||||
)
|
||||
group.add_argument(
|
||||
"--num-inference-nodes", type=int, default=0, help="Number of inference nodes"
|
||||
)
|
||||
|
||||
|
||||
def add_debug_args(parser: argparse.ArgumentParser) -> None:
|
||||
|
|
@ -248,6 +256,7 @@ def add_debug_args(parser: argparse.ArgumentParser) -> None:
|
|||
# Parser Builders
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def create_base_parser(description: str) -> argparse.ArgumentParser:
|
||||
"""Create a base parser with common formatting."""
|
||||
return argparse.ArgumentParser(
|
||||
|
|
@ -299,6 +308,7 @@ def create_unified_parser() -> argparse.ArgumentParser:
|
|||
# Legacy API (backwards compatibility)
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
"""
|
||||
Parse command-line arguments for the GRPO trainer (grpo.py).
|
||||
|
|
|
|||
|
|
@ -35,9 +35,9 @@ class TrainingConfig(BaseModel):
|
|||
optimizer: Literal["adamw", "adamw_8bit", "adamw_cpu", "adafactor"] = Field(
|
||||
"adamw_8bit",
|
||||
description="Optimizer to use: 'adamw' (full precision, ~32GB GPU), "
|
||||
"'adamw_8bit' (8-bit states, ~8GB GPU, requires bitsandbytes), "
|
||||
"'adamw_cpu' (CPU offload, ~0GB GPU, slower), "
|
||||
"'adafactor' (no momentum, ~8GB GPU)"
|
||||
"'adamw_8bit' (8-bit states, ~8GB GPU, requires bitsandbytes), "
|
||||
"'adamw_cpu' (CPU offload, ~0GB GPU, slower), "
|
||||
"'adafactor' (no momentum, ~8GB GPU)",
|
||||
)
|
||||
|
||||
# === GRPO/PPO Hyperparameters ===
|
||||
|
|
@ -69,12 +69,10 @@ class TrainingConfig(BaseModel):
|
|||
|
||||
# === Device & Storage ===
|
||||
device: str = Field(
|
||||
"cuda" if torch.cuda.is_available() else "cpu",
|
||||
description="Device to train on"
|
||||
"cuda" if torch.cuda.is_available() else "cpu", description="Device to train on"
|
||||
)
|
||||
save_path: str = Field(
|
||||
"trained_model_checkpoints",
|
||||
description="Base path to save model checkpoints"
|
||||
"trained_model_checkpoints", description="Base path to save model checkpoints"
|
||||
)
|
||||
checkpoint_interval: int = Field(
|
||||
3,
|
||||
|
|
@ -121,9 +119,7 @@ class TrainingConfig(BaseModel):
|
|||
trainer_rank: int = Field(
|
||||
0, description="Rank of this trainer in the distributed group"
|
||||
)
|
||||
world_size: int = Field(
|
||||
1, description="Total processes in the distributed group"
|
||||
)
|
||||
world_size: int = Field(1, description="Total processes in the distributed group")
|
||||
init_method: str = Field(
|
||||
"env://",
|
||||
description=(
|
||||
|
|
@ -198,4 +194,3 @@ class TrainingConfig(BaseModel):
|
|||
"Default is http://localhost:8000. Change for concurrent tests."
|
||||
),
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -92,28 +92,30 @@ def pad_data_to_good_offset(
|
|||
# Process each sample in the item
|
||||
for i in range(len(item["tokens"])):
|
||||
seq_len = len(item["tokens"][i])
|
||||
lengths.append(
|
||||
math.ceil((seq_len - 1) / good_multiple) * good_multiple
|
||||
)
|
||||
lengths.append(math.ceil((seq_len - 1) / good_multiple) * good_multiple)
|
||||
|
||||
# Create labels with padding (-100 for masked positions)
|
||||
label_item = np.concatenate([
|
||||
np.array(item["masks"][i]),
|
||||
np.full(
|
||||
max(0, token_setup_len - seq_len),
|
||||
-100,
|
||||
dtype=np.int32,
|
||||
),
|
||||
])
|
||||
label_item = np.concatenate(
|
||||
[
|
||||
np.array(item["masks"][i]),
|
||||
np.full(
|
||||
max(0, token_setup_len - seq_len),
|
||||
-100,
|
||||
dtype=np.int32,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
# Pad tokens
|
||||
item["tokens"][i] = np.concatenate([
|
||||
np.array(item["tokens"][i]),
|
||||
np.zeros(
|
||||
max(0, token_setup_len - seq_len),
|
||||
dtype=np.int32,
|
||||
),
|
||||
])
|
||||
item["tokens"][i] = np.concatenate(
|
||||
[
|
||||
np.array(item["tokens"][i]),
|
||||
np.zeros(
|
||||
max(0, token_setup_len - seq_len),
|
||||
dtype=np.int32,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
input_ids.append(item["tokens"][i][:-1]) # Remove last for causal
|
||||
labels.append(label_item[1:]) # Shift by 1 for causal
|
||||
|
|
@ -126,7 +128,9 @@ def pad_data_to_good_offset(
|
|||
# We just need to pad to match the sequence length
|
||||
if extract_inference_logprobs and "inference_logprobs" in item:
|
||||
if i < len(item["inference_logprobs"]):
|
||||
raw_logprobs = np.array(item["inference_logprobs"][i], dtype=np.float32)
|
||||
raw_logprobs = np.array(
|
||||
item["inference_logprobs"][i], dtype=np.float32
|
||||
)
|
||||
has_any_logprobs = True
|
||||
|
||||
# Create padded logprobs array matching token_setup_len
|
||||
|
|
@ -141,10 +145,14 @@ def pad_data_to_good_offset(
|
|||
inference_logprobs_padded.append(padded_logprobs[1:])
|
||||
else:
|
||||
# No logprobs for this sample, use 1.0
|
||||
inference_logprobs_padded.append(np.full(token_setup_len - 1, 1.0, dtype=np.float32))
|
||||
inference_logprobs_padded.append(
|
||||
np.full(token_setup_len - 1, 1.0, dtype=np.float32)
|
||||
)
|
||||
elif extract_inference_logprobs:
|
||||
# No inference_logprobs in item, use 1.0
|
||||
inference_logprobs_padded.append(np.full(token_setup_len - 1, 1.0, dtype=np.float32))
|
||||
inference_logprobs_padded.append(
|
||||
np.full(token_setup_len - 1, 1.0, dtype=np.float32)
|
||||
)
|
||||
|
||||
# Extract temperature (priority: override > generation_params > group_overrides > 1.0)
|
||||
t = 1.0
|
||||
|
|
@ -155,9 +163,13 @@ def pad_data_to_good_offset(
|
|||
and ("temperature" in item["overrides"][i])
|
||||
):
|
||||
t = float(item["overrides"][i]["temperature"])
|
||||
elif item.get("generation_params") and ("temperature" in item["generation_params"]):
|
||||
elif item.get("generation_params") and (
|
||||
"temperature" in item["generation_params"]
|
||||
):
|
||||
t = float(item["generation_params"]["temperature"])
|
||||
elif item.get("group_overrides") and ("temperature" in item["group_overrides"]):
|
||||
elif item.get("group_overrides") and (
|
||||
"temperature" in item["group_overrides"]
|
||||
):
|
||||
t = float(item["group_overrides"]["temperature"])
|
||||
temperatures.append(t)
|
||||
|
||||
|
|
@ -172,19 +184,15 @@ def pad_data_to_good_offset(
|
|||
start = i * batch_size
|
||||
end = (i + 1) * batch_size
|
||||
|
||||
token_batches.append(
|
||||
torch.tensor(np.stack(input_ids[start:end], axis=0))
|
||||
)
|
||||
label_batches.append(
|
||||
torch.tensor(np.stack(labels[start:end], axis=0))
|
||||
)
|
||||
token_batches.append(torch.tensor(np.stack(input_ids[start:end], axis=0)))
|
||||
label_batches.append(torch.tensor(np.stack(labels[start:end], axis=0)))
|
||||
advantage_batches.append(
|
||||
torch.tensor(np.stack(advantages[start:end], axis=0)).view(-1, 1)
|
||||
)
|
||||
temperature_batches.append(
|
||||
torch.tensor(
|
||||
np.array(temperatures[start:end], dtype=np.float32)
|
||||
).view(-1, 1, 1)
|
||||
torch.tensor(np.array(temperatures[start:end], dtype=np.float32)).view(
|
||||
-1, 1, 1
|
||||
)
|
||||
)
|
||||
|
||||
# Batch inference logprobs (same shape as labels)
|
||||
|
|
@ -194,9 +202,19 @@ def pad_data_to_good_offset(
|
|||
)
|
||||
|
||||
# Return inference logprob batches if we have any real logprobs
|
||||
final_logprob_batches = inference_logprob_batches if (has_any_logprobs and inference_logprob_batches) else None
|
||||
final_logprob_batches = (
|
||||
inference_logprob_batches
|
||||
if (has_any_logprobs and inference_logprob_batches)
|
||||
else None
|
||||
)
|
||||
|
||||
return token_batches, label_batches, advantage_batches, temperature_batches, final_logprob_batches
|
||||
return (
|
||||
token_batches,
|
||||
label_batches,
|
||||
advantage_batches,
|
||||
temperature_batches,
|
||||
final_logprob_batches,
|
||||
)
|
||||
|
||||
|
||||
def get_data(
|
||||
|
|
@ -205,13 +223,15 @@ def get_data(
|
|||
atropos_url: str = "http://localhost:8000",
|
||||
extract_inference_logprobs: bool = True,
|
||||
) -> Tuple[
|
||||
List[Tuple[
|
||||
List[torch.Tensor], # token_batches
|
||||
List[torch.Tensor], # label_batches
|
||||
List[torch.Tensor], # advantage_batches
|
||||
List[torch.Tensor], # temperature_batches
|
||||
Optional[List[torch.Tensor]], # inference_logprob_batches
|
||||
]],
|
||||
List[
|
||||
Tuple[
|
||||
List[torch.Tensor], # token_batches
|
||||
List[torch.Tensor], # label_batches
|
||||
List[torch.Tensor], # advantage_batches
|
||||
List[torch.Tensor], # temperature_batches
|
||||
Optional[List[torch.Tensor]], # inference_logprob_batches
|
||||
]
|
||||
],
|
||||
None, # Legacy return (no longer used)
|
||||
]:
|
||||
"""
|
||||
|
|
@ -241,18 +261,37 @@ def get_data(
|
|||
if data["batch"] is not None:
|
||||
# DEBUG: Check if inference_logprobs exists in the data
|
||||
if not _logged_logprob_warning:
|
||||
has_logprobs = any("inference_logprobs" in item for item in data["batch"])
|
||||
has_logprobs = any(
|
||||
"inference_logprobs" in item for item in data["batch"]
|
||||
)
|
||||
if has_logprobs:
|
||||
# Check if they're non-empty
|
||||
sample_item = next((item for item in data["batch"] if "inference_logprobs" in item), None)
|
||||
sample_item = next(
|
||||
(
|
||||
item
|
||||
for item in data["batch"]
|
||||
if "inference_logprobs" in item
|
||||
),
|
||||
None,
|
||||
)
|
||||
if sample_item and sample_item.get("inference_logprobs"):
|
||||
sample_lp = sample_item["inference_logprobs"][0] if sample_item["inference_logprobs"] else []
|
||||
print(f" [Data] ✓ inference_logprobs found in batch (sample len: {len(sample_lp)})")
|
||||
sample_lp = (
|
||||
sample_item["inference_logprobs"][0]
|
||||
if sample_item["inference_logprobs"]
|
||||
else []
|
||||
)
|
||||
print(
|
||||
f" [Data] ✓ inference_logprobs found in batch (sample len: {len(sample_lp)})"
|
||||
)
|
||||
else:
|
||||
print(" [Data] ⚠ inference_logprobs key exists but is empty!")
|
||||
print(
|
||||
" [Data] ⚠ inference_logprobs key exists but is empty!"
|
||||
)
|
||||
else:
|
||||
print(" [Data] ⚠ NO inference_logprobs in batch data!")
|
||||
print(f" [Data] Keys in first item: {list(data['batch'][0].keys())}")
|
||||
print(
|
||||
f" [Data] Keys in first item: {list(data['batch'][0].keys())}"
|
||||
)
|
||||
_logged_logprob_warning = True
|
||||
|
||||
# Save batch for debugging
|
||||
|
|
@ -260,11 +299,24 @@ def get_data(
|
|||
json.dump(data, f)
|
||||
|
||||
# Process and accumulate batches (now includes batched inference logprobs)
|
||||
token_batches, label_batches, adv_batches, temp_batches, inf_logprob_batches = \
|
||||
pad_data_to_good_offset(data, batch_size, extract_inference_logprobs)
|
||||
(
|
||||
token_batches,
|
||||
label_batches,
|
||||
adv_batches,
|
||||
temp_batches,
|
||||
inf_logprob_batches,
|
||||
) = pad_data_to_good_offset(data, batch_size, extract_inference_logprobs)
|
||||
|
||||
# Include inference logprob batches in the tuple
|
||||
batches.append((token_batches, label_batches, adv_batches, temp_batches, inf_logprob_batches))
|
||||
batches.append(
|
||||
(
|
||||
token_batches,
|
||||
label_batches,
|
||||
adv_batches,
|
||||
temp_batches,
|
||||
inf_logprob_batches,
|
||||
)
|
||||
)
|
||||
|
||||
elif len(batches) > 0:
|
||||
# Return accumulated batches when no more data
|
||||
|
|
@ -272,4 +324,3 @@ def get_data(
|
|||
else:
|
||||
# Wait for data
|
||||
time.sleep(1)
|
||||
|
||||
|
|
|
|||
|
|
@ -19,8 +19,8 @@ Usage:
|
|||
--weight-bridge-mode lora_only --lora-r 16 --lora-alpha 32
|
||||
"""
|
||||
|
||||
from .cli import parse_args, config_from_args
|
||||
from .trainers import train_legacy, train_shared_vllm, train_lora
|
||||
from .cli import config_from_args, parse_args
|
||||
from .trainers import train_legacy, train_lora, train_shared_vllm
|
||||
|
||||
|
||||
def main():
|
||||
|
|
@ -28,9 +28,9 @@ def main():
|
|||
args = parse_args()
|
||||
config = config_from_args(args)
|
||||
|
||||
print("\n" + "="*60)
|
||||
print("\n" + "=" * 60)
|
||||
print("GRPO TRAINER")
|
||||
print("="*60)
|
||||
print("=" * 60)
|
||||
print(f"Model: {config.model_name}")
|
||||
print(f"Mode: {config.weight_bridge_mode}")
|
||||
print(f"Training steps: {config.training_steps}")
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ from .config import TrainingConfig
|
|||
# Import PEFT for LoRA training
|
||||
try:
|
||||
from peft import LoraConfig, TaskType, get_peft_model
|
||||
|
||||
PEFT_AVAILABLE = True
|
||||
except ImportError:
|
||||
PEFT_AVAILABLE = False
|
||||
|
|
@ -37,6 +38,7 @@ def _get_attention_implementation() -> str:
|
|||
"""
|
||||
try:
|
||||
import flash_attn # noqa: F401
|
||||
|
||||
return "flash_attention_2"
|
||||
except ImportError:
|
||||
return "sdpa"
|
||||
|
|
@ -61,12 +63,19 @@ def _load_model_with_attention(
|
|||
"""
|
||||
# Select the loader function based on mode
|
||||
# from_config: creates empty shell (meta device), from_pretrained: loads weights
|
||||
loader = AutoModelForCausalLM.from_config if from_config else AutoModelForCausalLM.from_pretrained
|
||||
loader = (
|
||||
AutoModelForCausalLM.from_config
|
||||
if from_config
|
||||
else AutoModelForCausalLM.from_pretrained
|
||||
)
|
||||
|
||||
# Try attention implementations in order of preference
|
||||
for attn_impl in ["flash_attention_2", "sdpa"]:
|
||||
# Skip flash_attention_2 if not available
|
||||
if attn_impl == "flash_attention_2" and _get_attention_implementation() != "flash_attention_2":
|
||||
if (
|
||||
attn_impl == "flash_attention_2"
|
||||
and _get_attention_implementation() != "flash_attention_2"
|
||||
):
|
||||
continue
|
||||
|
||||
try:
|
||||
|
|
@ -86,6 +95,7 @@ def _load_model_with_attention(
|
|||
# Should never reach here, but just in case
|
||||
raise RuntimeError("Failed to load model with any attention implementation")
|
||||
|
||||
|
||||
def load_model_and_tokenizer(
|
||||
config: TrainingConfig,
|
||||
single_copy: bool = False,
|
||||
|
|
@ -178,7 +188,7 @@ def _load_model_with_lora(config: TrainingConfig) -> torch.nn.Module:
|
|||
Returns:
|
||||
PEFT model with LoRA adapters applied
|
||||
"""
|
||||
if not PEFT_AVAILABLE: # Yeah no PEFT is needed no matter what bless huggingface
|
||||
if not PEFT_AVAILABLE: # Yeah no PEFT is needed no matter what bless huggingface
|
||||
raise RuntimeError("PEFT library not available. Install with: pip install peft")
|
||||
|
||||
print("[Setup] Loading base model for LoRA mode...")
|
||||
|
|
@ -208,7 +218,9 @@ def _load_model_with_lora(config: TrainingConfig) -> torch.nn.Module:
|
|||
return model
|
||||
|
||||
|
||||
def _setup_gradient_checkpointing(model: torch.nn.Module, config: TrainingConfig) -> None:
|
||||
def _setup_gradient_checkpointing(
|
||||
model: torch.nn.Module, config: TrainingConfig
|
||||
) -> None:
|
||||
"""Configure gradient checkpointing for the model."""
|
||||
# Disable KV cache - incompatible with gradient checkpointing
|
||||
model.config.use_cache = False
|
||||
|
|
@ -297,7 +309,9 @@ def _attach_to_vllm_shared_tensors(
|
|||
ipc_handles, vllm_to_hf_mapping, config
|
||||
)
|
||||
|
||||
print(f"[Setup] Attached {attached_count} tensors ({fused_count} from fused layers)")
|
||||
print(
|
||||
f"[Setup] Attached {attached_count} tensors ({fused_count} from fused layers)"
|
||||
)
|
||||
|
||||
if attached_count == 0:
|
||||
print("[Setup] Could not attach any tensors, falling back to regular loading")
|
||||
|
|
@ -322,6 +336,7 @@ def _attach_to_vllm_shared_tensors(
|
|||
|
||||
def _deserialize_ipc_handles(handles_raw: dict) -> dict:
|
||||
"""Deserialize base64-encoded bytes in IPC handles."""
|
||||
|
||||
def deserialize(handles):
|
||||
result = {}
|
||||
for k, v in handles.items():
|
||||
|
|
@ -333,6 +348,7 @@ def _deserialize_ipc_handles(handles_raw: dict) -> dict:
|
|||
else:
|
||||
result[k] = v
|
||||
return result
|
||||
|
||||
return deserialize(handles_raw)
|
||||
|
||||
|
||||
|
|
@ -387,8 +403,14 @@ def _reconstruct_shared_tensors(
|
|||
event_sync_required = ipc_info["event_sync_required"]
|
||||
|
||||
share_tuple = (
|
||||
device_index, ipc_handle, storage_size, storage_offset_orig,
|
||||
ref_counter_handle, ref_counter_offset, event_handle, event_sync_required,
|
||||
device_index,
|
||||
ipc_handle,
|
||||
storage_size,
|
||||
storage_offset_orig,
|
||||
ref_counter_handle,
|
||||
ref_counter_offset,
|
||||
event_handle,
|
||||
event_sync_required,
|
||||
)
|
||||
|
||||
storage = torch.UntypedStorage._new_shared_cuda(*share_tuple)
|
||||
|
|
@ -424,7 +446,9 @@ def _reconstruct_shared_tensors(
|
|||
if slice_dim == 0:
|
||||
tensor = full_tensor[slice_start:slice_end]
|
||||
else:
|
||||
tensor = full_tensor.narrow(slice_dim, slice_start, slice_end - slice_start)
|
||||
tensor = full_tensor.narrow(
|
||||
slice_dim, slice_start, slice_end - slice_start
|
||||
)
|
||||
|
||||
tensor.requires_grad_(True)
|
||||
hf_state_dict[hf_name] = tensor
|
||||
|
|
@ -459,12 +483,16 @@ def _validate_mapping_coverage(
|
|||
|
||||
# Note: attached_count may be > param_count because state_dict includes buffers
|
||||
# while named_parameters only counts trainable params
|
||||
print(f"[Setup] Mapping coverage: {attached_count} tensors for {hf_param_count} parameters "
|
||||
f"(>100% is OK - includes buffers)")
|
||||
print(
|
||||
f"[Setup] Mapping coverage: {attached_count} tensors for {hf_param_count} parameters "
|
||||
f"(>100% is OK - includes buffers)"
|
||||
)
|
||||
|
||||
if mapping_coverage < 0.90:
|
||||
unmapped_params = set(model.state_dict().keys()) - set(hf_state_dict.keys())
|
||||
warning_msg = f"[Setup] WARNING: Low mapping coverage ({mapping_coverage:.1%})\n"
|
||||
warning_msg = (
|
||||
f"[Setup] WARNING: Low mapping coverage ({mapping_coverage:.1%})\n"
|
||||
)
|
||||
warning_msg += f"Unmapped parameters ({len(unmapped_params)}):\n"
|
||||
for name in list(unmapped_params)[:20]:
|
||||
warning_msg += f" - {name}\n"
|
||||
|
|
@ -484,11 +512,17 @@ def _initialize_meta_tensors(
|
|||
config: TrainingConfig,
|
||||
) -> None:
|
||||
"""Initialize any remaining meta tensors after loading."""
|
||||
meta_params = [name for name, p in model.named_parameters() if p.device.type == "meta"]
|
||||
meta_buffers = [name for name, b in model.named_buffers() if b.device.type == "meta"]
|
||||
meta_params = [
|
||||
name for name, p in model.named_parameters() if p.device.type == "meta"
|
||||
]
|
||||
meta_buffers = [
|
||||
name for name, b in model.named_buffers() if b.device.type == "meta"
|
||||
]
|
||||
|
||||
if config.debug_loading:
|
||||
print(f"\n[DIAGNOSTIC] Meta params: {len(meta_params)}, Meta buffers: {len(meta_buffers)}")
|
||||
print(
|
||||
f"\n[DIAGNOSTIC] Meta params: {len(meta_params)}, Meta buffers: {len(meta_buffers)}"
|
||||
)
|
||||
|
||||
def get_parent_and_name(model, full_name):
|
||||
parts = full_name.split(".")
|
||||
|
|
@ -526,11 +560,15 @@ def _initialize_meta_tensors(
|
|||
dim = buffer.shape[0] * 2
|
||||
# Get rope_theta from model config (default 10000.0 for LLaMA, but Qwen3 uses 5000000!)
|
||||
rope_theta = getattr(model.config, "rope_theta", 10000.0)
|
||||
inv_freq = 1.0 / (rope_theta ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))
|
||||
inv_freq = 1.0 / (
|
||||
rope_theta ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)
|
||||
)
|
||||
new_buffer = inv_freq.to(dtype=buffer.dtype, device=device)
|
||||
print(f"[Setup] Initialized {name} with rope_theta={rope_theta}")
|
||||
else:
|
||||
new_buffer = torch.zeros(buffer.shape, dtype=buffer.dtype, device=device)
|
||||
new_buffer = torch.zeros(
|
||||
buffer.shape, dtype=buffer.dtype, device=device
|
||||
)
|
||||
|
||||
parent, attr_name = get_parent_and_name(model, name)
|
||||
parent.register_buffer(attr_name, new_buffer)
|
||||
|
|
@ -544,8 +582,12 @@ def _initialize_meta_tensors(
|
|||
|
||||
def _validate_no_meta_tensors(model: torch.nn.Module) -> None:
|
||||
"""Ensure no parameters or buffers are still on meta device."""
|
||||
final_meta_params = [name for name, p in model.named_parameters() if p.device.type == "meta"]
|
||||
final_meta_buffers = [name for name, b in model.named_buffers() if b.device.type == "meta"]
|
||||
final_meta_params = [
|
||||
name for name, p in model.named_parameters() if p.device.type == "meta"
|
||||
]
|
||||
final_meta_buffers = [
|
||||
name for name, b in model.named_buffers() if b.device.type == "meta"
|
||||
]
|
||||
|
||||
if final_meta_params or final_meta_buffers:
|
||||
error_msg = "[Setup] CRITICAL ERROR: Some tensors are still on meta device!\n"
|
||||
|
|
@ -588,7 +630,9 @@ def _create_vllm_to_hf_mapping(
|
|||
model_config = model.config
|
||||
hidden_size = getattr(model_config, "hidden_size", 4096)
|
||||
num_attention_heads = getattr(model_config, "num_attention_heads", 32)
|
||||
num_key_value_heads = getattr(model_config, "num_key_value_heads", num_attention_heads)
|
||||
num_key_value_heads = getattr(
|
||||
model_config, "num_key_value_heads", num_attention_heads
|
||||
)
|
||||
intermediate_size = getattr(model_config, "intermediate_size", hidden_size * 4)
|
||||
|
||||
# Try to get head_dim from config (some models like Qwen3 have this)
|
||||
|
|
@ -639,8 +683,10 @@ def _create_vllm_to_hf_mapping(
|
|||
up_size = intermediate_size
|
||||
|
||||
# Always print sizes for debugging weight sharing issues
|
||||
print(f"[Mapping] Model config: hidden={hidden_size}, heads={num_attention_heads}, "
|
||||
f"kv_heads={num_key_value_heads}, head_dim={head_dim}")
|
||||
print(
|
||||
f"[Mapping] Model config: hidden={hidden_size}, heads={num_attention_heads}, "
|
||||
f"kv_heads={num_key_value_heads}, head_dim={head_dim}"
|
||||
)
|
||||
print(f"[Mapping] QKV sizes from HF model: q={q_size}, k={k_size}, v={v_size}")
|
||||
print(f"[Mapping] Gate/Up sizes from HF model: gate={gate_size}, up={up_size}")
|
||||
|
||||
|
|
@ -714,7 +760,8 @@ def _create_vllm_to_hf_mapping(
|
|||
if debug:
|
||||
direct = sum(1 for v in mapping.values() if isinstance(v, str))
|
||||
fused = sum(1 for v in mapping.values() if isinstance(v, dict))
|
||||
print(f"[Mapping] Total: {len(mapping)} mapped ({direct} direct, {fused} fused)")
|
||||
print(
|
||||
f"[Mapping] Total: {len(mapping)} mapped ({direct} direct, {fused} fused)"
|
||||
)
|
||||
|
||||
return mapping
|
||||
|
||||
|
|
|
|||
|
|
@ -60,10 +60,13 @@ def wait_for_bridge_config(config_path: str, timeout: int = 60) -> bool:
|
|||
if os.path.exists(config_path):
|
||||
try:
|
||||
import json
|
||||
with open(config_path, 'r') as f:
|
||||
|
||||
with open(config_path, "r") as f:
|
||||
config = json.load(f)
|
||||
if config.get('ipc_handles') and len(config['ipc_handles']) > 0:
|
||||
print(f"[Run] ✓ Bridge config ready with {len(config['ipc_handles'])} IPC handles")
|
||||
if config.get("ipc_handles") and len(config["ipc_handles"]) > 0:
|
||||
print(
|
||||
f"[Run] ✓ Bridge config ready with {len(config['ipc_handles'])} IPC handles"
|
||||
)
|
||||
return True
|
||||
except Exception:
|
||||
pass
|
||||
|
|
@ -79,7 +82,7 @@ def main():
|
|||
args = parser.parse_args()
|
||||
|
||||
# Create log directory
|
||||
log_dir = getattr(args, 'log_dir', './logs')
|
||||
log_dir = getattr(args, "log_dir", "./logs")
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
|
||||
# Bridge config path
|
||||
|
|
@ -91,16 +94,16 @@ def main():
|
|||
print("[Run] Removed old bridge config")
|
||||
|
||||
# === Print Configuration ===
|
||||
print("\n" + "="*60)
|
||||
print("\n" + "=" * 60)
|
||||
print("STARTING UNIFIED GRPO TRAINER (shared_vllm mode)")
|
||||
print("="*60)
|
||||
print("=" * 60)
|
||||
print(f"Model: {args.model_name}")
|
||||
print(f"vLLM port: {args.vllm_port}")
|
||||
print(f"GPU memory utilization: {args.gpu_memory_utilization}")
|
||||
print(f"Training steps: {args.training_steps}")
|
||||
print(f"Optimizer: {args.optimizer}")
|
||||
print(f"GRPO: kl_coef={args.kl_coef}, clip_eps={args.clip_eps}")
|
||||
print("="*60 + "\n")
|
||||
print("=" * 60 + "\n")
|
||||
|
||||
# Get the path to vllm_api_server.py
|
||||
script_dir = Path(__file__).parent
|
||||
|
|
@ -126,12 +129,19 @@ def main():
|
|||
|
||||
# Build vLLM command
|
||||
vllm_cmd = [
|
||||
sys.executable, "-u", str(vllm_server_script),
|
||||
"--model", args.model_name,
|
||||
"--port", str(args.vllm_port),
|
||||
"--dtype", args.dtype,
|
||||
"--gpu-memory-utilization", str(args.gpu_memory_utilization),
|
||||
"--max-model-len", str(args.max_model_len),
|
||||
sys.executable,
|
||||
"-u",
|
||||
str(vllm_server_script),
|
||||
"--model",
|
||||
args.model_name,
|
||||
"--port",
|
||||
str(args.vllm_port),
|
||||
"--dtype",
|
||||
args.dtype,
|
||||
"--gpu-memory-utilization",
|
||||
str(args.gpu_memory_utilization),
|
||||
"--max-model-len",
|
||||
str(args.max_model_len),
|
||||
"--enforce-eager", # Required for shared weights
|
||||
]
|
||||
|
||||
|
|
@ -212,6 +222,7 @@ def main():
|
|||
except Exception as e:
|
||||
print(f"\n[Run] ✗ Training failed: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
sys.exit(1)
|
||||
|
||||
|
|
|
|||
|
|
@ -138,4 +138,3 @@ if [ -d "$LOG_DIR/checkpoints" ]; then
|
|||
}' | jq '.text[0]' | tee "${LOG_DIR}/trained_response.txt"
|
||||
fi
|
||||
fi
|
||||
|
||||
|
|
|
|||
|
|
@ -143,4 +143,3 @@ curl -s -X POST "http://localhost:${VLLM_PORT}/generate" \
|
|||
"max_tokens": 100,
|
||||
"temperature": 0.1
|
||||
}' | jq '.text[0]' | tee "${LOG_DIR}/trained_response.txt"
|
||||
|
||||
|
|
|
|||
|
|
@ -25,7 +25,10 @@ class CPUOffloadAdamW(torch.optim.Optimizer):
|
|||
Full precision (no quantization), but states stay on CPU RAM instead of GPU.
|
||||
Trade-off: Slower (~2x) but uses ~0GB GPU memory for optimizer states.
|
||||
"""
|
||||
def __init__(self, params, lr=1e-5, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.01):
|
||||
|
||||
def __init__(
|
||||
self, params, lr=1e-5, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.01
|
||||
):
|
||||
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
|
||||
super().__init__(params, defaults)
|
||||
|
||||
|
|
@ -33,10 +36,10 @@ class CPUOffloadAdamW(torch.optim.Optimizer):
|
|||
"""Lazily initialize state on CPU."""
|
||||
state = self.state[p]
|
||||
if len(state) == 0:
|
||||
state['step'] = 0
|
||||
state["step"] = 0
|
||||
# Store on CPU in FP32
|
||||
state['exp_avg'] = torch.zeros_like(p, device='cpu', dtype=torch.float32)
|
||||
state['exp_avg_sq'] = torch.zeros_like(p, device='cpu', dtype=torch.float32)
|
||||
state["exp_avg"] = torch.zeros_like(p, device="cpu", dtype=torch.float32)
|
||||
state["exp_avg_sq"] = torch.zeros_like(p, device="cpu", dtype=torch.float32)
|
||||
return state
|
||||
|
||||
@torch.no_grad()
|
||||
|
|
@ -47,41 +50,41 @@ class CPUOffloadAdamW(torch.optim.Optimizer):
|
|||
loss = closure()
|
||||
|
||||
for group in self.param_groups:
|
||||
beta1, beta2 = group['betas']
|
||||
beta1, beta2 = group["betas"]
|
||||
|
||||
for p in group['params']:
|
||||
for p in group["params"]:
|
||||
if p.grad is None:
|
||||
continue
|
||||
|
||||
grad = p.grad
|
||||
state = self._init_state(p)
|
||||
|
||||
state['step'] += 1
|
||||
state["step"] += 1
|
||||
|
||||
# Move states to GPU for computation
|
||||
exp_avg = state['exp_avg'].to(p.device)
|
||||
exp_avg_sq = state['exp_avg_sq'].to(p.device)
|
||||
exp_avg = state["exp_avg"].to(p.device)
|
||||
exp_avg_sq = state["exp_avg_sq"].to(p.device)
|
||||
|
||||
# AdamW update
|
||||
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
|
||||
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
|
||||
|
||||
# Bias correction
|
||||
bias_correction1 = 1 - beta1 ** state['step']
|
||||
bias_correction2 = 1 - beta2 ** state['step']
|
||||
step_size = group['lr'] / bias_correction1
|
||||
bias_correction1 = 1 - beta1 ** state["step"]
|
||||
bias_correction2 = 1 - beta2 ** state["step"]
|
||||
step_size = group["lr"] / bias_correction1
|
||||
|
||||
# Update weights
|
||||
denom = (exp_avg_sq.sqrt() / (bias_correction2 ** 0.5)).add_(group['eps'])
|
||||
denom = (exp_avg_sq.sqrt() / (bias_correction2**0.5)).add_(group["eps"])
|
||||
p.addcdiv_(exp_avg, denom, value=-step_size)
|
||||
|
||||
# Weight decay
|
||||
if group['weight_decay'] != 0:
|
||||
p.add_(p, alpha=-group['lr'] * group['weight_decay'])
|
||||
if group["weight_decay"] != 0:
|
||||
p.add_(p, alpha=-group["lr"] * group["weight_decay"])
|
||||
|
||||
# Move states back to CPU (non-blocking for better perf)
|
||||
state['exp_avg'].copy_(exp_avg.cpu())
|
||||
state['exp_avg_sq'].copy_(exp_avg_sq.cpu())
|
||||
state["exp_avg"].copy_(exp_avg.cpu())
|
||||
state["exp_avg_sq"].copy_(exp_avg_sq.cpu())
|
||||
|
||||
return loss
|
||||
|
||||
|
|
@ -99,6 +102,7 @@ def create_optimizer(model: torch.nn.Module, config) -> torch.optim.Optimizer:
|
|||
if config.optimizer == "adamw_8bit":
|
||||
try:
|
||||
import bitsandbytes as bnb
|
||||
|
||||
optimizer = bnb.optim.AdamW8bit(model.parameters(), lr=config.lr)
|
||||
print("[Setup] Using 8-bit AdamW (saves ~24GB optimizer memory)")
|
||||
return optimizer
|
||||
|
|
@ -108,13 +112,18 @@ def create_optimizer(model: torch.nn.Module, config) -> torch.optim.Optimizer:
|
|||
|
||||
if config.optimizer == "adamw_cpu":
|
||||
optimizer = CPUOffloadAdamW(model.parameters(), lr=config.lr)
|
||||
print("[Setup] Using AdamW with CPU offload (full precision, ~0GB GPU for states)")
|
||||
print("[Setup] NOTE: ~2x slower due to CPU<->GPU transfers, but no quantization")
|
||||
print(
|
||||
"[Setup] Using AdamW with CPU offload (full precision, ~0GB GPU for states)"
|
||||
)
|
||||
print(
|
||||
"[Setup] NOTE: ~2x slower due to CPU<->GPU transfers, but no quantization"
|
||||
)
|
||||
return optimizer
|
||||
|
||||
if config.optimizer == "adafactor":
|
||||
try:
|
||||
from transformers.optimization import Adafactor
|
||||
|
||||
optimizer = Adafactor(
|
||||
model.parameters(),
|
||||
lr=config.lr,
|
||||
|
|
@ -135,7 +144,7 @@ def create_optimizer(model: torch.nn.Module, config) -> torch.optim.Optimizer:
|
|||
from .checkpointing import save_checkpoint, save_lora_checkpoint # noqa: E402
|
||||
from .config import TrainingConfig # noqa: E402
|
||||
from .data import get_data # noqa: E402
|
||||
from .model import load_model_and_tokenizer, PEFT_AVAILABLE # noqa: E402
|
||||
from .model import PEFT_AVAILABLE, load_model_and_tokenizer # noqa: E402
|
||||
from .training import ( # noqa: E402
|
||||
finalize_training,
|
||||
log_metrics,
|
||||
|
|
@ -146,8 +155,8 @@ from .vllm_manager import ( # noqa: E402
|
|||
check_vllm_health,
|
||||
check_vllm_process_health,
|
||||
launch_vllm_server,
|
||||
terminate_vllm_process,
|
||||
set_vllm_process,
|
||||
terminate_vllm_process,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -171,13 +180,13 @@ def train_legacy(config: TrainingConfig):
|
|||
model, tokenizer = load_model_and_tokenizer(config)
|
||||
optimizer = create_optimizer(model, config)
|
||||
|
||||
print("\n" + "="*60)
|
||||
print("\n" + "=" * 60)
|
||||
print("LEGACY MODE (checkpoint + vLLM restart)")
|
||||
print("="*60)
|
||||
print("=" * 60)
|
||||
print(f"Training for {config.training_steps} steps on {config.device}")
|
||||
print(f"vLLM restart interval: every {config.vllm_restart_interval} steps")
|
||||
print(f"Save path: {config.save_path}")
|
||||
print("="*60 + "\n")
|
||||
print("=" * 60 + "\n")
|
||||
|
||||
os.makedirs(config.save_path, exist_ok=True)
|
||||
|
||||
|
|
@ -206,24 +215,36 @@ def train_legacy(config: TrainingConfig):
|
|||
# Fetch data (with inference logprobs for proper GRPO)
|
||||
data_fetch_start = time.time()
|
||||
if len(batches) == 0:
|
||||
batches, _ = get_data(config.batch_size, config.seq_len, config.atropos_url,
|
||||
extract_inference_logprobs=True)
|
||||
batches, _ = get_data(
|
||||
config.batch_size,
|
||||
config.seq_len,
|
||||
config.atropos_url,
|
||||
extract_inference_logprobs=True,
|
||||
)
|
||||
batch_data = batches.pop(0)
|
||||
token_batches, label_batches, advantage_batches, temperature_batches = batch_data[:4]
|
||||
token_batches, label_batches, advantage_batches, temperature_batches = (
|
||||
batch_data[:4]
|
||||
)
|
||||
inference_logprob_batches = batch_data[4] if len(batch_data) > 4 else None
|
||||
data_fetch_time = time.time() - data_fetch_start
|
||||
benchmark_stats["data_fetch_times"].append(data_fetch_time)
|
||||
|
||||
# Check if we should sync (save checkpoint + restart vLLM)
|
||||
should_sync = (step + 1) % config.vllm_restart_interval == 0 or step == config.training_steps - 1
|
||||
should_sync = (
|
||||
step + 1
|
||||
) % config.vllm_restart_interval == 0 or step == config.training_steps - 1
|
||||
if should_sync:
|
||||
terminate_vllm_process()
|
||||
|
||||
# Training step (with proper GRPO using inference logprobs)
|
||||
step_start = time.time()
|
||||
metrics = run_training_step(
|
||||
model, optimizer,
|
||||
token_batches, label_batches, advantage_batches, temperature_batches,
|
||||
model,
|
||||
optimizer,
|
||||
token_batches,
|
||||
label_batches,
|
||||
advantage_batches,
|
||||
temperature_batches,
|
||||
config,
|
||||
inference_logprob_batches=inference_logprob_batches,
|
||||
)
|
||||
|
|
@ -231,15 +252,21 @@ def train_legacy(config: TrainingConfig):
|
|||
benchmark_stats["step_times"].append(step_time)
|
||||
|
||||
# GPU memory tracking
|
||||
gpu_mem_gb = torch.cuda.memory_allocated() / 1e9 if torch.cuda.is_available() else 0
|
||||
gpu_mem_reserved_gb = torch.cuda.memory_reserved() / 1e9 if torch.cuda.is_available() else 0
|
||||
gpu_mem_gb = (
|
||||
torch.cuda.memory_allocated() / 1e9 if torch.cuda.is_available() else 0
|
||||
)
|
||||
gpu_mem_reserved_gb = (
|
||||
torch.cuda.memory_reserved() / 1e9 if torch.cuda.is_available() else 0
|
||||
)
|
||||
benchmark_stats["gpu_memories"].append(gpu_mem_gb)
|
||||
|
||||
# Sync (checkpoint + restart)
|
||||
sync_time = 0
|
||||
if should_sync:
|
||||
sync_start = time.time()
|
||||
checkpoint_path = save_checkpoint(model, tokenizer, config.save_path, step + 1)
|
||||
checkpoint_path = save_checkpoint(
|
||||
model, tokenizer, config.save_path, step + 1
|
||||
)
|
||||
torch.cuda.empty_cache()
|
||||
vllm_proc = launch_vllm_server(config, checkpoint_path)
|
||||
set_vllm_process(vllm_proc)
|
||||
|
|
@ -247,20 +274,31 @@ def train_legacy(config: TrainingConfig):
|
|||
benchmark_stats["sync_times"].append(sync_time)
|
||||
|
||||
# Update metrics
|
||||
metrics.update({
|
||||
"step_time": step_time,
|
||||
"sync_time": sync_time,
|
||||
"data_fetch_time": data_fetch_time,
|
||||
"gpu_memory_gb": gpu_mem_gb,
|
||||
"gpu_memory_reserved_gb": gpu_mem_reserved_gb,
|
||||
})
|
||||
metrics.update(
|
||||
{
|
||||
"step_time": step_time,
|
||||
"sync_time": sync_time,
|
||||
"data_fetch_time": data_fetch_time,
|
||||
"gpu_memory_gb": gpu_mem_gb,
|
||||
"gpu_memory_reserved_gb": gpu_mem_reserved_gb,
|
||||
}
|
||||
)
|
||||
|
||||
log_metrics(metrics, step + 1, use_wandb, benchmark=config.benchmark)
|
||||
check_vllm_process_health()
|
||||
|
||||
# === Cleanup ===
|
||||
save_checkpoint(model, tokenizer, config.save_path, config.training_steps, is_final=True)
|
||||
finalize_training(use_wandb, training_start_time, "legacy", config.training_steps, benchmark_stats, config.benchmark)
|
||||
save_checkpoint(
|
||||
model, tokenizer, config.save_path, config.training_steps, is_final=True
|
||||
)
|
||||
finalize_training(
|
||||
use_wandb,
|
||||
training_start_time,
|
||||
"legacy",
|
||||
config.training_steps,
|
||||
benchmark_stats,
|
||||
config.benchmark,
|
||||
)
|
||||
|
||||
|
||||
def train_shared_vllm(config: TrainingConfig):
|
||||
|
|
@ -281,13 +319,13 @@ def train_shared_vllm(config: TrainingConfig):
|
|||
# === Setup ===
|
||||
use_wandb = setup_wandb(config)
|
||||
|
||||
print("\n" + "="*60)
|
||||
print("\n" + "=" * 60)
|
||||
print("SINGLE-COPY MODE (CUDA IPC)")
|
||||
print(">>> Trainer uses vLLM's tensors directly!")
|
||||
print("="*60)
|
||||
print("=" * 60)
|
||||
print(f"Model: {config.model_name}")
|
||||
print(f"Save path: {config.save_path}")
|
||||
print("="*60 + "\n")
|
||||
print("=" * 60 + "\n")
|
||||
|
||||
# Attach to vLLM's shared tensors
|
||||
print("[1/2] Attaching to vLLM's shared tensors...")
|
||||
|
|
@ -331,11 +369,15 @@ def train_shared_vllm(config: TrainingConfig):
|
|||
data_fetch_start = time.time()
|
||||
if len(batches) == 0:
|
||||
batches, _ = get_data(
|
||||
config.batch_size, config.seq_len, config.atropos_url,
|
||||
config.batch_size,
|
||||
config.seq_len,
|
||||
config.atropos_url,
|
||||
extract_inference_logprobs=True, # Enable proper GRPO with reference logprobs
|
||||
)
|
||||
batch_data = batches.pop(0)
|
||||
token_batches, label_batches, advantage_batches, temperature_batches = batch_data[:4]
|
||||
token_batches, label_batches, advantage_batches, temperature_batches = (
|
||||
batch_data[:4]
|
||||
)
|
||||
inference_logprob_batches = batch_data[4] if len(batch_data) > 4 else None
|
||||
data_fetch_time = time.time() - data_fetch_start
|
||||
benchmark_stats["data_fetch_times"].append(data_fetch_time)
|
||||
|
|
@ -343,8 +385,12 @@ def train_shared_vllm(config: TrainingConfig):
|
|||
# Training step with proper GRPO (importance sampling + KL penalty)
|
||||
step_start = time.time()
|
||||
metrics = run_training_step(
|
||||
model, optimizer,
|
||||
token_batches, label_batches, advantage_batches, temperature_batches,
|
||||
model,
|
||||
optimizer,
|
||||
token_batches,
|
||||
label_batches,
|
||||
advantage_batches,
|
||||
temperature_batches,
|
||||
config,
|
||||
inference_logprob_batches=inference_logprob_batches, # Pass for GRPO ratio computation
|
||||
)
|
||||
|
|
@ -352,8 +398,12 @@ def train_shared_vllm(config: TrainingConfig):
|
|||
benchmark_stats["step_times"].append(step_time)
|
||||
|
||||
# GPU memory tracking
|
||||
gpu_mem_gb = torch.cuda.memory_allocated() / 1e9 if torch.cuda.is_available() else 0
|
||||
gpu_mem_reserved_gb = torch.cuda.memory_reserved() / 1e9 if torch.cuda.is_available() else 0
|
||||
gpu_mem_gb = (
|
||||
torch.cuda.memory_allocated() / 1e9 if torch.cuda.is_available() else 0
|
||||
)
|
||||
gpu_mem_reserved_gb = (
|
||||
torch.cuda.memory_reserved() / 1e9 if torch.cuda.is_available() else 0
|
||||
)
|
||||
benchmark_stats["gpu_memories"].append(gpu_mem_gb)
|
||||
|
||||
# In single-copy mode, weights are updated in-place (no sync needed!)
|
||||
|
|
@ -362,23 +412,37 @@ def train_shared_vllm(config: TrainingConfig):
|
|||
benchmark_stats["sync_times"].append(sync_time)
|
||||
|
||||
# Update metrics
|
||||
metrics.update({
|
||||
"step_time": step_time,
|
||||
"sync_time": sync_time,
|
||||
"data_fetch_time": data_fetch_time,
|
||||
"gpu_memory_gb": gpu_mem_gb,
|
||||
"gpu_memory_reserved_gb": gpu_mem_reserved_gb,
|
||||
})
|
||||
metrics.update(
|
||||
{
|
||||
"step_time": step_time,
|
||||
"sync_time": sync_time,
|
||||
"data_fetch_time": data_fetch_time,
|
||||
"gpu_memory_gb": gpu_mem_gb,
|
||||
"gpu_memory_reserved_gb": gpu_mem_reserved_gb,
|
||||
}
|
||||
)
|
||||
|
||||
log_metrics(metrics, step + 1, use_wandb, benchmark=config.benchmark)
|
||||
|
||||
# Periodic checkpoint (for recovery, not for vLLM sync)
|
||||
if config.checkpoint_interval > 0 and (step + 1) % config.checkpoint_interval == 0:
|
||||
if (
|
||||
config.checkpoint_interval > 0
|
||||
and (step + 1) % config.checkpoint_interval == 0
|
||||
):
|
||||
save_checkpoint(model, tokenizer, config.save_path, step + 1)
|
||||
|
||||
# === Cleanup ===
|
||||
save_checkpoint(model, tokenizer, config.save_path, config.training_steps, is_final=True)
|
||||
finalize_training(use_wandb, training_start_time, "shared_vllm", config.training_steps, benchmark_stats, config.benchmark)
|
||||
save_checkpoint(
|
||||
model, tokenizer, config.save_path, config.training_steps, is_final=True
|
||||
)
|
||||
finalize_training(
|
||||
use_wandb,
|
||||
training_start_time,
|
||||
"shared_vllm",
|
||||
config.training_steps,
|
||||
benchmark_stats,
|
||||
config.benchmark,
|
||||
)
|
||||
|
||||
|
||||
def train_lora(config: TrainingConfig):
|
||||
|
|
@ -399,29 +463,33 @@ def train_lora(config: TrainingConfig):
|
|||
- External vLLM server running with --enable-lora
|
||||
"""
|
||||
if not PEFT_AVAILABLE:
|
||||
raise RuntimeError("PEFT library required for LoRA mode. Install with: pip install peft")
|
||||
raise RuntimeError(
|
||||
"PEFT library required for LoRA mode. Install with: pip install peft"
|
||||
)
|
||||
|
||||
training_start_time = time.time()
|
||||
|
||||
# === Setup ===
|
||||
use_wandb = setup_wandb(config)
|
||||
|
||||
print("\n" + "="*60)
|
||||
print("\n" + "=" * 60)
|
||||
print("LORA MODE (adapter-only training)")
|
||||
print("="*60)
|
||||
print("=" * 60)
|
||||
print(f"Base model: {config.model_name}")
|
||||
print(f"LoRA config: r={config.lora_r}, alpha={config.lora_alpha}")
|
||||
print(f"Save path: {config.save_path}")
|
||||
print(f"vLLM port: {config.vllm_port}")
|
||||
print("="*60 + "\n")
|
||||
print("=" * 60 + "\n")
|
||||
|
||||
# Check external vLLM server
|
||||
print("[1/3] Checking external vLLM server...")
|
||||
if not check_vllm_health(config.vllm_port):
|
||||
print(f"\nERROR: vLLM server not running on port {config.vllm_port}")
|
||||
print("\nLoRA mode requires an external vLLM server. Start it first:")
|
||||
print(f" python example_trainer/vllm_api_server.py --model {config.model_name} "
|
||||
f"--port {config.vllm_port} --enable-lora --enforce-eager")
|
||||
print(
|
||||
f" python example_trainer/vllm_api_server.py --model {config.model_name} "
|
||||
f"--port {config.vllm_port} --enable-lora --enforce-eager"
|
||||
)
|
||||
raise RuntimeError(f"External vLLM server required on port {config.vllm_port}")
|
||||
print(f"vLLM server healthy on port {config.vllm_port}")
|
||||
|
||||
|
|
@ -459,10 +527,16 @@ def train_lora(config: TrainingConfig):
|
|||
# Fetch data (with inference logprobs for proper GRPO)
|
||||
data_fetch_start = time.time()
|
||||
if len(batches) == 0:
|
||||
batches, _ = get_data(config.batch_size, config.seq_len, config.atropos_url,
|
||||
extract_inference_logprobs=True)
|
||||
batches, _ = get_data(
|
||||
config.batch_size,
|
||||
config.seq_len,
|
||||
config.atropos_url,
|
||||
extract_inference_logprobs=True,
|
||||
)
|
||||
batch_data = batches.pop(0)
|
||||
token_batches, label_batches, advantage_batches, temperature_batches = batch_data[:4]
|
||||
token_batches, label_batches, advantage_batches, temperature_batches = (
|
||||
batch_data[:4]
|
||||
)
|
||||
inference_logprob_batches = batch_data[4] if len(batch_data) > 4 else None
|
||||
data_fetch_time = time.time() - data_fetch_start
|
||||
benchmark_stats["data_fetch_times"].append(data_fetch_time)
|
||||
|
|
@ -470,8 +544,12 @@ def train_lora(config: TrainingConfig):
|
|||
# Training step with proper GRPO
|
||||
step_start = time.time()
|
||||
metrics = run_training_step(
|
||||
model, optimizer,
|
||||
token_batches, label_batches, advantage_batches, temperature_batches,
|
||||
model,
|
||||
optimizer,
|
||||
token_batches,
|
||||
label_batches,
|
||||
advantage_batches,
|
||||
temperature_batches,
|
||||
config,
|
||||
inference_logprob_batches=inference_logprob_batches,
|
||||
)
|
||||
|
|
@ -479,8 +557,12 @@ def train_lora(config: TrainingConfig):
|
|||
benchmark_stats["step_times"].append(step_time)
|
||||
|
||||
# GPU memory tracking
|
||||
gpu_mem_gb = torch.cuda.memory_allocated() / 1e9 if torch.cuda.is_available() else 0
|
||||
gpu_mem_reserved_gb = torch.cuda.memory_reserved() / 1e9 if torch.cuda.is_available() else 0
|
||||
gpu_mem_gb = (
|
||||
torch.cuda.memory_allocated() / 1e9 if torch.cuda.is_available() else 0
|
||||
)
|
||||
gpu_mem_reserved_gb = (
|
||||
torch.cuda.memory_reserved() / 1e9 if torch.cuda.is_available() else 0
|
||||
)
|
||||
benchmark_stats["gpu_memories"].append(gpu_mem_gb)
|
||||
|
||||
# Periodic adapter save + hot-swap
|
||||
|
|
@ -494,24 +576,35 @@ def train_lora(config: TrainingConfig):
|
|||
benchmark_stats["sync_times"].append(sync_time)
|
||||
|
||||
# Update metrics
|
||||
metrics.update({
|
||||
"step_time": step_time,
|
||||
"sync_time": sync_time,
|
||||
"data_fetch_time": data_fetch_time,
|
||||
"gpu_memory_gb": gpu_mem_gb,
|
||||
"gpu_memory_reserved_gb": gpu_mem_reserved_gb,
|
||||
})
|
||||
metrics.update(
|
||||
{
|
||||
"step_time": step_time,
|
||||
"sync_time": sync_time,
|
||||
"data_fetch_time": data_fetch_time,
|
||||
"gpu_memory_gb": gpu_mem_gb,
|
||||
"gpu_memory_reserved_gb": gpu_mem_reserved_gb,
|
||||
}
|
||||
)
|
||||
|
||||
log_metrics(metrics, step + 1, use_wandb, benchmark=config.benchmark)
|
||||
|
||||
# === Cleanup ===
|
||||
final_sync_start = time.time()
|
||||
final_adapter_path = save_lora_checkpoint(model, config.save_path, config.training_steps, is_final=True)
|
||||
final_adapter_path = save_lora_checkpoint(
|
||||
model, config.save_path, config.training_steps, is_final=True
|
||||
)
|
||||
_hotswap_lora_adapter(config.vllm_port, final_adapter_path, "final")
|
||||
final_sync_time = time.time() - final_sync_start
|
||||
benchmark_stats["sync_times"].append(final_sync_time)
|
||||
|
||||
finalize_training(use_wandb, training_start_time, "lora_only", config.training_steps, benchmark_stats, config.benchmark)
|
||||
finalize_training(
|
||||
use_wandb,
|
||||
training_start_time,
|
||||
"lora_only",
|
||||
config.training_steps,
|
||||
benchmark_stats,
|
||||
config.benchmark,
|
||||
)
|
||||
|
||||
# Save tokenizer
|
||||
tokenizer_path = os.path.join(config.save_path, "tokenizer")
|
||||
|
|
@ -563,4 +656,3 @@ def _hotswap_lora_adapter(
|
|||
except Exception as e:
|
||||
print(f" [LORA] ✗ Hot-swap request failed: {e}")
|
||||
return False
|
||||
|
||||
|
|
|
|||
|
|
@ -18,10 +18,10 @@ import wandb
|
|||
|
||||
from .config import TrainingConfig
|
||||
|
||||
|
||||
# Global storage for logprob alignment stats
|
||||
_logprob_alignment_stats: Dict[str, float] = {}
|
||||
|
||||
|
||||
def setup_wandb(config: TrainingConfig) -> bool:
|
||||
"""
|
||||
Initialize Weights & Biases logging if enabled.
|
||||
|
|
@ -134,7 +134,9 @@ def compute_grpo_loss(
|
|||
# === GRPO/PPO Loss Computation ===
|
||||
if use_reference_logprobs and inference_logprobs is not None:
|
||||
# Move inference logprobs to correct device/dtype
|
||||
ref_logprobs = inference_logprobs.to(logp_per_token.device, logp_per_token.dtype)
|
||||
ref_logprobs = inference_logprobs.to(
|
||||
logp_per_token.device, logp_per_token.dtype
|
||||
)
|
||||
|
||||
# NOTE: inference_logprobs uses 1.0 for masked (prompt) positions, actual negative values for generated
|
||||
with torch.no_grad():
|
||||
|
|
@ -156,10 +158,16 @@ def compute_grpo_loss(
|
|||
# Check if ref logprobs are negative (as they should be for generated tokens)
|
||||
# If ref_at_generated is close to 1.0, that means the 1.0 placeholder is being used
|
||||
if ref_at_generated > 0.5:
|
||||
print(f" [WARNING] ref_logprobs avg {ref_at_generated:.3f} (should be negative!)")
|
||||
print(" [WARNING] This suggests inference_logprobs alignment is wrong")
|
||||
print(
|
||||
f" [WARNING] ref_logprobs avg {ref_at_generated:.3f} (should be negative!)"
|
||||
)
|
||||
print(
|
||||
" [WARNING] This suggests inference_logprobs alignment is wrong"
|
||||
)
|
||||
elif abs(ref_at_generated - train_at_generated) > 2.0:
|
||||
print(f" [DEBUG] Logprob gap: ref={ref_at_generated:.3f}, train={train_at_generated:.3f}")
|
||||
print(
|
||||
f" [DEBUG] Logprob gap: ref={ref_at_generated:.3f}, train={train_at_generated:.3f}"
|
||||
)
|
||||
|
||||
# Compute importance sampling ratio: policy(a|s) / policy_old(a|s) = exp(log policy - log policy_old)
|
||||
log_ratio = logp_per_token - ref_logprobs
|
||||
|
|
@ -192,7 +200,9 @@ def compute_grpo_loss(
|
|||
# = exp(-log_ratio) + log_ratio - 1
|
||||
kl_per_token = torch.exp(-log_ratio) + log_ratio - 1.0
|
||||
kl_penalty = ((kl_per_token * mask).sum(dim=-1) / mask_sum).mean()
|
||||
total_loss = (policy_loss + kl_coef * kl_penalty) / gradient_accumulation_steps
|
||||
total_loss = (
|
||||
policy_loss + kl_coef * kl_penalty
|
||||
) / gradient_accumulation_steps
|
||||
else:
|
||||
kl_penalty = torch.tensor(0.0, device=logp_per_token.device)
|
||||
total_loss = policy_loss / gradient_accumulation_steps
|
||||
|
|
@ -200,7 +210,9 @@ def compute_grpo_loss(
|
|||
# Compute metrics for logging
|
||||
with torch.no_grad():
|
||||
# Fraction of tokens where ratio was clipped
|
||||
clipped_fraction = ((ratio < 1.0 - clip_eps) | (ratio > 1.0 + clip_eps)).float()
|
||||
clipped_fraction = (
|
||||
(ratio < 1.0 - clip_eps) | (ratio > 1.0 + clip_eps)
|
||||
).float()
|
||||
clipped_fraction = (clipped_fraction * mask).sum() / mask.sum()
|
||||
|
||||
# Mean ratio and KL for monitoring (using Schulman's estimator)
|
||||
|
|
@ -256,7 +268,11 @@ def compute_grpo_loss(
|
|||
"kl_penalty": kl_penalty.item() if torch.is_tensor(kl_penalty) else kl_penalty,
|
||||
"mean_ratio": mean_ratio.item() if torch.is_tensor(mean_ratio) else mean_ratio,
|
||||
"mean_kl": mean_kl.item() if torch.is_tensor(mean_kl) else mean_kl,
|
||||
"clipped_fraction": clipped_fraction.item() if torch.is_tensor(clipped_fraction) else clipped_fraction,
|
||||
"clipped_fraction": (
|
||||
clipped_fraction.item()
|
||||
if torch.is_tensor(clipped_fraction)
|
||||
else clipped_fraction
|
||||
),
|
||||
# Token-level alignment metrics (key for verifying weight sharing)
|
||||
"logprob_diff_mean": logprob_diff_mean,
|
||||
"logprob_diff_abs_mean": logprob_diff_abs_mean,
|
||||
|
|
@ -315,23 +331,25 @@ def run_training_step(
|
|||
all_inference_logprobs: List[torch.Tensor] = []
|
||||
|
||||
# Get GRPO hyperparameters from config
|
||||
kl_coef = getattr(config, 'kl_coef', 0.1)
|
||||
clip_eps = getattr(config, 'clip_eps', 0.2)
|
||||
use_reference_logprobs = getattr(config, 'use_reference_logprobs', True)
|
||||
kl_coef = getattr(config, "kl_coef", 0.1)
|
||||
clip_eps = getattr(config, "clip_eps", 0.2)
|
||||
use_reference_logprobs = getattr(config, "use_reference_logprobs", True)
|
||||
|
||||
# Accumulate gradients over micro-batches
|
||||
num_batches = len(token_batches) if token_batches else 1
|
||||
|
||||
for batch_idx, (tokens, labels, advantages, temperatures) in enumerate(zip(
|
||||
token_batches, label_batches, advantage_batches, temperature_batches
|
||||
)):
|
||||
for batch_idx, (tokens, labels, advantages, temperatures) in enumerate(
|
||||
zip(token_batches, label_batches, advantage_batches, temperature_batches)
|
||||
):
|
||||
tokens = tokens.to(config.device)
|
||||
labels = labels.to(config.device)
|
||||
advantages = advantages.to(config.device)
|
||||
|
||||
# Get corresponding inference logprobs batch if available
|
||||
inf_logprobs = None
|
||||
if inference_logprob_batches is not None and batch_idx < len(inference_logprob_batches):
|
||||
if inference_logprob_batches is not None and batch_idx < len(
|
||||
inference_logprob_batches
|
||||
):
|
||||
inf_logprobs = inference_logprob_batches[batch_idx]
|
||||
|
||||
loss, metrics = compute_grpo_loss(
|
||||
|
|
@ -363,12 +381,17 @@ def run_training_step(
|
|||
# Accumulate token-level alignment metrics
|
||||
total_logprob_diff_mean += metrics.get("logprob_diff_mean", 0.0)
|
||||
total_logprob_diff_abs_mean += metrics.get("logprob_diff_abs_mean", 0.0)
|
||||
total_logprob_diff_max = max(total_logprob_diff_max, metrics.get("logprob_diff_max", 0.0))
|
||||
total_logprob_diff_max = max(
|
||||
total_logprob_diff_max, metrics.get("logprob_diff_max", 0.0)
|
||||
)
|
||||
|
||||
# Collect logprobs for alignment monitoring
|
||||
if "training_logprobs" in metrics and metrics["training_logprobs"] is not None:
|
||||
all_training_logprobs.append(metrics["training_logprobs"])
|
||||
if "inference_logprobs" in metrics and metrics["inference_logprobs"] is not None:
|
||||
if (
|
||||
"inference_logprobs" in metrics
|
||||
and metrics["inference_logprobs"] is not None
|
||||
):
|
||||
all_inference_logprobs.append(metrics["inference_logprobs"])
|
||||
|
||||
# Gradient clipping and optimizer step
|
||||
|
|
@ -387,7 +410,7 @@ def run_training_step(
|
|||
|
||||
result = {
|
||||
"loss": total_loss,
|
||||
"grad_norm": grad_norm.item() if hasattr(grad_norm, 'item') else grad_norm,
|
||||
"grad_norm": grad_norm.item() if hasattr(grad_norm, "item") else grad_norm,
|
||||
"pos_logp": total_pos_logp,
|
||||
"neg_logp": total_neg_logp,
|
||||
"pos_count": total_pos,
|
||||
|
|
@ -404,7 +427,9 @@ def run_training_step(
|
|||
if all_training_logprobs:
|
||||
train_flat = torch.cat(all_training_logprobs)
|
||||
if train_flat.numel() > 0:
|
||||
_logprob_alignment_stats["logprobs/training_mean"] = train_flat.mean().item()
|
||||
_logprob_alignment_stats["logprobs/training_mean"] = (
|
||||
train_flat.mean().item()
|
||||
)
|
||||
_logprob_alignment_stats["logprobs/training_std"] = train_flat.std().item()
|
||||
|
||||
if all_inference_logprobs:
|
||||
|
|
@ -415,8 +440,12 @@ def run_training_step(
|
|||
|
||||
# Token-level alignment metrics - THE key metric for verifying weight sharing
|
||||
# diff_abs_mean close to 0 = weights are truly shared
|
||||
_logprob_alignment_stats["alignment/diff_mean"] = total_logprob_diff_mean / num_batches
|
||||
_logprob_alignment_stats["alignment/diff_abs_mean"] = total_logprob_diff_abs_mean / num_batches
|
||||
_logprob_alignment_stats["alignment/diff_mean"] = (
|
||||
total_logprob_diff_mean / num_batches
|
||||
)
|
||||
_logprob_alignment_stats["alignment/diff_abs_mean"] = (
|
||||
total_logprob_diff_abs_mean / num_batches
|
||||
)
|
||||
_logprob_alignment_stats["alignment/diff_max"] = total_logprob_diff_max
|
||||
|
||||
return result
|
||||
|
|
@ -495,8 +524,13 @@ def log_metrics(
|
|||
"grpo/clipped_fraction": clipped_frac,
|
||||
}
|
||||
# Add timing metrics if present
|
||||
for key in ["step_time", "sync_time", "data_fetch_time",
|
||||
"gpu_memory_gb", "gpu_memory_reserved_gb"]:
|
||||
for key in [
|
||||
"step_time",
|
||||
"sync_time",
|
||||
"data_fetch_time",
|
||||
"gpu_memory_gb",
|
||||
"gpu_memory_reserved_gb",
|
||||
]:
|
||||
if key in metrics:
|
||||
log_dict[f"train/{key}"] = metrics[key]
|
||||
|
||||
|
|
@ -549,7 +583,9 @@ def finalize_training(
|
|||
total_step_time = sum(step_times)
|
||||
avg_sync_time = sum(sync_times) / len(sync_times) if sync_times else 0
|
||||
total_sync_time = sum(sync_times)
|
||||
avg_data_fetch = sum(data_fetch_times) / len(data_fetch_times) if data_fetch_times else 0
|
||||
avg_data_fetch = (
|
||||
sum(data_fetch_times) / len(data_fetch_times) if data_fetch_times else 0
|
||||
)
|
||||
total_data_fetch = sum(data_fetch_times)
|
||||
avg_gpu_mem = sum(gpu_memories) / len(gpu_memories) if gpu_memories else 0
|
||||
|
||||
|
|
@ -557,13 +593,17 @@ def finalize_training(
|
|||
print(f"\n{'='*70}")
|
||||
print(f"BENCHMARK SUMMARY ({mode})")
|
||||
print(f"{'='*70}")
|
||||
print(f" Total training time: {total_time:.2f}s ({total_time/60:.2f} min)")
|
||||
print(
|
||||
f" Total training time: {total_time:.2f}s ({total_time/60:.2f} min)"
|
||||
)
|
||||
print(f" Total steps: {total_steps}")
|
||||
print(" ")
|
||||
print(" TIMING BREAKDOWN:")
|
||||
print(f" Avg step time: {avg_step_time:.2f}s")
|
||||
print(f" Total step time: {total_step_time:.2f}s")
|
||||
print(f" Avg sync time: {avg_sync_time:.2f}s (x{len(sync_times)} syncs)")
|
||||
print(
|
||||
f" Avg sync time: {avg_sync_time:.2f}s (x{len(sync_times)} syncs)"
|
||||
)
|
||||
print(f" Total sync time: {total_sync_time:.2f}s")
|
||||
print(f" Avg data fetch time: {avg_data_fetch:.2f}s")
|
||||
print(f" Total data fetch time: {total_data_fetch:.2f}s")
|
||||
|
|
@ -584,4 +624,3 @@ def finalize_training(
|
|||
wandb.finish()
|
||||
elif use_wandb:
|
||||
wandb.finish()
|
||||
|
||||
|
|
|
|||
|
|
@ -53,7 +53,7 @@ os.environ.setdefault("VLLM_USE_V1", "0")
|
|||
# Set spawn method for multiprocessing (required for CUDA)
|
||||
os.environ.setdefault("VLLM_WORKER_MULTIPROC_METHOD", "spawn")
|
||||
try:
|
||||
multiprocessing.set_start_method('spawn', force=True)
|
||||
multiprocessing.set_start_method("spawn", force=True)
|
||||
except RuntimeError:
|
||||
pass # Already set
|
||||
|
||||
|
|
@ -86,6 +86,7 @@ def _apply_patches_early() -> bool:
|
|||
try:
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Add parent directory to path so we can import vllm_patching
|
||||
script_dir = Path(__file__).parent
|
||||
if str(script_dir) not in sys.path:
|
||||
|
|
@ -106,6 +107,7 @@ def _apply_patches_early() -> bool:
|
|||
except Exception as e:
|
||||
print(f"[vLLM Server] Error applying patches: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
|
|
@ -145,17 +147,20 @@ except ImportError:
|
|||
|
||||
def add_argument(self, *args, **kwargs):
|
||||
# Remove 'deprecated' kwarg if present (not supported before Python 3.13)
|
||||
kwargs.pop('deprecated', None)
|
||||
kwargs.pop("deprecated", None)
|
||||
return super().add_argument(*args, **kwargs)
|
||||
|
||||
|
||||
# set_ulimit might not exist in all vLLM versions
|
||||
try:
|
||||
from vllm.utils import set_ulimit
|
||||
except ImportError:
|
||||
|
||||
def set_ulimit() -> None:
|
||||
"""No-op fallback for set_ulimit."""
|
||||
pass
|
||||
|
||||
|
||||
from vllm.outputs import RequestOutput # noqa: F401, E402
|
||||
from vllm.version import __version__ as VLLM_VERSION # noqa: E402
|
||||
|
||||
|
|
@ -602,7 +607,9 @@ async def lora_load(request: LoraLoadRequest) -> JSONResponse:
|
|||
) # vLLM needs unique int ID
|
||||
bridge_state.lora_load_count += 1
|
||||
|
||||
logger.info(f"LoRA adapter loaded: {request.adapter_path} (id={bridge_state.active_lora_id})")
|
||||
logger.info(
|
||||
f"LoRA adapter loaded: {request.adapter_path} (id={bridge_state.active_lora_id})"
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
{
|
||||
|
|
|
|||
|
|
@ -17,7 +17,6 @@ import requests
|
|||
|
||||
from .config import TrainingConfig
|
||||
|
||||
|
||||
# Global variable to keep track of the vLLM process
|
||||
_vllm_process: Optional[subprocess.Popen] = None
|
||||
|
||||
|
|
@ -25,7 +24,7 @@ _vllm_process: Optional[subprocess.Popen] = None
|
|||
def is_port_in_use(port: int) -> bool:
|
||||
"""Check if a port is already in use."""
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
return s.connect_ex(('localhost', port)) == 0
|
||||
return s.connect_ex(("localhost", port)) == 0
|
||||
|
||||
|
||||
def kill_process_on_port(port: int, timeout: float = 5.0) -> bool:
|
||||
|
|
@ -42,13 +41,10 @@ def kill_process_on_port(port: int, timeout: float = 5.0) -> bool:
|
|||
try:
|
||||
# Try to find and kill the process using lsof (Linux/Mac)
|
||||
result = subprocess.run(
|
||||
["lsof", "-t", "-i", f":{port}"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=5
|
||||
["lsof", "-t", "-i", f":{port}"], capture_output=True, text=True, timeout=5
|
||||
)
|
||||
if result.stdout.strip():
|
||||
pids = result.stdout.strip().split('\n')
|
||||
pids = result.stdout.strip().split("\n")
|
||||
for pid in pids:
|
||||
try:
|
||||
os.kill(int(pid), signal.SIGTERM)
|
||||
|
|
@ -135,7 +131,9 @@ def launch_vllm_server(
|
|||
if is_port_in_use(config.vllm_port):
|
||||
print(f" WARNING: Port {config.vllm_port} is already in use!")
|
||||
if not kill_process_on_port(config.vllm_port):
|
||||
print(f" ERROR: Could not free port {config.vllm_port}. Please manually kill the process.")
|
||||
print(
|
||||
f" ERROR: Could not free port {config.vllm_port}. Please manually kill the process."
|
||||
)
|
||||
print(f" Try: lsof -i :{config.vllm_port} | grep LISTEN")
|
||||
print(f" Or: pkill -f 'vllm.*{config.vllm_port}'")
|
||||
return None
|
||||
|
|
@ -209,7 +207,9 @@ def check_vllm_process_health() -> None:
|
|||
global _vllm_process
|
||||
|
||||
if _vllm_process is not None and _vllm_process.poll() is not None:
|
||||
print(f" WARNING: vLLM terminated unexpectedly (code: {_vllm_process.returncode})")
|
||||
print(
|
||||
f" WARNING: vLLM terminated unexpectedly (code: {_vllm_process.returncode})"
|
||||
)
|
||||
_vllm_process = None
|
||||
|
||||
|
||||
|
|
@ -299,7 +299,9 @@ def hotswap_lora_adapter(
|
|||
print(f" [LORA] ✓ Hot-swapped adapter: {adapter_name} ({adapter_path})")
|
||||
return True
|
||||
else:
|
||||
print(f" [LORA] ✗ Hot-swap failed: {response.status_code} - {response.text}")
|
||||
print(
|
||||
f" [LORA] ✗ Hot-swap failed: {response.status_code} - {response.text}"
|
||||
)
|
||||
return False
|
||||
|
||||
except requests.exceptions.ConnectionError:
|
||||
|
|
@ -308,4 +310,3 @@ def hotswap_lora_adapter(
|
|||
except Exception as e:
|
||||
print(f" [LORA] ✗ Error during hot-swap: {e}")
|
||||
return False
|
||||
|
||||
|
|
|
|||
|
|
@ -37,6 +37,7 @@ def _patch_lora_triton_for_blackwell() -> bool:
|
|||
"""
|
||||
try:
|
||||
import vllm
|
||||
|
||||
vllm_path = vllm.__path__[0]
|
||||
kernel_utils_path = f"{vllm_path}/lora/ops/triton_ops/kernel_utils.py"
|
||||
|
||||
|
|
@ -45,42 +46,42 @@ def _patch_lora_triton_for_blackwell() -> bool:
|
|||
print("[vLLM Patch] LoRA kernel_utils.py not found, skipping GDC patch")
|
||||
return False
|
||||
|
||||
with open(kernel_utils_path, 'r') as f:
|
||||
with open(kernel_utils_path, "r") as f:
|
||||
content = f.read()
|
||||
|
||||
# Check if already patched
|
||||
if 'PATCHED FOR B200' in content:
|
||||
if "PATCHED FOR B200" in content:
|
||||
print("[vLLM Patch] LoRA GDC already patched for B200")
|
||||
return True
|
||||
|
||||
modified = False
|
||||
|
||||
# Patch USE_GDC = True -> False
|
||||
if 'USE_GDC = True' in content:
|
||||
if "USE_GDC = True" in content:
|
||||
content = content.replace(
|
||||
'USE_GDC = True',
|
||||
'USE_GDC = False # PATCHED FOR B200 - GDC causes Triton compilation failure'
|
||||
"USE_GDC = True",
|
||||
"USE_GDC = False # PATCHED FOR B200 - GDC causes Triton compilation failure",
|
||||
)
|
||||
modified = True
|
||||
|
||||
# Patch USE_GDC: tl.constexpr = True -> False
|
||||
if 'USE_GDC: tl.constexpr = True' in content:
|
||||
if "USE_GDC: tl.constexpr = True" in content:
|
||||
content = content.replace(
|
||||
'USE_GDC: tl.constexpr = True',
|
||||
'USE_GDC: tl.constexpr = False # PATCHED FOR B200'
|
||||
"USE_GDC: tl.constexpr = True",
|
||||
"USE_GDC: tl.constexpr = False # PATCHED FOR B200",
|
||||
)
|
||||
modified = True
|
||||
|
||||
# Patch the gdc_wait call itself
|
||||
if 'tl.extra.cuda.gdc_wait()' in content:
|
||||
if "tl.extra.cuda.gdc_wait()" in content:
|
||||
content = content.replace(
|
||||
'tl.extra.cuda.gdc_wait()',
|
||||
'pass # tl.extra.cuda.gdc_wait() PATCHED FOR B200 - disabled'
|
||||
"tl.extra.cuda.gdc_wait()",
|
||||
"pass # tl.extra.cuda.gdc_wait() PATCHED FOR B200 - disabled",
|
||||
)
|
||||
modified = True
|
||||
|
||||
if modified:
|
||||
with open(kernel_utils_path, 'w') as f:
|
||||
with open(kernel_utils_path, "w") as f:
|
||||
f.write(content)
|
||||
print(f"[vLLM Patch] ✓ Patched LoRA Triton GDC in {kernel_utils_path}")
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue