mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
parent
d07ab3e3ce
commit
5cfd1929f1
19 changed files with 708 additions and 452 deletions
|
|
@ -373,7 +373,9 @@ class GSM8KEvalEnv(BaseEnv):
|
||||||
|
|
||||||
# Create evaluation tasks
|
# Create evaluation tasks
|
||||||
async def eval_task(item):
|
async def eval_task(item):
|
||||||
return await self.rollout_and_score_eval(item, self.server.servers[0].config)
|
return await self.rollout_and_score_eval(
|
||||||
|
item, self.server.servers[0].config
|
||||||
|
)
|
||||||
|
|
||||||
tasks = [eval_task(item) for item in self.eval_items]
|
tasks = [eval_task(item) for item in self.eval_items]
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -299,6 +299,7 @@ class MathEnv(BaseEnv):
|
||||||
if not self.config.run_evaluation:
|
if not self.config.run_evaluation:
|
||||||
return
|
return
|
||||||
import time
|
import time
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
eval_tasks = []
|
eval_tasks = []
|
||||||
|
|
@ -320,9 +321,7 @@ class MathEnv(BaseEnv):
|
||||||
metrics[f"{subset}_accuracy"] = accuracy
|
metrics[f"{subset}_accuracy"] = accuracy
|
||||||
metrics[f"{subset}_total"] = len(scores)
|
metrics[f"{subset}_total"] = len(scores)
|
||||||
metrics[f"{subset}_correct"] = sum(scores)
|
metrics[f"{subset}_correct"] = sum(scores)
|
||||||
self.eval_metrics.append(
|
self.eval_metrics.append((f"eval/{subset}_percent_correct", accuracy))
|
||||||
(f"eval/{subset}_percent_correct", accuracy)
|
|
||||||
)
|
|
||||||
|
|
||||||
# overall score
|
# overall score
|
||||||
all_scores = []
|
all_scores = []
|
||||||
|
|
@ -332,9 +331,7 @@ class MathEnv(BaseEnv):
|
||||||
metrics["overall_accuracy"] = overall_accuracy
|
metrics["overall_accuracy"] = overall_accuracy
|
||||||
metrics["overall_total"] = len(all_scores)
|
metrics["overall_total"] = len(all_scores)
|
||||||
metrics["overall_correct"] = sum(all_scores)
|
metrics["overall_correct"] = sum(all_scores)
|
||||||
self.eval_metrics.append(
|
self.eval_metrics.append(("eval/overall_percent_correct", overall_accuracy))
|
||||||
("eval/overall_percent_correct", overall_accuracy)
|
|
||||||
)
|
|
||||||
|
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
|
|
||||||
|
|
@ -342,7 +339,9 @@ class MathEnv(BaseEnv):
|
||||||
print("\n" + "=" * 60)
|
print("\n" + "=" * 60)
|
||||||
print("Math Zero Evaluation Results")
|
print("Math Zero Evaluation Results")
|
||||||
print("=" * 60)
|
print("=" * 60)
|
||||||
print(f"Overall Accuracy: {overall_accuracy:.2%} ({sum(all_scores)}/{len(all_scores)})")
|
print(
|
||||||
|
f"Overall Accuracy: {overall_accuracy:.2%} ({sum(all_scores)}/{len(all_scores)})"
|
||||||
|
)
|
||||||
print("\nPer-subset breakdown:")
|
print("\nPer-subset breakdown:")
|
||||||
for subset, scores in sorted(task_lists.items()):
|
for subset, scores in sorted(task_lists.items()):
|
||||||
acc = sum(scores) / len(scores)
|
acc = sum(scores) / len(scores)
|
||||||
|
|
|
||||||
|
|
@ -542,4 +542,3 @@ python -m example_trainer.vllm_api_server # NOT direct vllm commands
|
||||||
| `vllm_api_server.py` | Streamlined vLLM server for training |
|
| `vllm_api_server.py` | Streamlined vLLM server for training |
|
||||||
| `vllm_manager.py` | vLLM process lifecycle management |
|
| `vllm_manager.py` | vLLM process lifecycle management |
|
||||||
| `checkpointing.py` | Save/load checkpoints and adapters |
|
| `checkpointing.py` | Save/load checkpoints and adapters |
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -20,9 +20,9 @@ Usage:
|
||||||
train_legacy(config)
|
train_legacy(config)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from .cli import config_from_args, parse_args
|
||||||
from .config import TrainingConfig
|
from .config import TrainingConfig
|
||||||
from .trainers import train_legacy, train_shared_vllm, train_lora
|
from .trainers import train_legacy, train_lora, train_shared_vllm
|
||||||
from .cli import parse_args, config_from_args
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"TrainingConfig",
|
"TrainingConfig",
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,9 @@ from tenacity import retry, stop_after_attempt, wait_exponential
|
||||||
from .config import TrainingConfig
|
from .config import TrainingConfig
|
||||||
|
|
||||||
|
|
||||||
def check_atropos_api(url: str = "http://localhost:8000", timeout: float = 30.0) -> bool:
|
def check_atropos_api(
|
||||||
|
url: str = "http://localhost:8000", timeout: float = 30.0
|
||||||
|
) -> bool:
|
||||||
"""
|
"""
|
||||||
Check if the Atropos API server is reachable.
|
Check if the Atropos API server is reachable.
|
||||||
|
|
||||||
|
|
@ -99,4 +101,3 @@ def get_batch(url: str = "http://localhost:8000"):
|
||||||
raise RuntimeError(f"Atropos API error: {data.get('message', 'Unknown error')}")
|
raise RuntimeError(f"Atropos API error: {data.get('message', 'Unknown error')}")
|
||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -89,11 +89,14 @@ def save_checkpoint(
|
||||||
|
|
||||||
# Count how many were non-contiguous (views into fused tensors)
|
# Count how many were non-contiguous (views into fused tensors)
|
||||||
view_count = sum(
|
view_count = sum(
|
||||||
1 for name, param in model.named_parameters()
|
1
|
||||||
|
for name, param in model.named_parameters()
|
||||||
if not param.is_contiguous() or param.storage_offset() != 0
|
if not param.is_contiguous() or param.storage_offset() != 0
|
||||||
)
|
)
|
||||||
if view_count > 0:
|
if view_count > 0:
|
||||||
print(f" [Checkpoint] Unfused {view_count} view tensors (qkv/gate_up fusions)")
|
print(
|
||||||
|
f" [Checkpoint] Unfused {view_count} view tensors (qkv/gate_up fusions)"
|
||||||
|
)
|
||||||
|
|
||||||
# Save state dict manually, then save config separately
|
# Save state dict manually, then save config separately
|
||||||
torch.save(state_dict, os.path.join(checkpoint_path, "pytorch_model.bin"))
|
torch.save(state_dict, os.path.join(checkpoint_path, "pytorch_model.bin"))
|
||||||
|
|
@ -102,6 +105,7 @@ def save_checkpoint(
|
||||||
# CRITICAL: Clean up the copied state_dict to free ~8GB GPU memory!
|
# CRITICAL: Clean up the copied state_dict to free ~8GB GPU memory!
|
||||||
del state_dict
|
del state_dict
|
||||||
import gc
|
import gc
|
||||||
|
|
||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
else:
|
else:
|
||||||
|
|
@ -151,4 +155,3 @@ def save_lora_checkpoint(
|
||||||
|
|
||||||
print(" Adapter saved.")
|
print(" Adapter saved.")
|
||||||
return adapter_path
|
return adapter_path
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -11,16 +11,17 @@ import torch
|
||||||
|
|
||||||
from .config import TrainingConfig
|
from .config import TrainingConfig
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# Argument Group Builders (modular, reusable)
|
# Argument Group Builders (modular, reusable)
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
def add_model_args(parser: argparse.ArgumentParser) -> None:
|
def add_model_args(parser: argparse.ArgumentParser) -> None:
|
||||||
"""Add model-related arguments."""
|
"""Add model-related arguments."""
|
||||||
group = parser.add_argument_group("Model")
|
group = parser.add_argument_group("Model")
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
"--model", "--model-name",
|
"--model",
|
||||||
|
"--model-name",
|
||||||
type=str,
|
type=str,
|
||||||
required=True,
|
required=True,
|
||||||
dest="model_name",
|
dest="model_name",
|
||||||
|
|
@ -67,7 +68,7 @@ def add_training_args(parser: argparse.ArgumentParser) -> None:
|
||||||
choices=["adamw", "adamw_8bit", "adamw_cpu", "adafactor"],
|
choices=["adamw", "adamw_8bit", "adamw_cpu", "adafactor"],
|
||||||
default="adamw_8bit",
|
default="adamw_8bit",
|
||||||
help="Optimizer: 'adamw' (full precision), 'adamw_8bit' (8-bit states), "
|
help="Optimizer: 'adamw' (full precision), 'adamw_8bit' (8-bit states), "
|
||||||
"'adamw_cpu' (CPU offload), 'adafactor' (no momentum)",
|
"'adamw_cpu' (CPU offload), 'adafactor' (no momentum)",
|
||||||
)
|
)
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
"--device",
|
"--device",
|
||||||
|
|
@ -121,7 +122,8 @@ def add_vllm_args(parser: argparse.ArgumentParser) -> None:
|
||||||
help="Port for the vLLM server",
|
help="Port for the vLLM server",
|
||||||
)
|
)
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
"--gpu-memory-utilization", "--vllm-gpu-memory-utilization",
|
"--gpu-memory-utilization",
|
||||||
|
"--vllm-gpu-memory-utilization",
|
||||||
type=float,
|
type=float,
|
||||||
default=0.45,
|
default=0.45,
|
||||||
dest="gpu_memory_utilization",
|
dest="gpu_memory_utilization",
|
||||||
|
|
@ -203,7 +205,9 @@ def add_lora_args(parser: argparse.ArgumentParser) -> None:
|
||||||
"""Add LoRA-specific arguments."""
|
"""Add LoRA-specific arguments."""
|
||||||
group = parser.add_argument_group("LoRA Configuration")
|
group = parser.add_argument_group("LoRA Configuration")
|
||||||
group.add_argument("--lora-r", type=int, default=16, help="LoRA rank")
|
group.add_argument("--lora-r", type=int, default=16, help="LoRA rank")
|
||||||
group.add_argument("--lora-alpha", type=int, default=32, help="LoRA alpha (scaling factor)")
|
group.add_argument(
|
||||||
|
"--lora-alpha", type=int, default=32, help="LoRA alpha (scaling factor)"
|
||||||
|
)
|
||||||
group.add_argument("--lora-dropout", type=float, default=0.05, help="LoRA dropout")
|
group.add_argument("--lora-dropout", type=float, default=0.05, help="LoRA dropout")
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
"--lora-target-modules",
|
"--lora-target-modules",
|
||||||
|
|
@ -219,8 +223,12 @@ def add_distributed_args(parser: argparse.ArgumentParser) -> None:
|
||||||
group = parser.add_argument_group("Distributed Training")
|
group = parser.add_argument_group("Distributed Training")
|
||||||
group.add_argument("--trainer-rank", type=int, default=0, help="Trainer rank")
|
group.add_argument("--trainer-rank", type=int, default=0, help="Trainer rank")
|
||||||
group.add_argument("--world-size", type=int, default=1, help="World size")
|
group.add_argument("--world-size", type=int, default=1, help="World size")
|
||||||
group.add_argument("--init-method", type=str, default="env://", help="Distributed init method")
|
group.add_argument(
|
||||||
group.add_argument("--num-inference-nodes", type=int, default=0, help="Number of inference nodes")
|
"--init-method", type=str, default="env://", help="Distributed init method"
|
||||||
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"--num-inference-nodes", type=int, default=0, help="Number of inference nodes"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def add_debug_args(parser: argparse.ArgumentParser) -> None:
|
def add_debug_args(parser: argparse.ArgumentParser) -> None:
|
||||||
|
|
@ -248,6 +256,7 @@ def add_debug_args(parser: argparse.ArgumentParser) -> None:
|
||||||
# Parser Builders
|
# Parser Builders
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
def create_base_parser(description: str) -> argparse.ArgumentParser:
|
def create_base_parser(description: str) -> argparse.ArgumentParser:
|
||||||
"""Create a base parser with common formatting."""
|
"""Create a base parser with common formatting."""
|
||||||
return argparse.ArgumentParser(
|
return argparse.ArgumentParser(
|
||||||
|
|
@ -299,6 +308,7 @@ def create_unified_parser() -> argparse.ArgumentParser:
|
||||||
# Legacy API (backwards compatibility)
|
# Legacy API (backwards compatibility)
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
def parse_args() -> argparse.Namespace:
|
def parse_args() -> argparse.Namespace:
|
||||||
"""
|
"""
|
||||||
Parse command-line arguments for the GRPO trainer (grpo.py).
|
Parse command-line arguments for the GRPO trainer (grpo.py).
|
||||||
|
|
|
||||||
|
|
@ -35,9 +35,9 @@ class TrainingConfig(BaseModel):
|
||||||
optimizer: Literal["adamw", "adamw_8bit", "adamw_cpu", "adafactor"] = Field(
|
optimizer: Literal["adamw", "adamw_8bit", "adamw_cpu", "adafactor"] = Field(
|
||||||
"adamw_8bit",
|
"adamw_8bit",
|
||||||
description="Optimizer to use: 'adamw' (full precision, ~32GB GPU), "
|
description="Optimizer to use: 'adamw' (full precision, ~32GB GPU), "
|
||||||
"'adamw_8bit' (8-bit states, ~8GB GPU, requires bitsandbytes), "
|
"'adamw_8bit' (8-bit states, ~8GB GPU, requires bitsandbytes), "
|
||||||
"'adamw_cpu' (CPU offload, ~0GB GPU, slower), "
|
"'adamw_cpu' (CPU offload, ~0GB GPU, slower), "
|
||||||
"'adafactor' (no momentum, ~8GB GPU)"
|
"'adafactor' (no momentum, ~8GB GPU)",
|
||||||
)
|
)
|
||||||
|
|
||||||
# === GRPO/PPO Hyperparameters ===
|
# === GRPO/PPO Hyperparameters ===
|
||||||
|
|
@ -69,12 +69,10 @@ class TrainingConfig(BaseModel):
|
||||||
|
|
||||||
# === Device & Storage ===
|
# === Device & Storage ===
|
||||||
device: str = Field(
|
device: str = Field(
|
||||||
"cuda" if torch.cuda.is_available() else "cpu",
|
"cuda" if torch.cuda.is_available() else "cpu", description="Device to train on"
|
||||||
description="Device to train on"
|
|
||||||
)
|
)
|
||||||
save_path: str = Field(
|
save_path: str = Field(
|
||||||
"trained_model_checkpoints",
|
"trained_model_checkpoints", description="Base path to save model checkpoints"
|
||||||
description="Base path to save model checkpoints"
|
|
||||||
)
|
)
|
||||||
checkpoint_interval: int = Field(
|
checkpoint_interval: int = Field(
|
||||||
3,
|
3,
|
||||||
|
|
@ -121,9 +119,7 @@ class TrainingConfig(BaseModel):
|
||||||
trainer_rank: int = Field(
|
trainer_rank: int = Field(
|
||||||
0, description="Rank of this trainer in the distributed group"
|
0, description="Rank of this trainer in the distributed group"
|
||||||
)
|
)
|
||||||
world_size: int = Field(
|
world_size: int = Field(1, description="Total processes in the distributed group")
|
||||||
1, description="Total processes in the distributed group"
|
|
||||||
)
|
|
||||||
init_method: str = Field(
|
init_method: str = Field(
|
||||||
"env://",
|
"env://",
|
||||||
description=(
|
description=(
|
||||||
|
|
@ -198,4 +194,3 @@ class TrainingConfig(BaseModel):
|
||||||
"Default is http://localhost:8000. Change for concurrent tests."
|
"Default is http://localhost:8000. Change for concurrent tests."
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -92,28 +92,30 @@ def pad_data_to_good_offset(
|
||||||
# Process each sample in the item
|
# Process each sample in the item
|
||||||
for i in range(len(item["tokens"])):
|
for i in range(len(item["tokens"])):
|
||||||
seq_len = len(item["tokens"][i])
|
seq_len = len(item["tokens"][i])
|
||||||
lengths.append(
|
lengths.append(math.ceil((seq_len - 1) / good_multiple) * good_multiple)
|
||||||
math.ceil((seq_len - 1) / good_multiple) * good_multiple
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create labels with padding (-100 for masked positions)
|
# Create labels with padding (-100 for masked positions)
|
||||||
label_item = np.concatenate([
|
label_item = np.concatenate(
|
||||||
np.array(item["masks"][i]),
|
[
|
||||||
np.full(
|
np.array(item["masks"][i]),
|
||||||
max(0, token_setup_len - seq_len),
|
np.full(
|
||||||
-100,
|
max(0, token_setup_len - seq_len),
|
||||||
dtype=np.int32,
|
-100,
|
||||||
),
|
dtype=np.int32,
|
||||||
])
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
# Pad tokens
|
# Pad tokens
|
||||||
item["tokens"][i] = np.concatenate([
|
item["tokens"][i] = np.concatenate(
|
||||||
np.array(item["tokens"][i]),
|
[
|
||||||
np.zeros(
|
np.array(item["tokens"][i]),
|
||||||
max(0, token_setup_len - seq_len),
|
np.zeros(
|
||||||
dtype=np.int32,
|
max(0, token_setup_len - seq_len),
|
||||||
),
|
dtype=np.int32,
|
||||||
])
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
input_ids.append(item["tokens"][i][:-1]) # Remove last for causal
|
input_ids.append(item["tokens"][i][:-1]) # Remove last for causal
|
||||||
labels.append(label_item[1:]) # Shift by 1 for causal
|
labels.append(label_item[1:]) # Shift by 1 for causal
|
||||||
|
|
@ -126,7 +128,9 @@ def pad_data_to_good_offset(
|
||||||
# We just need to pad to match the sequence length
|
# We just need to pad to match the sequence length
|
||||||
if extract_inference_logprobs and "inference_logprobs" in item:
|
if extract_inference_logprobs and "inference_logprobs" in item:
|
||||||
if i < len(item["inference_logprobs"]):
|
if i < len(item["inference_logprobs"]):
|
||||||
raw_logprobs = np.array(item["inference_logprobs"][i], dtype=np.float32)
|
raw_logprobs = np.array(
|
||||||
|
item["inference_logprobs"][i], dtype=np.float32
|
||||||
|
)
|
||||||
has_any_logprobs = True
|
has_any_logprobs = True
|
||||||
|
|
||||||
# Create padded logprobs array matching token_setup_len
|
# Create padded logprobs array matching token_setup_len
|
||||||
|
|
@ -141,10 +145,14 @@ def pad_data_to_good_offset(
|
||||||
inference_logprobs_padded.append(padded_logprobs[1:])
|
inference_logprobs_padded.append(padded_logprobs[1:])
|
||||||
else:
|
else:
|
||||||
# No logprobs for this sample, use 1.0
|
# No logprobs for this sample, use 1.0
|
||||||
inference_logprobs_padded.append(np.full(token_setup_len - 1, 1.0, dtype=np.float32))
|
inference_logprobs_padded.append(
|
||||||
|
np.full(token_setup_len - 1, 1.0, dtype=np.float32)
|
||||||
|
)
|
||||||
elif extract_inference_logprobs:
|
elif extract_inference_logprobs:
|
||||||
# No inference_logprobs in item, use 1.0
|
# No inference_logprobs in item, use 1.0
|
||||||
inference_logprobs_padded.append(np.full(token_setup_len - 1, 1.0, dtype=np.float32))
|
inference_logprobs_padded.append(
|
||||||
|
np.full(token_setup_len - 1, 1.0, dtype=np.float32)
|
||||||
|
)
|
||||||
|
|
||||||
# Extract temperature (priority: override > generation_params > group_overrides > 1.0)
|
# Extract temperature (priority: override > generation_params > group_overrides > 1.0)
|
||||||
t = 1.0
|
t = 1.0
|
||||||
|
|
@ -155,9 +163,13 @@ def pad_data_to_good_offset(
|
||||||
and ("temperature" in item["overrides"][i])
|
and ("temperature" in item["overrides"][i])
|
||||||
):
|
):
|
||||||
t = float(item["overrides"][i]["temperature"])
|
t = float(item["overrides"][i]["temperature"])
|
||||||
elif item.get("generation_params") and ("temperature" in item["generation_params"]):
|
elif item.get("generation_params") and (
|
||||||
|
"temperature" in item["generation_params"]
|
||||||
|
):
|
||||||
t = float(item["generation_params"]["temperature"])
|
t = float(item["generation_params"]["temperature"])
|
||||||
elif item.get("group_overrides") and ("temperature" in item["group_overrides"]):
|
elif item.get("group_overrides") and (
|
||||||
|
"temperature" in item["group_overrides"]
|
||||||
|
):
|
||||||
t = float(item["group_overrides"]["temperature"])
|
t = float(item["group_overrides"]["temperature"])
|
||||||
temperatures.append(t)
|
temperatures.append(t)
|
||||||
|
|
||||||
|
|
@ -172,19 +184,15 @@ def pad_data_to_good_offset(
|
||||||
start = i * batch_size
|
start = i * batch_size
|
||||||
end = (i + 1) * batch_size
|
end = (i + 1) * batch_size
|
||||||
|
|
||||||
token_batches.append(
|
token_batches.append(torch.tensor(np.stack(input_ids[start:end], axis=0)))
|
||||||
torch.tensor(np.stack(input_ids[start:end], axis=0))
|
label_batches.append(torch.tensor(np.stack(labels[start:end], axis=0)))
|
||||||
)
|
|
||||||
label_batches.append(
|
|
||||||
torch.tensor(np.stack(labels[start:end], axis=0))
|
|
||||||
)
|
|
||||||
advantage_batches.append(
|
advantage_batches.append(
|
||||||
torch.tensor(np.stack(advantages[start:end], axis=0)).view(-1, 1)
|
torch.tensor(np.stack(advantages[start:end], axis=0)).view(-1, 1)
|
||||||
)
|
)
|
||||||
temperature_batches.append(
|
temperature_batches.append(
|
||||||
torch.tensor(
|
torch.tensor(np.array(temperatures[start:end], dtype=np.float32)).view(
|
||||||
np.array(temperatures[start:end], dtype=np.float32)
|
-1, 1, 1
|
||||||
).view(-1, 1, 1)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Batch inference logprobs (same shape as labels)
|
# Batch inference logprobs (same shape as labels)
|
||||||
|
|
@ -194,9 +202,19 @@ def pad_data_to_good_offset(
|
||||||
)
|
)
|
||||||
|
|
||||||
# Return inference logprob batches if we have any real logprobs
|
# Return inference logprob batches if we have any real logprobs
|
||||||
final_logprob_batches = inference_logprob_batches if (has_any_logprobs and inference_logprob_batches) else None
|
final_logprob_batches = (
|
||||||
|
inference_logprob_batches
|
||||||
|
if (has_any_logprobs and inference_logprob_batches)
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
return token_batches, label_batches, advantage_batches, temperature_batches, final_logprob_batches
|
return (
|
||||||
|
token_batches,
|
||||||
|
label_batches,
|
||||||
|
advantage_batches,
|
||||||
|
temperature_batches,
|
||||||
|
final_logprob_batches,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_data(
|
def get_data(
|
||||||
|
|
@ -205,13 +223,15 @@ def get_data(
|
||||||
atropos_url: str = "http://localhost:8000",
|
atropos_url: str = "http://localhost:8000",
|
||||||
extract_inference_logprobs: bool = True,
|
extract_inference_logprobs: bool = True,
|
||||||
) -> Tuple[
|
) -> Tuple[
|
||||||
List[Tuple[
|
List[
|
||||||
List[torch.Tensor], # token_batches
|
Tuple[
|
||||||
List[torch.Tensor], # label_batches
|
List[torch.Tensor], # token_batches
|
||||||
List[torch.Tensor], # advantage_batches
|
List[torch.Tensor], # label_batches
|
||||||
List[torch.Tensor], # temperature_batches
|
List[torch.Tensor], # advantage_batches
|
||||||
Optional[List[torch.Tensor]], # inference_logprob_batches
|
List[torch.Tensor], # temperature_batches
|
||||||
]],
|
Optional[List[torch.Tensor]], # inference_logprob_batches
|
||||||
|
]
|
||||||
|
],
|
||||||
None, # Legacy return (no longer used)
|
None, # Legacy return (no longer used)
|
||||||
]:
|
]:
|
||||||
"""
|
"""
|
||||||
|
|
@ -241,18 +261,37 @@ def get_data(
|
||||||
if data["batch"] is not None:
|
if data["batch"] is not None:
|
||||||
# DEBUG: Check if inference_logprobs exists in the data
|
# DEBUG: Check if inference_logprobs exists in the data
|
||||||
if not _logged_logprob_warning:
|
if not _logged_logprob_warning:
|
||||||
has_logprobs = any("inference_logprobs" in item for item in data["batch"])
|
has_logprobs = any(
|
||||||
|
"inference_logprobs" in item for item in data["batch"]
|
||||||
|
)
|
||||||
if has_logprobs:
|
if has_logprobs:
|
||||||
# Check if they're non-empty
|
# Check if they're non-empty
|
||||||
sample_item = next((item for item in data["batch"] if "inference_logprobs" in item), None)
|
sample_item = next(
|
||||||
|
(
|
||||||
|
item
|
||||||
|
for item in data["batch"]
|
||||||
|
if "inference_logprobs" in item
|
||||||
|
),
|
||||||
|
None,
|
||||||
|
)
|
||||||
if sample_item and sample_item.get("inference_logprobs"):
|
if sample_item and sample_item.get("inference_logprobs"):
|
||||||
sample_lp = sample_item["inference_logprobs"][0] if sample_item["inference_logprobs"] else []
|
sample_lp = (
|
||||||
print(f" [Data] ✓ inference_logprobs found in batch (sample len: {len(sample_lp)})")
|
sample_item["inference_logprobs"][0]
|
||||||
|
if sample_item["inference_logprobs"]
|
||||||
|
else []
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f" [Data] ✓ inference_logprobs found in batch (sample len: {len(sample_lp)})"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
print(" [Data] ⚠ inference_logprobs key exists but is empty!")
|
print(
|
||||||
|
" [Data] ⚠ inference_logprobs key exists but is empty!"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
print(" [Data] ⚠ NO inference_logprobs in batch data!")
|
print(" [Data] ⚠ NO inference_logprobs in batch data!")
|
||||||
print(f" [Data] Keys in first item: {list(data['batch'][0].keys())}")
|
print(
|
||||||
|
f" [Data] Keys in first item: {list(data['batch'][0].keys())}"
|
||||||
|
)
|
||||||
_logged_logprob_warning = True
|
_logged_logprob_warning = True
|
||||||
|
|
||||||
# Save batch for debugging
|
# Save batch for debugging
|
||||||
|
|
@ -260,11 +299,24 @@ def get_data(
|
||||||
json.dump(data, f)
|
json.dump(data, f)
|
||||||
|
|
||||||
# Process and accumulate batches (now includes batched inference logprobs)
|
# Process and accumulate batches (now includes batched inference logprobs)
|
||||||
token_batches, label_batches, adv_batches, temp_batches, inf_logprob_batches = \
|
(
|
||||||
pad_data_to_good_offset(data, batch_size, extract_inference_logprobs)
|
token_batches,
|
||||||
|
label_batches,
|
||||||
|
adv_batches,
|
||||||
|
temp_batches,
|
||||||
|
inf_logprob_batches,
|
||||||
|
) = pad_data_to_good_offset(data, batch_size, extract_inference_logprobs)
|
||||||
|
|
||||||
# Include inference logprob batches in the tuple
|
# Include inference logprob batches in the tuple
|
||||||
batches.append((token_batches, label_batches, adv_batches, temp_batches, inf_logprob_batches))
|
batches.append(
|
||||||
|
(
|
||||||
|
token_batches,
|
||||||
|
label_batches,
|
||||||
|
adv_batches,
|
||||||
|
temp_batches,
|
||||||
|
inf_logprob_batches,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
elif len(batches) > 0:
|
elif len(batches) > 0:
|
||||||
# Return accumulated batches when no more data
|
# Return accumulated batches when no more data
|
||||||
|
|
@ -272,4 +324,3 @@ def get_data(
|
||||||
else:
|
else:
|
||||||
# Wait for data
|
# Wait for data
|
||||||
time.sleep(1)
|
time.sleep(1)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -19,8 +19,8 @@ Usage:
|
||||||
--weight-bridge-mode lora_only --lora-r 16 --lora-alpha 32
|
--weight-bridge-mode lora_only --lora-r 16 --lora-alpha 32
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from .cli import parse_args, config_from_args
|
from .cli import config_from_args, parse_args
|
||||||
from .trainers import train_legacy, train_shared_vllm, train_lora
|
from .trainers import train_legacy, train_lora, train_shared_vllm
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
|
@ -28,9 +28,9 @@ def main():
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
config = config_from_args(args)
|
config = config_from_args(args)
|
||||||
|
|
||||||
print("\n" + "="*60)
|
print("\n" + "=" * 60)
|
||||||
print("GRPO TRAINER")
|
print("GRPO TRAINER")
|
||||||
print("="*60)
|
print("=" * 60)
|
||||||
print(f"Model: {config.model_name}")
|
print(f"Model: {config.model_name}")
|
||||||
print(f"Mode: {config.weight_bridge_mode}")
|
print(f"Mode: {config.weight_bridge_mode}")
|
||||||
print(f"Training steps: {config.training_steps}")
|
print(f"Training steps: {config.training_steps}")
|
||||||
|
|
|
||||||
|
|
@ -20,6 +20,7 @@ from .config import TrainingConfig
|
||||||
# Import PEFT for LoRA training
|
# Import PEFT for LoRA training
|
||||||
try:
|
try:
|
||||||
from peft import LoraConfig, TaskType, get_peft_model
|
from peft import LoraConfig, TaskType, get_peft_model
|
||||||
|
|
||||||
PEFT_AVAILABLE = True
|
PEFT_AVAILABLE = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
PEFT_AVAILABLE = False
|
PEFT_AVAILABLE = False
|
||||||
|
|
@ -37,6 +38,7 @@ def _get_attention_implementation() -> str:
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
import flash_attn # noqa: F401
|
import flash_attn # noqa: F401
|
||||||
|
|
||||||
return "flash_attention_2"
|
return "flash_attention_2"
|
||||||
except ImportError:
|
except ImportError:
|
||||||
return "sdpa"
|
return "sdpa"
|
||||||
|
|
@ -61,12 +63,19 @@ def _load_model_with_attention(
|
||||||
"""
|
"""
|
||||||
# Select the loader function based on mode
|
# Select the loader function based on mode
|
||||||
# from_config: creates empty shell (meta device), from_pretrained: loads weights
|
# from_config: creates empty shell (meta device), from_pretrained: loads weights
|
||||||
loader = AutoModelForCausalLM.from_config if from_config else AutoModelForCausalLM.from_pretrained
|
loader = (
|
||||||
|
AutoModelForCausalLM.from_config
|
||||||
|
if from_config
|
||||||
|
else AutoModelForCausalLM.from_pretrained
|
||||||
|
)
|
||||||
|
|
||||||
# Try attention implementations in order of preference
|
# Try attention implementations in order of preference
|
||||||
for attn_impl in ["flash_attention_2", "sdpa"]:
|
for attn_impl in ["flash_attention_2", "sdpa"]:
|
||||||
# Skip flash_attention_2 if not available
|
# Skip flash_attention_2 if not available
|
||||||
if attn_impl == "flash_attention_2" and _get_attention_implementation() != "flash_attention_2":
|
if (
|
||||||
|
attn_impl == "flash_attention_2"
|
||||||
|
and _get_attention_implementation() != "flash_attention_2"
|
||||||
|
):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
@ -86,6 +95,7 @@ def _load_model_with_attention(
|
||||||
# Should never reach here, but just in case
|
# Should never reach here, but just in case
|
||||||
raise RuntimeError("Failed to load model with any attention implementation")
|
raise RuntimeError("Failed to load model with any attention implementation")
|
||||||
|
|
||||||
|
|
||||||
def load_model_and_tokenizer(
|
def load_model_and_tokenizer(
|
||||||
config: TrainingConfig,
|
config: TrainingConfig,
|
||||||
single_copy: bool = False,
|
single_copy: bool = False,
|
||||||
|
|
@ -178,7 +188,7 @@ def _load_model_with_lora(config: TrainingConfig) -> torch.nn.Module:
|
||||||
Returns:
|
Returns:
|
||||||
PEFT model with LoRA adapters applied
|
PEFT model with LoRA adapters applied
|
||||||
"""
|
"""
|
||||||
if not PEFT_AVAILABLE: # Yeah no PEFT is needed no matter what bless huggingface
|
if not PEFT_AVAILABLE: # Yeah no PEFT is needed no matter what bless huggingface
|
||||||
raise RuntimeError("PEFT library not available. Install with: pip install peft")
|
raise RuntimeError("PEFT library not available. Install with: pip install peft")
|
||||||
|
|
||||||
print("[Setup] Loading base model for LoRA mode...")
|
print("[Setup] Loading base model for LoRA mode...")
|
||||||
|
|
@ -208,7 +218,9 @@ def _load_model_with_lora(config: TrainingConfig) -> torch.nn.Module:
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
def _setup_gradient_checkpointing(model: torch.nn.Module, config: TrainingConfig) -> None:
|
def _setup_gradient_checkpointing(
|
||||||
|
model: torch.nn.Module, config: TrainingConfig
|
||||||
|
) -> None:
|
||||||
"""Configure gradient checkpointing for the model."""
|
"""Configure gradient checkpointing for the model."""
|
||||||
# Disable KV cache - incompatible with gradient checkpointing
|
# Disable KV cache - incompatible with gradient checkpointing
|
||||||
model.config.use_cache = False
|
model.config.use_cache = False
|
||||||
|
|
@ -297,7 +309,9 @@ def _attach_to_vllm_shared_tensors(
|
||||||
ipc_handles, vllm_to_hf_mapping, config
|
ipc_handles, vllm_to_hf_mapping, config
|
||||||
)
|
)
|
||||||
|
|
||||||
print(f"[Setup] Attached {attached_count} tensors ({fused_count} from fused layers)")
|
print(
|
||||||
|
f"[Setup] Attached {attached_count} tensors ({fused_count} from fused layers)"
|
||||||
|
)
|
||||||
|
|
||||||
if attached_count == 0:
|
if attached_count == 0:
|
||||||
print("[Setup] Could not attach any tensors, falling back to regular loading")
|
print("[Setup] Could not attach any tensors, falling back to regular loading")
|
||||||
|
|
@ -322,6 +336,7 @@ def _attach_to_vllm_shared_tensors(
|
||||||
|
|
||||||
def _deserialize_ipc_handles(handles_raw: dict) -> dict:
|
def _deserialize_ipc_handles(handles_raw: dict) -> dict:
|
||||||
"""Deserialize base64-encoded bytes in IPC handles."""
|
"""Deserialize base64-encoded bytes in IPC handles."""
|
||||||
|
|
||||||
def deserialize(handles):
|
def deserialize(handles):
|
||||||
result = {}
|
result = {}
|
||||||
for k, v in handles.items():
|
for k, v in handles.items():
|
||||||
|
|
@ -333,6 +348,7 @@ def _deserialize_ipc_handles(handles_raw: dict) -> dict:
|
||||||
else:
|
else:
|
||||||
result[k] = v
|
result[k] = v
|
||||||
return result
|
return result
|
||||||
|
|
||||||
return deserialize(handles_raw)
|
return deserialize(handles_raw)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -387,8 +403,14 @@ def _reconstruct_shared_tensors(
|
||||||
event_sync_required = ipc_info["event_sync_required"]
|
event_sync_required = ipc_info["event_sync_required"]
|
||||||
|
|
||||||
share_tuple = (
|
share_tuple = (
|
||||||
device_index, ipc_handle, storage_size, storage_offset_orig,
|
device_index,
|
||||||
ref_counter_handle, ref_counter_offset, event_handle, event_sync_required,
|
ipc_handle,
|
||||||
|
storage_size,
|
||||||
|
storage_offset_orig,
|
||||||
|
ref_counter_handle,
|
||||||
|
ref_counter_offset,
|
||||||
|
event_handle,
|
||||||
|
event_sync_required,
|
||||||
)
|
)
|
||||||
|
|
||||||
storage = torch.UntypedStorage._new_shared_cuda(*share_tuple)
|
storage = torch.UntypedStorage._new_shared_cuda(*share_tuple)
|
||||||
|
|
@ -424,7 +446,9 @@ def _reconstruct_shared_tensors(
|
||||||
if slice_dim == 0:
|
if slice_dim == 0:
|
||||||
tensor = full_tensor[slice_start:slice_end]
|
tensor = full_tensor[slice_start:slice_end]
|
||||||
else:
|
else:
|
||||||
tensor = full_tensor.narrow(slice_dim, slice_start, slice_end - slice_start)
|
tensor = full_tensor.narrow(
|
||||||
|
slice_dim, slice_start, slice_end - slice_start
|
||||||
|
)
|
||||||
|
|
||||||
tensor.requires_grad_(True)
|
tensor.requires_grad_(True)
|
||||||
hf_state_dict[hf_name] = tensor
|
hf_state_dict[hf_name] = tensor
|
||||||
|
|
@ -459,12 +483,16 @@ def _validate_mapping_coverage(
|
||||||
|
|
||||||
# Note: attached_count may be > param_count because state_dict includes buffers
|
# Note: attached_count may be > param_count because state_dict includes buffers
|
||||||
# while named_parameters only counts trainable params
|
# while named_parameters only counts trainable params
|
||||||
print(f"[Setup] Mapping coverage: {attached_count} tensors for {hf_param_count} parameters "
|
print(
|
||||||
f"(>100% is OK - includes buffers)")
|
f"[Setup] Mapping coverage: {attached_count} tensors for {hf_param_count} parameters "
|
||||||
|
f"(>100% is OK - includes buffers)"
|
||||||
|
)
|
||||||
|
|
||||||
if mapping_coverage < 0.90:
|
if mapping_coverage < 0.90:
|
||||||
unmapped_params = set(model.state_dict().keys()) - set(hf_state_dict.keys())
|
unmapped_params = set(model.state_dict().keys()) - set(hf_state_dict.keys())
|
||||||
warning_msg = f"[Setup] WARNING: Low mapping coverage ({mapping_coverage:.1%})\n"
|
warning_msg = (
|
||||||
|
f"[Setup] WARNING: Low mapping coverage ({mapping_coverage:.1%})\n"
|
||||||
|
)
|
||||||
warning_msg += f"Unmapped parameters ({len(unmapped_params)}):\n"
|
warning_msg += f"Unmapped parameters ({len(unmapped_params)}):\n"
|
||||||
for name in list(unmapped_params)[:20]:
|
for name in list(unmapped_params)[:20]:
|
||||||
warning_msg += f" - {name}\n"
|
warning_msg += f" - {name}\n"
|
||||||
|
|
@ -484,11 +512,17 @@ def _initialize_meta_tensors(
|
||||||
config: TrainingConfig,
|
config: TrainingConfig,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initialize any remaining meta tensors after loading."""
|
"""Initialize any remaining meta tensors after loading."""
|
||||||
meta_params = [name for name, p in model.named_parameters() if p.device.type == "meta"]
|
meta_params = [
|
||||||
meta_buffers = [name for name, b in model.named_buffers() if b.device.type == "meta"]
|
name for name, p in model.named_parameters() if p.device.type == "meta"
|
||||||
|
]
|
||||||
|
meta_buffers = [
|
||||||
|
name for name, b in model.named_buffers() if b.device.type == "meta"
|
||||||
|
]
|
||||||
|
|
||||||
if config.debug_loading:
|
if config.debug_loading:
|
||||||
print(f"\n[DIAGNOSTIC] Meta params: {len(meta_params)}, Meta buffers: {len(meta_buffers)}")
|
print(
|
||||||
|
f"\n[DIAGNOSTIC] Meta params: {len(meta_params)}, Meta buffers: {len(meta_buffers)}"
|
||||||
|
)
|
||||||
|
|
||||||
def get_parent_and_name(model, full_name):
|
def get_parent_and_name(model, full_name):
|
||||||
parts = full_name.split(".")
|
parts = full_name.split(".")
|
||||||
|
|
@ -526,11 +560,15 @@ def _initialize_meta_tensors(
|
||||||
dim = buffer.shape[0] * 2
|
dim = buffer.shape[0] * 2
|
||||||
# Get rope_theta from model config (default 10000.0 for LLaMA, but Qwen3 uses 5000000!)
|
# Get rope_theta from model config (default 10000.0 for LLaMA, but Qwen3 uses 5000000!)
|
||||||
rope_theta = getattr(model.config, "rope_theta", 10000.0)
|
rope_theta = getattr(model.config, "rope_theta", 10000.0)
|
||||||
inv_freq = 1.0 / (rope_theta ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))
|
inv_freq = 1.0 / (
|
||||||
|
rope_theta ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)
|
||||||
|
)
|
||||||
new_buffer = inv_freq.to(dtype=buffer.dtype, device=device)
|
new_buffer = inv_freq.to(dtype=buffer.dtype, device=device)
|
||||||
print(f"[Setup] Initialized {name} with rope_theta={rope_theta}")
|
print(f"[Setup] Initialized {name} with rope_theta={rope_theta}")
|
||||||
else:
|
else:
|
||||||
new_buffer = torch.zeros(buffer.shape, dtype=buffer.dtype, device=device)
|
new_buffer = torch.zeros(
|
||||||
|
buffer.shape, dtype=buffer.dtype, device=device
|
||||||
|
)
|
||||||
|
|
||||||
parent, attr_name = get_parent_and_name(model, name)
|
parent, attr_name = get_parent_and_name(model, name)
|
||||||
parent.register_buffer(attr_name, new_buffer)
|
parent.register_buffer(attr_name, new_buffer)
|
||||||
|
|
@ -544,8 +582,12 @@ def _initialize_meta_tensors(
|
||||||
|
|
||||||
def _validate_no_meta_tensors(model: torch.nn.Module) -> None:
|
def _validate_no_meta_tensors(model: torch.nn.Module) -> None:
|
||||||
"""Ensure no parameters or buffers are still on meta device."""
|
"""Ensure no parameters or buffers are still on meta device."""
|
||||||
final_meta_params = [name for name, p in model.named_parameters() if p.device.type == "meta"]
|
final_meta_params = [
|
||||||
final_meta_buffers = [name for name, b in model.named_buffers() if b.device.type == "meta"]
|
name for name, p in model.named_parameters() if p.device.type == "meta"
|
||||||
|
]
|
||||||
|
final_meta_buffers = [
|
||||||
|
name for name, b in model.named_buffers() if b.device.type == "meta"
|
||||||
|
]
|
||||||
|
|
||||||
if final_meta_params or final_meta_buffers:
|
if final_meta_params or final_meta_buffers:
|
||||||
error_msg = "[Setup] CRITICAL ERROR: Some tensors are still on meta device!\n"
|
error_msg = "[Setup] CRITICAL ERROR: Some tensors are still on meta device!\n"
|
||||||
|
|
@ -588,7 +630,9 @@ def _create_vllm_to_hf_mapping(
|
||||||
model_config = model.config
|
model_config = model.config
|
||||||
hidden_size = getattr(model_config, "hidden_size", 4096)
|
hidden_size = getattr(model_config, "hidden_size", 4096)
|
||||||
num_attention_heads = getattr(model_config, "num_attention_heads", 32)
|
num_attention_heads = getattr(model_config, "num_attention_heads", 32)
|
||||||
num_key_value_heads = getattr(model_config, "num_key_value_heads", num_attention_heads)
|
num_key_value_heads = getattr(
|
||||||
|
model_config, "num_key_value_heads", num_attention_heads
|
||||||
|
)
|
||||||
intermediate_size = getattr(model_config, "intermediate_size", hidden_size * 4)
|
intermediate_size = getattr(model_config, "intermediate_size", hidden_size * 4)
|
||||||
|
|
||||||
# Try to get head_dim from config (some models like Qwen3 have this)
|
# Try to get head_dim from config (some models like Qwen3 have this)
|
||||||
|
|
@ -639,8 +683,10 @@ def _create_vllm_to_hf_mapping(
|
||||||
up_size = intermediate_size
|
up_size = intermediate_size
|
||||||
|
|
||||||
# Always print sizes for debugging weight sharing issues
|
# Always print sizes for debugging weight sharing issues
|
||||||
print(f"[Mapping] Model config: hidden={hidden_size}, heads={num_attention_heads}, "
|
print(
|
||||||
f"kv_heads={num_key_value_heads}, head_dim={head_dim}")
|
f"[Mapping] Model config: hidden={hidden_size}, heads={num_attention_heads}, "
|
||||||
|
f"kv_heads={num_key_value_heads}, head_dim={head_dim}"
|
||||||
|
)
|
||||||
print(f"[Mapping] QKV sizes from HF model: q={q_size}, k={k_size}, v={v_size}")
|
print(f"[Mapping] QKV sizes from HF model: q={q_size}, k={k_size}, v={v_size}")
|
||||||
print(f"[Mapping] Gate/Up sizes from HF model: gate={gate_size}, up={up_size}")
|
print(f"[Mapping] Gate/Up sizes from HF model: gate={gate_size}, up={up_size}")
|
||||||
|
|
||||||
|
|
@ -714,7 +760,8 @@ def _create_vllm_to_hf_mapping(
|
||||||
if debug:
|
if debug:
|
||||||
direct = sum(1 for v in mapping.values() if isinstance(v, str))
|
direct = sum(1 for v in mapping.values() if isinstance(v, str))
|
||||||
fused = sum(1 for v in mapping.values() if isinstance(v, dict))
|
fused = sum(1 for v in mapping.values() if isinstance(v, dict))
|
||||||
print(f"[Mapping] Total: {len(mapping)} mapped ({direct} direct, {fused} fused)")
|
print(
|
||||||
|
f"[Mapping] Total: {len(mapping)} mapped ({direct} direct, {fused} fused)"
|
||||||
|
)
|
||||||
|
|
||||||
return mapping
|
return mapping
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -60,10 +60,13 @@ def wait_for_bridge_config(config_path: str, timeout: int = 60) -> bool:
|
||||||
if os.path.exists(config_path):
|
if os.path.exists(config_path):
|
||||||
try:
|
try:
|
||||||
import json
|
import json
|
||||||
with open(config_path, 'r') as f:
|
|
||||||
|
with open(config_path, "r") as f:
|
||||||
config = json.load(f)
|
config = json.load(f)
|
||||||
if config.get('ipc_handles') and len(config['ipc_handles']) > 0:
|
if config.get("ipc_handles") and len(config["ipc_handles"]) > 0:
|
||||||
print(f"[Run] ✓ Bridge config ready with {len(config['ipc_handles'])} IPC handles")
|
print(
|
||||||
|
f"[Run] ✓ Bridge config ready with {len(config['ipc_handles'])} IPC handles"
|
||||||
|
)
|
||||||
return True
|
return True
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
@ -79,7 +82,7 @@ def main():
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Create log directory
|
# Create log directory
|
||||||
log_dir = getattr(args, 'log_dir', './logs')
|
log_dir = getattr(args, "log_dir", "./logs")
|
||||||
os.makedirs(log_dir, exist_ok=True)
|
os.makedirs(log_dir, exist_ok=True)
|
||||||
|
|
||||||
# Bridge config path
|
# Bridge config path
|
||||||
|
|
@ -91,16 +94,16 @@ def main():
|
||||||
print("[Run] Removed old bridge config")
|
print("[Run] Removed old bridge config")
|
||||||
|
|
||||||
# === Print Configuration ===
|
# === Print Configuration ===
|
||||||
print("\n" + "="*60)
|
print("\n" + "=" * 60)
|
||||||
print("STARTING UNIFIED GRPO TRAINER (shared_vllm mode)")
|
print("STARTING UNIFIED GRPO TRAINER (shared_vllm mode)")
|
||||||
print("="*60)
|
print("=" * 60)
|
||||||
print(f"Model: {args.model_name}")
|
print(f"Model: {args.model_name}")
|
||||||
print(f"vLLM port: {args.vllm_port}")
|
print(f"vLLM port: {args.vllm_port}")
|
||||||
print(f"GPU memory utilization: {args.gpu_memory_utilization}")
|
print(f"GPU memory utilization: {args.gpu_memory_utilization}")
|
||||||
print(f"Training steps: {args.training_steps}")
|
print(f"Training steps: {args.training_steps}")
|
||||||
print(f"Optimizer: {args.optimizer}")
|
print(f"Optimizer: {args.optimizer}")
|
||||||
print(f"GRPO: kl_coef={args.kl_coef}, clip_eps={args.clip_eps}")
|
print(f"GRPO: kl_coef={args.kl_coef}, clip_eps={args.clip_eps}")
|
||||||
print("="*60 + "\n")
|
print("=" * 60 + "\n")
|
||||||
|
|
||||||
# Get the path to vllm_api_server.py
|
# Get the path to vllm_api_server.py
|
||||||
script_dir = Path(__file__).parent
|
script_dir = Path(__file__).parent
|
||||||
|
|
@ -126,12 +129,19 @@ def main():
|
||||||
|
|
||||||
# Build vLLM command
|
# Build vLLM command
|
||||||
vllm_cmd = [
|
vllm_cmd = [
|
||||||
sys.executable, "-u", str(vllm_server_script),
|
sys.executable,
|
||||||
"--model", args.model_name,
|
"-u",
|
||||||
"--port", str(args.vllm_port),
|
str(vllm_server_script),
|
||||||
"--dtype", args.dtype,
|
"--model",
|
||||||
"--gpu-memory-utilization", str(args.gpu_memory_utilization),
|
args.model_name,
|
||||||
"--max-model-len", str(args.max_model_len),
|
"--port",
|
||||||
|
str(args.vllm_port),
|
||||||
|
"--dtype",
|
||||||
|
args.dtype,
|
||||||
|
"--gpu-memory-utilization",
|
||||||
|
str(args.gpu_memory_utilization),
|
||||||
|
"--max-model-len",
|
||||||
|
str(args.max_model_len),
|
||||||
"--enforce-eager", # Required for shared weights
|
"--enforce-eager", # Required for shared weights
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
@ -212,6 +222,7 @@ def main():
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"\n[Run] ✗ Training failed: {e}")
|
print(f"\n[Run] ✗ Training failed: {e}")
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -138,4 +138,3 @@ if [ -d "$LOG_DIR/checkpoints" ]; then
|
||||||
}' | jq '.text[0]' | tee "${LOG_DIR}/trained_response.txt"
|
}' | jq '.text[0]' | tee "${LOG_DIR}/trained_response.txt"
|
||||||
fi
|
fi
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -143,4 +143,3 @@ curl -s -X POST "http://localhost:${VLLM_PORT}/generate" \
|
||||||
"max_tokens": 100,
|
"max_tokens": 100,
|
||||||
"temperature": 0.1
|
"temperature": 0.1
|
||||||
}' | jq '.text[0]' | tee "${LOG_DIR}/trained_response.txt"
|
}' | jq '.text[0]' | tee "${LOG_DIR}/trained_response.txt"
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -25,7 +25,10 @@ class CPUOffloadAdamW(torch.optim.Optimizer):
|
||||||
Full precision (no quantization), but states stay on CPU RAM instead of GPU.
|
Full precision (no quantization), but states stay on CPU RAM instead of GPU.
|
||||||
Trade-off: Slower (~2x) but uses ~0GB GPU memory for optimizer states.
|
Trade-off: Slower (~2x) but uses ~0GB GPU memory for optimizer states.
|
||||||
"""
|
"""
|
||||||
def __init__(self, params, lr=1e-5, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.01):
|
|
||||||
|
def __init__(
|
||||||
|
self, params, lr=1e-5, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.01
|
||||||
|
):
|
||||||
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
|
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
|
||||||
super().__init__(params, defaults)
|
super().__init__(params, defaults)
|
||||||
|
|
||||||
|
|
@ -33,10 +36,10 @@ class CPUOffloadAdamW(torch.optim.Optimizer):
|
||||||
"""Lazily initialize state on CPU."""
|
"""Lazily initialize state on CPU."""
|
||||||
state = self.state[p]
|
state = self.state[p]
|
||||||
if len(state) == 0:
|
if len(state) == 0:
|
||||||
state['step'] = 0
|
state["step"] = 0
|
||||||
# Store on CPU in FP32
|
# Store on CPU in FP32
|
||||||
state['exp_avg'] = torch.zeros_like(p, device='cpu', dtype=torch.float32)
|
state["exp_avg"] = torch.zeros_like(p, device="cpu", dtype=torch.float32)
|
||||||
state['exp_avg_sq'] = torch.zeros_like(p, device='cpu', dtype=torch.float32)
|
state["exp_avg_sq"] = torch.zeros_like(p, device="cpu", dtype=torch.float32)
|
||||||
return state
|
return state
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
|
|
@ -47,41 +50,41 @@ class CPUOffloadAdamW(torch.optim.Optimizer):
|
||||||
loss = closure()
|
loss = closure()
|
||||||
|
|
||||||
for group in self.param_groups:
|
for group in self.param_groups:
|
||||||
beta1, beta2 = group['betas']
|
beta1, beta2 = group["betas"]
|
||||||
|
|
||||||
for p in group['params']:
|
for p in group["params"]:
|
||||||
if p.grad is None:
|
if p.grad is None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
grad = p.grad
|
grad = p.grad
|
||||||
state = self._init_state(p)
|
state = self._init_state(p)
|
||||||
|
|
||||||
state['step'] += 1
|
state["step"] += 1
|
||||||
|
|
||||||
# Move states to GPU for computation
|
# Move states to GPU for computation
|
||||||
exp_avg = state['exp_avg'].to(p.device)
|
exp_avg = state["exp_avg"].to(p.device)
|
||||||
exp_avg_sq = state['exp_avg_sq'].to(p.device)
|
exp_avg_sq = state["exp_avg_sq"].to(p.device)
|
||||||
|
|
||||||
# AdamW update
|
# AdamW update
|
||||||
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
|
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
|
||||||
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
|
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
|
||||||
|
|
||||||
# Bias correction
|
# Bias correction
|
||||||
bias_correction1 = 1 - beta1 ** state['step']
|
bias_correction1 = 1 - beta1 ** state["step"]
|
||||||
bias_correction2 = 1 - beta2 ** state['step']
|
bias_correction2 = 1 - beta2 ** state["step"]
|
||||||
step_size = group['lr'] / bias_correction1
|
step_size = group["lr"] / bias_correction1
|
||||||
|
|
||||||
# Update weights
|
# Update weights
|
||||||
denom = (exp_avg_sq.sqrt() / (bias_correction2 ** 0.5)).add_(group['eps'])
|
denom = (exp_avg_sq.sqrt() / (bias_correction2**0.5)).add_(group["eps"])
|
||||||
p.addcdiv_(exp_avg, denom, value=-step_size)
|
p.addcdiv_(exp_avg, denom, value=-step_size)
|
||||||
|
|
||||||
# Weight decay
|
# Weight decay
|
||||||
if group['weight_decay'] != 0:
|
if group["weight_decay"] != 0:
|
||||||
p.add_(p, alpha=-group['lr'] * group['weight_decay'])
|
p.add_(p, alpha=-group["lr"] * group["weight_decay"])
|
||||||
|
|
||||||
# Move states back to CPU (non-blocking for better perf)
|
# Move states back to CPU (non-blocking for better perf)
|
||||||
state['exp_avg'].copy_(exp_avg.cpu())
|
state["exp_avg"].copy_(exp_avg.cpu())
|
||||||
state['exp_avg_sq'].copy_(exp_avg_sq.cpu())
|
state["exp_avg_sq"].copy_(exp_avg_sq.cpu())
|
||||||
|
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
|
@ -99,6 +102,7 @@ def create_optimizer(model: torch.nn.Module, config) -> torch.optim.Optimizer:
|
||||||
if config.optimizer == "adamw_8bit":
|
if config.optimizer == "adamw_8bit":
|
||||||
try:
|
try:
|
||||||
import bitsandbytes as bnb
|
import bitsandbytes as bnb
|
||||||
|
|
||||||
optimizer = bnb.optim.AdamW8bit(model.parameters(), lr=config.lr)
|
optimizer = bnb.optim.AdamW8bit(model.parameters(), lr=config.lr)
|
||||||
print("[Setup] Using 8-bit AdamW (saves ~24GB optimizer memory)")
|
print("[Setup] Using 8-bit AdamW (saves ~24GB optimizer memory)")
|
||||||
return optimizer
|
return optimizer
|
||||||
|
|
@ -108,13 +112,18 @@ def create_optimizer(model: torch.nn.Module, config) -> torch.optim.Optimizer:
|
||||||
|
|
||||||
if config.optimizer == "adamw_cpu":
|
if config.optimizer == "adamw_cpu":
|
||||||
optimizer = CPUOffloadAdamW(model.parameters(), lr=config.lr)
|
optimizer = CPUOffloadAdamW(model.parameters(), lr=config.lr)
|
||||||
print("[Setup] Using AdamW with CPU offload (full precision, ~0GB GPU for states)")
|
print(
|
||||||
print("[Setup] NOTE: ~2x slower due to CPU<->GPU transfers, but no quantization")
|
"[Setup] Using AdamW with CPU offload (full precision, ~0GB GPU for states)"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
"[Setup] NOTE: ~2x slower due to CPU<->GPU transfers, but no quantization"
|
||||||
|
)
|
||||||
return optimizer
|
return optimizer
|
||||||
|
|
||||||
if config.optimizer == "adafactor":
|
if config.optimizer == "adafactor":
|
||||||
try:
|
try:
|
||||||
from transformers.optimization import Adafactor
|
from transformers.optimization import Adafactor
|
||||||
|
|
||||||
optimizer = Adafactor(
|
optimizer = Adafactor(
|
||||||
model.parameters(),
|
model.parameters(),
|
||||||
lr=config.lr,
|
lr=config.lr,
|
||||||
|
|
@ -135,7 +144,7 @@ def create_optimizer(model: torch.nn.Module, config) -> torch.optim.Optimizer:
|
||||||
from .checkpointing import save_checkpoint, save_lora_checkpoint # noqa: E402
|
from .checkpointing import save_checkpoint, save_lora_checkpoint # noqa: E402
|
||||||
from .config import TrainingConfig # noqa: E402
|
from .config import TrainingConfig # noqa: E402
|
||||||
from .data import get_data # noqa: E402
|
from .data import get_data # noqa: E402
|
||||||
from .model import load_model_and_tokenizer, PEFT_AVAILABLE # noqa: E402
|
from .model import PEFT_AVAILABLE, load_model_and_tokenizer # noqa: E402
|
||||||
from .training import ( # noqa: E402
|
from .training import ( # noqa: E402
|
||||||
finalize_training,
|
finalize_training,
|
||||||
log_metrics,
|
log_metrics,
|
||||||
|
|
@ -146,8 +155,8 @@ from .vllm_manager import ( # noqa: E402
|
||||||
check_vllm_health,
|
check_vllm_health,
|
||||||
check_vllm_process_health,
|
check_vllm_process_health,
|
||||||
launch_vllm_server,
|
launch_vllm_server,
|
||||||
terminate_vllm_process,
|
|
||||||
set_vllm_process,
|
set_vllm_process,
|
||||||
|
terminate_vllm_process,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -171,13 +180,13 @@ def train_legacy(config: TrainingConfig):
|
||||||
model, tokenizer = load_model_and_tokenizer(config)
|
model, tokenizer = load_model_and_tokenizer(config)
|
||||||
optimizer = create_optimizer(model, config)
|
optimizer = create_optimizer(model, config)
|
||||||
|
|
||||||
print("\n" + "="*60)
|
print("\n" + "=" * 60)
|
||||||
print("LEGACY MODE (checkpoint + vLLM restart)")
|
print("LEGACY MODE (checkpoint + vLLM restart)")
|
||||||
print("="*60)
|
print("=" * 60)
|
||||||
print(f"Training for {config.training_steps} steps on {config.device}")
|
print(f"Training for {config.training_steps} steps on {config.device}")
|
||||||
print(f"vLLM restart interval: every {config.vllm_restart_interval} steps")
|
print(f"vLLM restart interval: every {config.vllm_restart_interval} steps")
|
||||||
print(f"Save path: {config.save_path}")
|
print(f"Save path: {config.save_path}")
|
||||||
print("="*60 + "\n")
|
print("=" * 60 + "\n")
|
||||||
|
|
||||||
os.makedirs(config.save_path, exist_ok=True)
|
os.makedirs(config.save_path, exist_ok=True)
|
||||||
|
|
||||||
|
|
@ -206,24 +215,36 @@ def train_legacy(config: TrainingConfig):
|
||||||
# Fetch data (with inference logprobs for proper GRPO)
|
# Fetch data (with inference logprobs for proper GRPO)
|
||||||
data_fetch_start = time.time()
|
data_fetch_start = time.time()
|
||||||
if len(batches) == 0:
|
if len(batches) == 0:
|
||||||
batches, _ = get_data(config.batch_size, config.seq_len, config.atropos_url,
|
batches, _ = get_data(
|
||||||
extract_inference_logprobs=True)
|
config.batch_size,
|
||||||
|
config.seq_len,
|
||||||
|
config.atropos_url,
|
||||||
|
extract_inference_logprobs=True,
|
||||||
|
)
|
||||||
batch_data = batches.pop(0)
|
batch_data = batches.pop(0)
|
||||||
token_batches, label_batches, advantage_batches, temperature_batches = batch_data[:4]
|
token_batches, label_batches, advantage_batches, temperature_batches = (
|
||||||
|
batch_data[:4]
|
||||||
|
)
|
||||||
inference_logprob_batches = batch_data[4] if len(batch_data) > 4 else None
|
inference_logprob_batches = batch_data[4] if len(batch_data) > 4 else None
|
||||||
data_fetch_time = time.time() - data_fetch_start
|
data_fetch_time = time.time() - data_fetch_start
|
||||||
benchmark_stats["data_fetch_times"].append(data_fetch_time)
|
benchmark_stats["data_fetch_times"].append(data_fetch_time)
|
||||||
|
|
||||||
# Check if we should sync (save checkpoint + restart vLLM)
|
# Check if we should sync (save checkpoint + restart vLLM)
|
||||||
should_sync = (step + 1) % config.vllm_restart_interval == 0 or step == config.training_steps - 1
|
should_sync = (
|
||||||
|
step + 1
|
||||||
|
) % config.vllm_restart_interval == 0 or step == config.training_steps - 1
|
||||||
if should_sync:
|
if should_sync:
|
||||||
terminate_vllm_process()
|
terminate_vllm_process()
|
||||||
|
|
||||||
# Training step (with proper GRPO using inference logprobs)
|
# Training step (with proper GRPO using inference logprobs)
|
||||||
step_start = time.time()
|
step_start = time.time()
|
||||||
metrics = run_training_step(
|
metrics = run_training_step(
|
||||||
model, optimizer,
|
model,
|
||||||
token_batches, label_batches, advantage_batches, temperature_batches,
|
optimizer,
|
||||||
|
token_batches,
|
||||||
|
label_batches,
|
||||||
|
advantage_batches,
|
||||||
|
temperature_batches,
|
||||||
config,
|
config,
|
||||||
inference_logprob_batches=inference_logprob_batches,
|
inference_logprob_batches=inference_logprob_batches,
|
||||||
)
|
)
|
||||||
|
|
@ -231,15 +252,21 @@ def train_legacy(config: TrainingConfig):
|
||||||
benchmark_stats["step_times"].append(step_time)
|
benchmark_stats["step_times"].append(step_time)
|
||||||
|
|
||||||
# GPU memory tracking
|
# GPU memory tracking
|
||||||
gpu_mem_gb = torch.cuda.memory_allocated() / 1e9 if torch.cuda.is_available() else 0
|
gpu_mem_gb = (
|
||||||
gpu_mem_reserved_gb = torch.cuda.memory_reserved() / 1e9 if torch.cuda.is_available() else 0
|
torch.cuda.memory_allocated() / 1e9 if torch.cuda.is_available() else 0
|
||||||
|
)
|
||||||
|
gpu_mem_reserved_gb = (
|
||||||
|
torch.cuda.memory_reserved() / 1e9 if torch.cuda.is_available() else 0
|
||||||
|
)
|
||||||
benchmark_stats["gpu_memories"].append(gpu_mem_gb)
|
benchmark_stats["gpu_memories"].append(gpu_mem_gb)
|
||||||
|
|
||||||
# Sync (checkpoint + restart)
|
# Sync (checkpoint + restart)
|
||||||
sync_time = 0
|
sync_time = 0
|
||||||
if should_sync:
|
if should_sync:
|
||||||
sync_start = time.time()
|
sync_start = time.time()
|
||||||
checkpoint_path = save_checkpoint(model, tokenizer, config.save_path, step + 1)
|
checkpoint_path = save_checkpoint(
|
||||||
|
model, tokenizer, config.save_path, step + 1
|
||||||
|
)
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
vllm_proc = launch_vllm_server(config, checkpoint_path)
|
vllm_proc = launch_vllm_server(config, checkpoint_path)
|
||||||
set_vllm_process(vllm_proc)
|
set_vllm_process(vllm_proc)
|
||||||
|
|
@ -247,20 +274,31 @@ def train_legacy(config: TrainingConfig):
|
||||||
benchmark_stats["sync_times"].append(sync_time)
|
benchmark_stats["sync_times"].append(sync_time)
|
||||||
|
|
||||||
# Update metrics
|
# Update metrics
|
||||||
metrics.update({
|
metrics.update(
|
||||||
"step_time": step_time,
|
{
|
||||||
"sync_time": sync_time,
|
"step_time": step_time,
|
||||||
"data_fetch_time": data_fetch_time,
|
"sync_time": sync_time,
|
||||||
"gpu_memory_gb": gpu_mem_gb,
|
"data_fetch_time": data_fetch_time,
|
||||||
"gpu_memory_reserved_gb": gpu_mem_reserved_gb,
|
"gpu_memory_gb": gpu_mem_gb,
|
||||||
})
|
"gpu_memory_reserved_gb": gpu_mem_reserved_gb,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
log_metrics(metrics, step + 1, use_wandb, benchmark=config.benchmark)
|
log_metrics(metrics, step + 1, use_wandb, benchmark=config.benchmark)
|
||||||
check_vllm_process_health()
|
check_vllm_process_health()
|
||||||
|
|
||||||
# === Cleanup ===
|
# === Cleanup ===
|
||||||
save_checkpoint(model, tokenizer, config.save_path, config.training_steps, is_final=True)
|
save_checkpoint(
|
||||||
finalize_training(use_wandb, training_start_time, "legacy", config.training_steps, benchmark_stats, config.benchmark)
|
model, tokenizer, config.save_path, config.training_steps, is_final=True
|
||||||
|
)
|
||||||
|
finalize_training(
|
||||||
|
use_wandb,
|
||||||
|
training_start_time,
|
||||||
|
"legacy",
|
||||||
|
config.training_steps,
|
||||||
|
benchmark_stats,
|
||||||
|
config.benchmark,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def train_shared_vllm(config: TrainingConfig):
|
def train_shared_vllm(config: TrainingConfig):
|
||||||
|
|
@ -281,13 +319,13 @@ def train_shared_vllm(config: TrainingConfig):
|
||||||
# === Setup ===
|
# === Setup ===
|
||||||
use_wandb = setup_wandb(config)
|
use_wandb = setup_wandb(config)
|
||||||
|
|
||||||
print("\n" + "="*60)
|
print("\n" + "=" * 60)
|
||||||
print("SINGLE-COPY MODE (CUDA IPC)")
|
print("SINGLE-COPY MODE (CUDA IPC)")
|
||||||
print(">>> Trainer uses vLLM's tensors directly!")
|
print(">>> Trainer uses vLLM's tensors directly!")
|
||||||
print("="*60)
|
print("=" * 60)
|
||||||
print(f"Model: {config.model_name}")
|
print(f"Model: {config.model_name}")
|
||||||
print(f"Save path: {config.save_path}")
|
print(f"Save path: {config.save_path}")
|
||||||
print("="*60 + "\n")
|
print("=" * 60 + "\n")
|
||||||
|
|
||||||
# Attach to vLLM's shared tensors
|
# Attach to vLLM's shared tensors
|
||||||
print("[1/2] Attaching to vLLM's shared tensors...")
|
print("[1/2] Attaching to vLLM's shared tensors...")
|
||||||
|
|
@ -331,11 +369,15 @@ def train_shared_vllm(config: TrainingConfig):
|
||||||
data_fetch_start = time.time()
|
data_fetch_start = time.time()
|
||||||
if len(batches) == 0:
|
if len(batches) == 0:
|
||||||
batches, _ = get_data(
|
batches, _ = get_data(
|
||||||
config.batch_size, config.seq_len, config.atropos_url,
|
config.batch_size,
|
||||||
|
config.seq_len,
|
||||||
|
config.atropos_url,
|
||||||
extract_inference_logprobs=True, # Enable proper GRPO with reference logprobs
|
extract_inference_logprobs=True, # Enable proper GRPO with reference logprobs
|
||||||
)
|
)
|
||||||
batch_data = batches.pop(0)
|
batch_data = batches.pop(0)
|
||||||
token_batches, label_batches, advantage_batches, temperature_batches = batch_data[:4]
|
token_batches, label_batches, advantage_batches, temperature_batches = (
|
||||||
|
batch_data[:4]
|
||||||
|
)
|
||||||
inference_logprob_batches = batch_data[4] if len(batch_data) > 4 else None
|
inference_logprob_batches = batch_data[4] if len(batch_data) > 4 else None
|
||||||
data_fetch_time = time.time() - data_fetch_start
|
data_fetch_time = time.time() - data_fetch_start
|
||||||
benchmark_stats["data_fetch_times"].append(data_fetch_time)
|
benchmark_stats["data_fetch_times"].append(data_fetch_time)
|
||||||
|
|
@ -343,8 +385,12 @@ def train_shared_vllm(config: TrainingConfig):
|
||||||
# Training step with proper GRPO (importance sampling + KL penalty)
|
# Training step with proper GRPO (importance sampling + KL penalty)
|
||||||
step_start = time.time()
|
step_start = time.time()
|
||||||
metrics = run_training_step(
|
metrics = run_training_step(
|
||||||
model, optimizer,
|
model,
|
||||||
token_batches, label_batches, advantage_batches, temperature_batches,
|
optimizer,
|
||||||
|
token_batches,
|
||||||
|
label_batches,
|
||||||
|
advantage_batches,
|
||||||
|
temperature_batches,
|
||||||
config,
|
config,
|
||||||
inference_logprob_batches=inference_logprob_batches, # Pass for GRPO ratio computation
|
inference_logprob_batches=inference_logprob_batches, # Pass for GRPO ratio computation
|
||||||
)
|
)
|
||||||
|
|
@ -352,8 +398,12 @@ def train_shared_vllm(config: TrainingConfig):
|
||||||
benchmark_stats["step_times"].append(step_time)
|
benchmark_stats["step_times"].append(step_time)
|
||||||
|
|
||||||
# GPU memory tracking
|
# GPU memory tracking
|
||||||
gpu_mem_gb = torch.cuda.memory_allocated() / 1e9 if torch.cuda.is_available() else 0
|
gpu_mem_gb = (
|
||||||
gpu_mem_reserved_gb = torch.cuda.memory_reserved() / 1e9 if torch.cuda.is_available() else 0
|
torch.cuda.memory_allocated() / 1e9 if torch.cuda.is_available() else 0
|
||||||
|
)
|
||||||
|
gpu_mem_reserved_gb = (
|
||||||
|
torch.cuda.memory_reserved() / 1e9 if torch.cuda.is_available() else 0
|
||||||
|
)
|
||||||
benchmark_stats["gpu_memories"].append(gpu_mem_gb)
|
benchmark_stats["gpu_memories"].append(gpu_mem_gb)
|
||||||
|
|
||||||
# In single-copy mode, weights are updated in-place (no sync needed!)
|
# In single-copy mode, weights are updated in-place (no sync needed!)
|
||||||
|
|
@ -362,23 +412,37 @@ def train_shared_vllm(config: TrainingConfig):
|
||||||
benchmark_stats["sync_times"].append(sync_time)
|
benchmark_stats["sync_times"].append(sync_time)
|
||||||
|
|
||||||
# Update metrics
|
# Update metrics
|
||||||
metrics.update({
|
metrics.update(
|
||||||
"step_time": step_time,
|
{
|
||||||
"sync_time": sync_time,
|
"step_time": step_time,
|
||||||
"data_fetch_time": data_fetch_time,
|
"sync_time": sync_time,
|
||||||
"gpu_memory_gb": gpu_mem_gb,
|
"data_fetch_time": data_fetch_time,
|
||||||
"gpu_memory_reserved_gb": gpu_mem_reserved_gb,
|
"gpu_memory_gb": gpu_mem_gb,
|
||||||
})
|
"gpu_memory_reserved_gb": gpu_mem_reserved_gb,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
log_metrics(metrics, step + 1, use_wandb, benchmark=config.benchmark)
|
log_metrics(metrics, step + 1, use_wandb, benchmark=config.benchmark)
|
||||||
|
|
||||||
# Periodic checkpoint (for recovery, not for vLLM sync)
|
# Periodic checkpoint (for recovery, not for vLLM sync)
|
||||||
if config.checkpoint_interval > 0 and (step + 1) % config.checkpoint_interval == 0:
|
if (
|
||||||
|
config.checkpoint_interval > 0
|
||||||
|
and (step + 1) % config.checkpoint_interval == 0
|
||||||
|
):
|
||||||
save_checkpoint(model, tokenizer, config.save_path, step + 1)
|
save_checkpoint(model, tokenizer, config.save_path, step + 1)
|
||||||
|
|
||||||
# === Cleanup ===
|
# === Cleanup ===
|
||||||
save_checkpoint(model, tokenizer, config.save_path, config.training_steps, is_final=True)
|
save_checkpoint(
|
||||||
finalize_training(use_wandb, training_start_time, "shared_vllm", config.training_steps, benchmark_stats, config.benchmark)
|
model, tokenizer, config.save_path, config.training_steps, is_final=True
|
||||||
|
)
|
||||||
|
finalize_training(
|
||||||
|
use_wandb,
|
||||||
|
training_start_time,
|
||||||
|
"shared_vllm",
|
||||||
|
config.training_steps,
|
||||||
|
benchmark_stats,
|
||||||
|
config.benchmark,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def train_lora(config: TrainingConfig):
|
def train_lora(config: TrainingConfig):
|
||||||
|
|
@ -399,29 +463,33 @@ def train_lora(config: TrainingConfig):
|
||||||
- External vLLM server running with --enable-lora
|
- External vLLM server running with --enable-lora
|
||||||
"""
|
"""
|
||||||
if not PEFT_AVAILABLE:
|
if not PEFT_AVAILABLE:
|
||||||
raise RuntimeError("PEFT library required for LoRA mode. Install with: pip install peft")
|
raise RuntimeError(
|
||||||
|
"PEFT library required for LoRA mode. Install with: pip install peft"
|
||||||
|
)
|
||||||
|
|
||||||
training_start_time = time.time()
|
training_start_time = time.time()
|
||||||
|
|
||||||
# === Setup ===
|
# === Setup ===
|
||||||
use_wandb = setup_wandb(config)
|
use_wandb = setup_wandb(config)
|
||||||
|
|
||||||
print("\n" + "="*60)
|
print("\n" + "=" * 60)
|
||||||
print("LORA MODE (adapter-only training)")
|
print("LORA MODE (adapter-only training)")
|
||||||
print("="*60)
|
print("=" * 60)
|
||||||
print(f"Base model: {config.model_name}")
|
print(f"Base model: {config.model_name}")
|
||||||
print(f"LoRA config: r={config.lora_r}, alpha={config.lora_alpha}")
|
print(f"LoRA config: r={config.lora_r}, alpha={config.lora_alpha}")
|
||||||
print(f"Save path: {config.save_path}")
|
print(f"Save path: {config.save_path}")
|
||||||
print(f"vLLM port: {config.vllm_port}")
|
print(f"vLLM port: {config.vllm_port}")
|
||||||
print("="*60 + "\n")
|
print("=" * 60 + "\n")
|
||||||
|
|
||||||
# Check external vLLM server
|
# Check external vLLM server
|
||||||
print("[1/3] Checking external vLLM server...")
|
print("[1/3] Checking external vLLM server...")
|
||||||
if not check_vllm_health(config.vllm_port):
|
if not check_vllm_health(config.vllm_port):
|
||||||
print(f"\nERROR: vLLM server not running on port {config.vllm_port}")
|
print(f"\nERROR: vLLM server not running on port {config.vllm_port}")
|
||||||
print("\nLoRA mode requires an external vLLM server. Start it first:")
|
print("\nLoRA mode requires an external vLLM server. Start it first:")
|
||||||
print(f" python example_trainer/vllm_api_server.py --model {config.model_name} "
|
print(
|
||||||
f"--port {config.vllm_port} --enable-lora --enforce-eager")
|
f" python example_trainer/vllm_api_server.py --model {config.model_name} "
|
||||||
|
f"--port {config.vllm_port} --enable-lora --enforce-eager"
|
||||||
|
)
|
||||||
raise RuntimeError(f"External vLLM server required on port {config.vllm_port}")
|
raise RuntimeError(f"External vLLM server required on port {config.vllm_port}")
|
||||||
print(f"vLLM server healthy on port {config.vllm_port}")
|
print(f"vLLM server healthy on port {config.vllm_port}")
|
||||||
|
|
||||||
|
|
@ -459,10 +527,16 @@ def train_lora(config: TrainingConfig):
|
||||||
# Fetch data (with inference logprobs for proper GRPO)
|
# Fetch data (with inference logprobs for proper GRPO)
|
||||||
data_fetch_start = time.time()
|
data_fetch_start = time.time()
|
||||||
if len(batches) == 0:
|
if len(batches) == 0:
|
||||||
batches, _ = get_data(config.batch_size, config.seq_len, config.atropos_url,
|
batches, _ = get_data(
|
||||||
extract_inference_logprobs=True)
|
config.batch_size,
|
||||||
|
config.seq_len,
|
||||||
|
config.atropos_url,
|
||||||
|
extract_inference_logprobs=True,
|
||||||
|
)
|
||||||
batch_data = batches.pop(0)
|
batch_data = batches.pop(0)
|
||||||
token_batches, label_batches, advantage_batches, temperature_batches = batch_data[:4]
|
token_batches, label_batches, advantage_batches, temperature_batches = (
|
||||||
|
batch_data[:4]
|
||||||
|
)
|
||||||
inference_logprob_batches = batch_data[4] if len(batch_data) > 4 else None
|
inference_logprob_batches = batch_data[4] if len(batch_data) > 4 else None
|
||||||
data_fetch_time = time.time() - data_fetch_start
|
data_fetch_time = time.time() - data_fetch_start
|
||||||
benchmark_stats["data_fetch_times"].append(data_fetch_time)
|
benchmark_stats["data_fetch_times"].append(data_fetch_time)
|
||||||
|
|
@ -470,8 +544,12 @@ def train_lora(config: TrainingConfig):
|
||||||
# Training step with proper GRPO
|
# Training step with proper GRPO
|
||||||
step_start = time.time()
|
step_start = time.time()
|
||||||
metrics = run_training_step(
|
metrics = run_training_step(
|
||||||
model, optimizer,
|
model,
|
||||||
token_batches, label_batches, advantage_batches, temperature_batches,
|
optimizer,
|
||||||
|
token_batches,
|
||||||
|
label_batches,
|
||||||
|
advantage_batches,
|
||||||
|
temperature_batches,
|
||||||
config,
|
config,
|
||||||
inference_logprob_batches=inference_logprob_batches,
|
inference_logprob_batches=inference_logprob_batches,
|
||||||
)
|
)
|
||||||
|
|
@ -479,8 +557,12 @@ def train_lora(config: TrainingConfig):
|
||||||
benchmark_stats["step_times"].append(step_time)
|
benchmark_stats["step_times"].append(step_time)
|
||||||
|
|
||||||
# GPU memory tracking
|
# GPU memory tracking
|
||||||
gpu_mem_gb = torch.cuda.memory_allocated() / 1e9 if torch.cuda.is_available() else 0
|
gpu_mem_gb = (
|
||||||
gpu_mem_reserved_gb = torch.cuda.memory_reserved() / 1e9 if torch.cuda.is_available() else 0
|
torch.cuda.memory_allocated() / 1e9 if torch.cuda.is_available() else 0
|
||||||
|
)
|
||||||
|
gpu_mem_reserved_gb = (
|
||||||
|
torch.cuda.memory_reserved() / 1e9 if torch.cuda.is_available() else 0
|
||||||
|
)
|
||||||
benchmark_stats["gpu_memories"].append(gpu_mem_gb)
|
benchmark_stats["gpu_memories"].append(gpu_mem_gb)
|
||||||
|
|
||||||
# Periodic adapter save + hot-swap
|
# Periodic adapter save + hot-swap
|
||||||
|
|
@ -494,24 +576,35 @@ def train_lora(config: TrainingConfig):
|
||||||
benchmark_stats["sync_times"].append(sync_time)
|
benchmark_stats["sync_times"].append(sync_time)
|
||||||
|
|
||||||
# Update metrics
|
# Update metrics
|
||||||
metrics.update({
|
metrics.update(
|
||||||
"step_time": step_time,
|
{
|
||||||
"sync_time": sync_time,
|
"step_time": step_time,
|
||||||
"data_fetch_time": data_fetch_time,
|
"sync_time": sync_time,
|
||||||
"gpu_memory_gb": gpu_mem_gb,
|
"data_fetch_time": data_fetch_time,
|
||||||
"gpu_memory_reserved_gb": gpu_mem_reserved_gb,
|
"gpu_memory_gb": gpu_mem_gb,
|
||||||
})
|
"gpu_memory_reserved_gb": gpu_mem_reserved_gb,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
log_metrics(metrics, step + 1, use_wandb, benchmark=config.benchmark)
|
log_metrics(metrics, step + 1, use_wandb, benchmark=config.benchmark)
|
||||||
|
|
||||||
# === Cleanup ===
|
# === Cleanup ===
|
||||||
final_sync_start = time.time()
|
final_sync_start = time.time()
|
||||||
final_adapter_path = save_lora_checkpoint(model, config.save_path, config.training_steps, is_final=True)
|
final_adapter_path = save_lora_checkpoint(
|
||||||
|
model, config.save_path, config.training_steps, is_final=True
|
||||||
|
)
|
||||||
_hotswap_lora_adapter(config.vllm_port, final_adapter_path, "final")
|
_hotswap_lora_adapter(config.vllm_port, final_adapter_path, "final")
|
||||||
final_sync_time = time.time() - final_sync_start
|
final_sync_time = time.time() - final_sync_start
|
||||||
benchmark_stats["sync_times"].append(final_sync_time)
|
benchmark_stats["sync_times"].append(final_sync_time)
|
||||||
|
|
||||||
finalize_training(use_wandb, training_start_time, "lora_only", config.training_steps, benchmark_stats, config.benchmark)
|
finalize_training(
|
||||||
|
use_wandb,
|
||||||
|
training_start_time,
|
||||||
|
"lora_only",
|
||||||
|
config.training_steps,
|
||||||
|
benchmark_stats,
|
||||||
|
config.benchmark,
|
||||||
|
)
|
||||||
|
|
||||||
# Save tokenizer
|
# Save tokenizer
|
||||||
tokenizer_path = os.path.join(config.save_path, "tokenizer")
|
tokenizer_path = os.path.join(config.save_path, "tokenizer")
|
||||||
|
|
@ -563,4 +656,3 @@ def _hotswap_lora_adapter(
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f" [LORA] ✗ Hot-swap request failed: {e}")
|
print(f" [LORA] ✗ Hot-swap request failed: {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -18,10 +18,10 @@ import wandb
|
||||||
|
|
||||||
from .config import TrainingConfig
|
from .config import TrainingConfig
|
||||||
|
|
||||||
|
|
||||||
# Global storage for logprob alignment stats
|
# Global storage for logprob alignment stats
|
||||||
_logprob_alignment_stats: Dict[str, float] = {}
|
_logprob_alignment_stats: Dict[str, float] = {}
|
||||||
|
|
||||||
|
|
||||||
def setup_wandb(config: TrainingConfig) -> bool:
|
def setup_wandb(config: TrainingConfig) -> bool:
|
||||||
"""
|
"""
|
||||||
Initialize Weights & Biases logging if enabled.
|
Initialize Weights & Biases logging if enabled.
|
||||||
|
|
@ -134,7 +134,9 @@ def compute_grpo_loss(
|
||||||
# === GRPO/PPO Loss Computation ===
|
# === GRPO/PPO Loss Computation ===
|
||||||
if use_reference_logprobs and inference_logprobs is not None:
|
if use_reference_logprobs and inference_logprobs is not None:
|
||||||
# Move inference logprobs to correct device/dtype
|
# Move inference logprobs to correct device/dtype
|
||||||
ref_logprobs = inference_logprobs.to(logp_per_token.device, logp_per_token.dtype)
|
ref_logprobs = inference_logprobs.to(
|
||||||
|
logp_per_token.device, logp_per_token.dtype
|
||||||
|
)
|
||||||
|
|
||||||
# NOTE: inference_logprobs uses 1.0 for masked (prompt) positions, actual negative values for generated
|
# NOTE: inference_logprobs uses 1.0 for masked (prompt) positions, actual negative values for generated
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
|
@ -156,10 +158,16 @@ def compute_grpo_loss(
|
||||||
# Check if ref logprobs are negative (as they should be for generated tokens)
|
# Check if ref logprobs are negative (as they should be for generated tokens)
|
||||||
# If ref_at_generated is close to 1.0, that means the 1.0 placeholder is being used
|
# If ref_at_generated is close to 1.0, that means the 1.0 placeholder is being used
|
||||||
if ref_at_generated > 0.5:
|
if ref_at_generated > 0.5:
|
||||||
print(f" [WARNING] ref_logprobs avg {ref_at_generated:.3f} (should be negative!)")
|
print(
|
||||||
print(" [WARNING] This suggests inference_logprobs alignment is wrong")
|
f" [WARNING] ref_logprobs avg {ref_at_generated:.3f} (should be negative!)"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
" [WARNING] This suggests inference_logprobs alignment is wrong"
|
||||||
|
)
|
||||||
elif abs(ref_at_generated - train_at_generated) > 2.0:
|
elif abs(ref_at_generated - train_at_generated) > 2.0:
|
||||||
print(f" [DEBUG] Logprob gap: ref={ref_at_generated:.3f}, train={train_at_generated:.3f}")
|
print(
|
||||||
|
f" [DEBUG] Logprob gap: ref={ref_at_generated:.3f}, train={train_at_generated:.3f}"
|
||||||
|
)
|
||||||
|
|
||||||
# Compute importance sampling ratio: policy(a|s) / policy_old(a|s) = exp(log policy - log policy_old)
|
# Compute importance sampling ratio: policy(a|s) / policy_old(a|s) = exp(log policy - log policy_old)
|
||||||
log_ratio = logp_per_token - ref_logprobs
|
log_ratio = logp_per_token - ref_logprobs
|
||||||
|
|
@ -192,7 +200,9 @@ def compute_grpo_loss(
|
||||||
# = exp(-log_ratio) + log_ratio - 1
|
# = exp(-log_ratio) + log_ratio - 1
|
||||||
kl_per_token = torch.exp(-log_ratio) + log_ratio - 1.0
|
kl_per_token = torch.exp(-log_ratio) + log_ratio - 1.0
|
||||||
kl_penalty = ((kl_per_token * mask).sum(dim=-1) / mask_sum).mean()
|
kl_penalty = ((kl_per_token * mask).sum(dim=-1) / mask_sum).mean()
|
||||||
total_loss = (policy_loss + kl_coef * kl_penalty) / gradient_accumulation_steps
|
total_loss = (
|
||||||
|
policy_loss + kl_coef * kl_penalty
|
||||||
|
) / gradient_accumulation_steps
|
||||||
else:
|
else:
|
||||||
kl_penalty = torch.tensor(0.0, device=logp_per_token.device)
|
kl_penalty = torch.tensor(0.0, device=logp_per_token.device)
|
||||||
total_loss = policy_loss / gradient_accumulation_steps
|
total_loss = policy_loss / gradient_accumulation_steps
|
||||||
|
|
@ -200,7 +210,9 @@ def compute_grpo_loss(
|
||||||
# Compute metrics for logging
|
# Compute metrics for logging
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
# Fraction of tokens where ratio was clipped
|
# Fraction of tokens where ratio was clipped
|
||||||
clipped_fraction = ((ratio < 1.0 - clip_eps) | (ratio > 1.0 + clip_eps)).float()
|
clipped_fraction = (
|
||||||
|
(ratio < 1.0 - clip_eps) | (ratio > 1.0 + clip_eps)
|
||||||
|
).float()
|
||||||
clipped_fraction = (clipped_fraction * mask).sum() / mask.sum()
|
clipped_fraction = (clipped_fraction * mask).sum() / mask.sum()
|
||||||
|
|
||||||
# Mean ratio and KL for monitoring (using Schulman's estimator)
|
# Mean ratio and KL for monitoring (using Schulman's estimator)
|
||||||
|
|
@ -256,7 +268,11 @@ def compute_grpo_loss(
|
||||||
"kl_penalty": kl_penalty.item() if torch.is_tensor(kl_penalty) else kl_penalty,
|
"kl_penalty": kl_penalty.item() if torch.is_tensor(kl_penalty) else kl_penalty,
|
||||||
"mean_ratio": mean_ratio.item() if torch.is_tensor(mean_ratio) else mean_ratio,
|
"mean_ratio": mean_ratio.item() if torch.is_tensor(mean_ratio) else mean_ratio,
|
||||||
"mean_kl": mean_kl.item() if torch.is_tensor(mean_kl) else mean_kl,
|
"mean_kl": mean_kl.item() if torch.is_tensor(mean_kl) else mean_kl,
|
||||||
"clipped_fraction": clipped_fraction.item() if torch.is_tensor(clipped_fraction) else clipped_fraction,
|
"clipped_fraction": (
|
||||||
|
clipped_fraction.item()
|
||||||
|
if torch.is_tensor(clipped_fraction)
|
||||||
|
else clipped_fraction
|
||||||
|
),
|
||||||
# Token-level alignment metrics (key for verifying weight sharing)
|
# Token-level alignment metrics (key for verifying weight sharing)
|
||||||
"logprob_diff_mean": logprob_diff_mean,
|
"logprob_diff_mean": logprob_diff_mean,
|
||||||
"logprob_diff_abs_mean": logprob_diff_abs_mean,
|
"logprob_diff_abs_mean": logprob_diff_abs_mean,
|
||||||
|
|
@ -315,23 +331,25 @@ def run_training_step(
|
||||||
all_inference_logprobs: List[torch.Tensor] = []
|
all_inference_logprobs: List[torch.Tensor] = []
|
||||||
|
|
||||||
# Get GRPO hyperparameters from config
|
# Get GRPO hyperparameters from config
|
||||||
kl_coef = getattr(config, 'kl_coef', 0.1)
|
kl_coef = getattr(config, "kl_coef", 0.1)
|
||||||
clip_eps = getattr(config, 'clip_eps', 0.2)
|
clip_eps = getattr(config, "clip_eps", 0.2)
|
||||||
use_reference_logprobs = getattr(config, 'use_reference_logprobs', True)
|
use_reference_logprobs = getattr(config, "use_reference_logprobs", True)
|
||||||
|
|
||||||
# Accumulate gradients over micro-batches
|
# Accumulate gradients over micro-batches
|
||||||
num_batches = len(token_batches) if token_batches else 1
|
num_batches = len(token_batches) if token_batches else 1
|
||||||
|
|
||||||
for batch_idx, (tokens, labels, advantages, temperatures) in enumerate(zip(
|
for batch_idx, (tokens, labels, advantages, temperatures) in enumerate(
|
||||||
token_batches, label_batches, advantage_batches, temperature_batches
|
zip(token_batches, label_batches, advantage_batches, temperature_batches)
|
||||||
)):
|
):
|
||||||
tokens = tokens.to(config.device)
|
tokens = tokens.to(config.device)
|
||||||
labels = labels.to(config.device)
|
labels = labels.to(config.device)
|
||||||
advantages = advantages.to(config.device)
|
advantages = advantages.to(config.device)
|
||||||
|
|
||||||
# Get corresponding inference logprobs batch if available
|
# Get corresponding inference logprobs batch if available
|
||||||
inf_logprobs = None
|
inf_logprobs = None
|
||||||
if inference_logprob_batches is not None and batch_idx < len(inference_logprob_batches):
|
if inference_logprob_batches is not None and batch_idx < len(
|
||||||
|
inference_logprob_batches
|
||||||
|
):
|
||||||
inf_logprobs = inference_logprob_batches[batch_idx]
|
inf_logprobs = inference_logprob_batches[batch_idx]
|
||||||
|
|
||||||
loss, metrics = compute_grpo_loss(
|
loss, metrics = compute_grpo_loss(
|
||||||
|
|
@ -363,12 +381,17 @@ def run_training_step(
|
||||||
# Accumulate token-level alignment metrics
|
# Accumulate token-level alignment metrics
|
||||||
total_logprob_diff_mean += metrics.get("logprob_diff_mean", 0.0)
|
total_logprob_diff_mean += metrics.get("logprob_diff_mean", 0.0)
|
||||||
total_logprob_diff_abs_mean += metrics.get("logprob_diff_abs_mean", 0.0)
|
total_logprob_diff_abs_mean += metrics.get("logprob_diff_abs_mean", 0.0)
|
||||||
total_logprob_diff_max = max(total_logprob_diff_max, metrics.get("logprob_diff_max", 0.0))
|
total_logprob_diff_max = max(
|
||||||
|
total_logprob_diff_max, metrics.get("logprob_diff_max", 0.0)
|
||||||
|
)
|
||||||
|
|
||||||
# Collect logprobs for alignment monitoring
|
# Collect logprobs for alignment monitoring
|
||||||
if "training_logprobs" in metrics and metrics["training_logprobs"] is not None:
|
if "training_logprobs" in metrics and metrics["training_logprobs"] is not None:
|
||||||
all_training_logprobs.append(metrics["training_logprobs"])
|
all_training_logprobs.append(metrics["training_logprobs"])
|
||||||
if "inference_logprobs" in metrics and metrics["inference_logprobs"] is not None:
|
if (
|
||||||
|
"inference_logprobs" in metrics
|
||||||
|
and metrics["inference_logprobs"] is not None
|
||||||
|
):
|
||||||
all_inference_logprobs.append(metrics["inference_logprobs"])
|
all_inference_logprobs.append(metrics["inference_logprobs"])
|
||||||
|
|
||||||
# Gradient clipping and optimizer step
|
# Gradient clipping and optimizer step
|
||||||
|
|
@ -387,7 +410,7 @@ def run_training_step(
|
||||||
|
|
||||||
result = {
|
result = {
|
||||||
"loss": total_loss,
|
"loss": total_loss,
|
||||||
"grad_norm": grad_norm.item() if hasattr(grad_norm, 'item') else grad_norm,
|
"grad_norm": grad_norm.item() if hasattr(grad_norm, "item") else grad_norm,
|
||||||
"pos_logp": total_pos_logp,
|
"pos_logp": total_pos_logp,
|
||||||
"neg_logp": total_neg_logp,
|
"neg_logp": total_neg_logp,
|
||||||
"pos_count": total_pos,
|
"pos_count": total_pos,
|
||||||
|
|
@ -404,7 +427,9 @@ def run_training_step(
|
||||||
if all_training_logprobs:
|
if all_training_logprobs:
|
||||||
train_flat = torch.cat(all_training_logprobs)
|
train_flat = torch.cat(all_training_logprobs)
|
||||||
if train_flat.numel() > 0:
|
if train_flat.numel() > 0:
|
||||||
_logprob_alignment_stats["logprobs/training_mean"] = train_flat.mean().item()
|
_logprob_alignment_stats["logprobs/training_mean"] = (
|
||||||
|
train_flat.mean().item()
|
||||||
|
)
|
||||||
_logprob_alignment_stats["logprobs/training_std"] = train_flat.std().item()
|
_logprob_alignment_stats["logprobs/training_std"] = train_flat.std().item()
|
||||||
|
|
||||||
if all_inference_logprobs:
|
if all_inference_logprobs:
|
||||||
|
|
@ -415,8 +440,12 @@ def run_training_step(
|
||||||
|
|
||||||
# Token-level alignment metrics - THE key metric for verifying weight sharing
|
# Token-level alignment metrics - THE key metric for verifying weight sharing
|
||||||
# diff_abs_mean close to 0 = weights are truly shared
|
# diff_abs_mean close to 0 = weights are truly shared
|
||||||
_logprob_alignment_stats["alignment/diff_mean"] = total_logprob_diff_mean / num_batches
|
_logprob_alignment_stats["alignment/diff_mean"] = (
|
||||||
_logprob_alignment_stats["alignment/diff_abs_mean"] = total_logprob_diff_abs_mean / num_batches
|
total_logprob_diff_mean / num_batches
|
||||||
|
)
|
||||||
|
_logprob_alignment_stats["alignment/diff_abs_mean"] = (
|
||||||
|
total_logprob_diff_abs_mean / num_batches
|
||||||
|
)
|
||||||
_logprob_alignment_stats["alignment/diff_max"] = total_logprob_diff_max
|
_logprob_alignment_stats["alignment/diff_max"] = total_logprob_diff_max
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
@ -495,8 +524,13 @@ def log_metrics(
|
||||||
"grpo/clipped_fraction": clipped_frac,
|
"grpo/clipped_fraction": clipped_frac,
|
||||||
}
|
}
|
||||||
# Add timing metrics if present
|
# Add timing metrics if present
|
||||||
for key in ["step_time", "sync_time", "data_fetch_time",
|
for key in [
|
||||||
"gpu_memory_gb", "gpu_memory_reserved_gb"]:
|
"step_time",
|
||||||
|
"sync_time",
|
||||||
|
"data_fetch_time",
|
||||||
|
"gpu_memory_gb",
|
||||||
|
"gpu_memory_reserved_gb",
|
||||||
|
]:
|
||||||
if key in metrics:
|
if key in metrics:
|
||||||
log_dict[f"train/{key}"] = metrics[key]
|
log_dict[f"train/{key}"] = metrics[key]
|
||||||
|
|
||||||
|
|
@ -549,7 +583,9 @@ def finalize_training(
|
||||||
total_step_time = sum(step_times)
|
total_step_time = sum(step_times)
|
||||||
avg_sync_time = sum(sync_times) / len(sync_times) if sync_times else 0
|
avg_sync_time = sum(sync_times) / len(sync_times) if sync_times else 0
|
||||||
total_sync_time = sum(sync_times)
|
total_sync_time = sum(sync_times)
|
||||||
avg_data_fetch = sum(data_fetch_times) / len(data_fetch_times) if data_fetch_times else 0
|
avg_data_fetch = (
|
||||||
|
sum(data_fetch_times) / len(data_fetch_times) if data_fetch_times else 0
|
||||||
|
)
|
||||||
total_data_fetch = sum(data_fetch_times)
|
total_data_fetch = sum(data_fetch_times)
|
||||||
avg_gpu_mem = sum(gpu_memories) / len(gpu_memories) if gpu_memories else 0
|
avg_gpu_mem = sum(gpu_memories) / len(gpu_memories) if gpu_memories else 0
|
||||||
|
|
||||||
|
|
@ -557,13 +593,17 @@ def finalize_training(
|
||||||
print(f"\n{'='*70}")
|
print(f"\n{'='*70}")
|
||||||
print(f"BENCHMARK SUMMARY ({mode})")
|
print(f"BENCHMARK SUMMARY ({mode})")
|
||||||
print(f"{'='*70}")
|
print(f"{'='*70}")
|
||||||
print(f" Total training time: {total_time:.2f}s ({total_time/60:.2f} min)")
|
print(
|
||||||
|
f" Total training time: {total_time:.2f}s ({total_time/60:.2f} min)"
|
||||||
|
)
|
||||||
print(f" Total steps: {total_steps}")
|
print(f" Total steps: {total_steps}")
|
||||||
print(" ")
|
print(" ")
|
||||||
print(" TIMING BREAKDOWN:")
|
print(" TIMING BREAKDOWN:")
|
||||||
print(f" Avg step time: {avg_step_time:.2f}s")
|
print(f" Avg step time: {avg_step_time:.2f}s")
|
||||||
print(f" Total step time: {total_step_time:.2f}s")
|
print(f" Total step time: {total_step_time:.2f}s")
|
||||||
print(f" Avg sync time: {avg_sync_time:.2f}s (x{len(sync_times)} syncs)")
|
print(
|
||||||
|
f" Avg sync time: {avg_sync_time:.2f}s (x{len(sync_times)} syncs)"
|
||||||
|
)
|
||||||
print(f" Total sync time: {total_sync_time:.2f}s")
|
print(f" Total sync time: {total_sync_time:.2f}s")
|
||||||
print(f" Avg data fetch time: {avg_data_fetch:.2f}s")
|
print(f" Avg data fetch time: {avg_data_fetch:.2f}s")
|
||||||
print(f" Total data fetch time: {total_data_fetch:.2f}s")
|
print(f" Total data fetch time: {total_data_fetch:.2f}s")
|
||||||
|
|
@ -584,4 +624,3 @@ def finalize_training(
|
||||||
wandb.finish()
|
wandb.finish()
|
||||||
elif use_wandb:
|
elif use_wandb:
|
||||||
wandb.finish()
|
wandb.finish()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -53,7 +53,7 @@ os.environ.setdefault("VLLM_USE_V1", "0")
|
||||||
# Set spawn method for multiprocessing (required for CUDA)
|
# Set spawn method for multiprocessing (required for CUDA)
|
||||||
os.environ.setdefault("VLLM_WORKER_MULTIPROC_METHOD", "spawn")
|
os.environ.setdefault("VLLM_WORKER_MULTIPROC_METHOD", "spawn")
|
||||||
try:
|
try:
|
||||||
multiprocessing.set_start_method('spawn', force=True)
|
multiprocessing.set_start_method("spawn", force=True)
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
pass # Already set
|
pass # Already set
|
||||||
|
|
||||||
|
|
@ -86,6 +86,7 @@ def _apply_patches_early() -> bool:
|
||||||
try:
|
try:
|
||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
# Add parent directory to path so we can import vllm_patching
|
# Add parent directory to path so we can import vllm_patching
|
||||||
script_dir = Path(__file__).parent
|
script_dir = Path(__file__).parent
|
||||||
if str(script_dir) not in sys.path:
|
if str(script_dir) not in sys.path:
|
||||||
|
|
@ -106,6 +107,7 @@ def _apply_patches_early() -> bool:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"[vLLM Server] Error applying patches: {e}")
|
print(f"[vLLM Server] Error applying patches: {e}")
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
@ -145,17 +147,20 @@ except ImportError:
|
||||||
|
|
||||||
def add_argument(self, *args, **kwargs):
|
def add_argument(self, *args, **kwargs):
|
||||||
# Remove 'deprecated' kwarg if present (not supported before Python 3.13)
|
# Remove 'deprecated' kwarg if present (not supported before Python 3.13)
|
||||||
kwargs.pop('deprecated', None)
|
kwargs.pop("deprecated", None)
|
||||||
return super().add_argument(*args, **kwargs)
|
return super().add_argument(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
# set_ulimit might not exist in all vLLM versions
|
# set_ulimit might not exist in all vLLM versions
|
||||||
try:
|
try:
|
||||||
from vllm.utils import set_ulimit
|
from vllm.utils import set_ulimit
|
||||||
except ImportError:
|
except ImportError:
|
||||||
|
|
||||||
def set_ulimit() -> None:
|
def set_ulimit() -> None:
|
||||||
"""No-op fallback for set_ulimit."""
|
"""No-op fallback for set_ulimit."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
from vllm.outputs import RequestOutput # noqa: F401, E402
|
from vllm.outputs import RequestOutput # noqa: F401, E402
|
||||||
from vllm.version import __version__ as VLLM_VERSION # noqa: E402
|
from vllm.version import __version__ as VLLM_VERSION # noqa: E402
|
||||||
|
|
||||||
|
|
@ -602,7 +607,9 @@ async def lora_load(request: LoraLoadRequest) -> JSONResponse:
|
||||||
) # vLLM needs unique int ID
|
) # vLLM needs unique int ID
|
||||||
bridge_state.lora_load_count += 1
|
bridge_state.lora_load_count += 1
|
||||||
|
|
||||||
logger.info(f"LoRA adapter loaded: {request.adapter_path} (id={bridge_state.active_lora_id})")
|
logger.info(
|
||||||
|
f"LoRA adapter loaded: {request.adapter_path} (id={bridge_state.active_lora_id})"
|
||||||
|
)
|
||||||
|
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
{
|
{
|
||||||
|
|
|
||||||
|
|
@ -17,7 +17,6 @@ import requests
|
||||||
|
|
||||||
from .config import TrainingConfig
|
from .config import TrainingConfig
|
||||||
|
|
||||||
|
|
||||||
# Global variable to keep track of the vLLM process
|
# Global variable to keep track of the vLLM process
|
||||||
_vllm_process: Optional[subprocess.Popen] = None
|
_vllm_process: Optional[subprocess.Popen] = None
|
||||||
|
|
||||||
|
|
@ -25,7 +24,7 @@ _vllm_process: Optional[subprocess.Popen] = None
|
||||||
def is_port_in_use(port: int) -> bool:
|
def is_port_in_use(port: int) -> bool:
|
||||||
"""Check if a port is already in use."""
|
"""Check if a port is already in use."""
|
||||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||||
return s.connect_ex(('localhost', port)) == 0
|
return s.connect_ex(("localhost", port)) == 0
|
||||||
|
|
||||||
|
|
||||||
def kill_process_on_port(port: int, timeout: float = 5.0) -> bool:
|
def kill_process_on_port(port: int, timeout: float = 5.0) -> bool:
|
||||||
|
|
@ -42,13 +41,10 @@ def kill_process_on_port(port: int, timeout: float = 5.0) -> bool:
|
||||||
try:
|
try:
|
||||||
# Try to find and kill the process using lsof (Linux/Mac)
|
# Try to find and kill the process using lsof (Linux/Mac)
|
||||||
result = subprocess.run(
|
result = subprocess.run(
|
||||||
["lsof", "-t", "-i", f":{port}"],
|
["lsof", "-t", "-i", f":{port}"], capture_output=True, text=True, timeout=5
|
||||||
capture_output=True,
|
|
||||||
text=True,
|
|
||||||
timeout=5
|
|
||||||
)
|
)
|
||||||
if result.stdout.strip():
|
if result.stdout.strip():
|
||||||
pids = result.stdout.strip().split('\n')
|
pids = result.stdout.strip().split("\n")
|
||||||
for pid in pids:
|
for pid in pids:
|
||||||
try:
|
try:
|
||||||
os.kill(int(pid), signal.SIGTERM)
|
os.kill(int(pid), signal.SIGTERM)
|
||||||
|
|
@ -135,7 +131,9 @@ def launch_vllm_server(
|
||||||
if is_port_in_use(config.vllm_port):
|
if is_port_in_use(config.vllm_port):
|
||||||
print(f" WARNING: Port {config.vllm_port} is already in use!")
|
print(f" WARNING: Port {config.vllm_port} is already in use!")
|
||||||
if not kill_process_on_port(config.vllm_port):
|
if not kill_process_on_port(config.vllm_port):
|
||||||
print(f" ERROR: Could not free port {config.vllm_port}. Please manually kill the process.")
|
print(
|
||||||
|
f" ERROR: Could not free port {config.vllm_port}. Please manually kill the process."
|
||||||
|
)
|
||||||
print(f" Try: lsof -i :{config.vllm_port} | grep LISTEN")
|
print(f" Try: lsof -i :{config.vllm_port} | grep LISTEN")
|
||||||
print(f" Or: pkill -f 'vllm.*{config.vllm_port}'")
|
print(f" Or: pkill -f 'vllm.*{config.vllm_port}'")
|
||||||
return None
|
return None
|
||||||
|
|
@ -209,7 +207,9 @@ def check_vllm_process_health() -> None:
|
||||||
global _vllm_process
|
global _vllm_process
|
||||||
|
|
||||||
if _vllm_process is not None and _vllm_process.poll() is not None:
|
if _vllm_process is not None and _vllm_process.poll() is not None:
|
||||||
print(f" WARNING: vLLM terminated unexpectedly (code: {_vllm_process.returncode})")
|
print(
|
||||||
|
f" WARNING: vLLM terminated unexpectedly (code: {_vllm_process.returncode})"
|
||||||
|
)
|
||||||
_vllm_process = None
|
_vllm_process = None
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -299,7 +299,9 @@ def hotswap_lora_adapter(
|
||||||
print(f" [LORA] ✓ Hot-swapped adapter: {adapter_name} ({adapter_path})")
|
print(f" [LORA] ✓ Hot-swapped adapter: {adapter_name} ({adapter_path})")
|
||||||
return True
|
return True
|
||||||
else:
|
else:
|
||||||
print(f" [LORA] ✗ Hot-swap failed: {response.status_code} - {response.text}")
|
print(
|
||||||
|
f" [LORA] ✗ Hot-swap failed: {response.status_code} - {response.text}"
|
||||||
|
)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
except requests.exceptions.ConnectionError:
|
except requests.exceptions.ConnectionError:
|
||||||
|
|
@ -308,4 +310,3 @@ def hotswap_lora_adapter(
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f" [LORA] ✗ Error during hot-swap: {e}")
|
print(f" [LORA] ✗ Error during hot-swap: {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -37,6 +37,7 @@ def _patch_lora_triton_for_blackwell() -> bool:
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
import vllm
|
import vllm
|
||||||
|
|
||||||
vllm_path = vllm.__path__[0]
|
vllm_path = vllm.__path__[0]
|
||||||
kernel_utils_path = f"{vllm_path}/lora/ops/triton_ops/kernel_utils.py"
|
kernel_utils_path = f"{vllm_path}/lora/ops/triton_ops/kernel_utils.py"
|
||||||
|
|
||||||
|
|
@ -45,42 +46,42 @@ def _patch_lora_triton_for_blackwell() -> bool:
|
||||||
print("[vLLM Patch] LoRA kernel_utils.py not found, skipping GDC patch")
|
print("[vLLM Patch] LoRA kernel_utils.py not found, skipping GDC patch")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
with open(kernel_utils_path, 'r') as f:
|
with open(kernel_utils_path, "r") as f:
|
||||||
content = f.read()
|
content = f.read()
|
||||||
|
|
||||||
# Check if already patched
|
# Check if already patched
|
||||||
if 'PATCHED FOR B200' in content:
|
if "PATCHED FOR B200" in content:
|
||||||
print("[vLLM Patch] LoRA GDC already patched for B200")
|
print("[vLLM Patch] LoRA GDC already patched for B200")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
modified = False
|
modified = False
|
||||||
|
|
||||||
# Patch USE_GDC = True -> False
|
# Patch USE_GDC = True -> False
|
||||||
if 'USE_GDC = True' in content:
|
if "USE_GDC = True" in content:
|
||||||
content = content.replace(
|
content = content.replace(
|
||||||
'USE_GDC = True',
|
"USE_GDC = True",
|
||||||
'USE_GDC = False # PATCHED FOR B200 - GDC causes Triton compilation failure'
|
"USE_GDC = False # PATCHED FOR B200 - GDC causes Triton compilation failure",
|
||||||
)
|
)
|
||||||
modified = True
|
modified = True
|
||||||
|
|
||||||
# Patch USE_GDC: tl.constexpr = True -> False
|
# Patch USE_GDC: tl.constexpr = True -> False
|
||||||
if 'USE_GDC: tl.constexpr = True' in content:
|
if "USE_GDC: tl.constexpr = True" in content:
|
||||||
content = content.replace(
|
content = content.replace(
|
||||||
'USE_GDC: tl.constexpr = True',
|
"USE_GDC: tl.constexpr = True",
|
||||||
'USE_GDC: tl.constexpr = False # PATCHED FOR B200'
|
"USE_GDC: tl.constexpr = False # PATCHED FOR B200",
|
||||||
)
|
)
|
||||||
modified = True
|
modified = True
|
||||||
|
|
||||||
# Patch the gdc_wait call itself
|
# Patch the gdc_wait call itself
|
||||||
if 'tl.extra.cuda.gdc_wait()' in content:
|
if "tl.extra.cuda.gdc_wait()" in content:
|
||||||
content = content.replace(
|
content = content.replace(
|
||||||
'tl.extra.cuda.gdc_wait()',
|
"tl.extra.cuda.gdc_wait()",
|
||||||
'pass # tl.extra.cuda.gdc_wait() PATCHED FOR B200 - disabled'
|
"pass # tl.extra.cuda.gdc_wait() PATCHED FOR B200 - disabled",
|
||||||
)
|
)
|
||||||
modified = True
|
modified = True
|
||||||
|
|
||||||
if modified:
|
if modified:
|
||||||
with open(kernel_utils_path, 'w') as f:
|
with open(kernel_utils_path, "w") as f:
|
||||||
f.write(content)
|
f.write(content)
|
||||||
print(f"[vLLM Patch] ✓ Patched LoRA Triton GDC in {kernel_utils_path}")
|
print(f"[vLLM Patch] ✓ Patched LoRA Triton GDC in {kernel_utils_path}")
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue