[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
This commit is contained in:
pre-commit-ci[bot] 2026-02-06 06:46:14 +00:00 committed by Jai Suphavadeeprasit
parent d07ab3e3ce
commit 5cfd1929f1
19 changed files with 708 additions and 452 deletions

View file

@ -373,7 +373,9 @@ class GSM8KEvalEnv(BaseEnv):
# Create evaluation tasks # Create evaluation tasks
async def eval_task(item): async def eval_task(item):
return await self.rollout_and_score_eval(item, self.server.servers[0].config) return await self.rollout_and_score_eval(
item, self.server.servers[0].config
)
tasks = [eval_task(item) for item in self.eval_items] tasks = [eval_task(item) for item in self.eval_items]

View file

@ -299,6 +299,7 @@ class MathEnv(BaseEnv):
if not self.config.run_evaluation: if not self.config.run_evaluation:
return return
import time import time
start_time = time.time() start_time = time.time()
eval_tasks = [] eval_tasks = []
@ -320,9 +321,7 @@ class MathEnv(BaseEnv):
metrics[f"{subset}_accuracy"] = accuracy metrics[f"{subset}_accuracy"] = accuracy
metrics[f"{subset}_total"] = len(scores) metrics[f"{subset}_total"] = len(scores)
metrics[f"{subset}_correct"] = sum(scores) metrics[f"{subset}_correct"] = sum(scores)
self.eval_metrics.append( self.eval_metrics.append((f"eval/{subset}_percent_correct", accuracy))
(f"eval/{subset}_percent_correct", accuracy)
)
# overall score # overall score
all_scores = [] all_scores = []
@ -332,9 +331,7 @@ class MathEnv(BaseEnv):
metrics["overall_accuracy"] = overall_accuracy metrics["overall_accuracy"] = overall_accuracy
metrics["overall_total"] = len(all_scores) metrics["overall_total"] = len(all_scores)
metrics["overall_correct"] = sum(all_scores) metrics["overall_correct"] = sum(all_scores)
self.eval_metrics.append( self.eval_metrics.append(("eval/overall_percent_correct", overall_accuracy))
("eval/overall_percent_correct", overall_accuracy)
)
end_time = time.time() end_time = time.time()
@ -342,7 +339,9 @@ class MathEnv(BaseEnv):
print("\n" + "=" * 60) print("\n" + "=" * 60)
print("Math Zero Evaluation Results") print("Math Zero Evaluation Results")
print("=" * 60) print("=" * 60)
print(f"Overall Accuracy: {overall_accuracy:.2%} ({sum(all_scores)}/{len(all_scores)})") print(
f"Overall Accuracy: {overall_accuracy:.2%} ({sum(all_scores)}/{len(all_scores)})"
)
print("\nPer-subset breakdown:") print("\nPer-subset breakdown:")
for subset, scores in sorted(task_lists.items()): for subset, scores in sorted(task_lists.items()):
acc = sum(scores) / len(scores) acc = sum(scores) / len(scores)

View file

@ -542,4 +542,3 @@ python -m example_trainer.vllm_api_server # NOT direct vllm commands
| `vllm_api_server.py` | Streamlined vLLM server for training | | `vllm_api_server.py` | Streamlined vLLM server for training |
| `vllm_manager.py` | vLLM process lifecycle management | | `vllm_manager.py` | vLLM process lifecycle management |
| `checkpointing.py` | Save/load checkpoints and adapters | | `checkpointing.py` | Save/load checkpoints and adapters |

View file

@ -20,9 +20,9 @@ Usage:
train_legacy(config) train_legacy(config)
""" """
from .cli import config_from_args, parse_args
from .config import TrainingConfig from .config import TrainingConfig
from .trainers import train_legacy, train_shared_vllm, train_lora from .trainers import train_legacy, train_lora, train_shared_vllm
from .cli import parse_args, config_from_args
__all__ = [ __all__ = [
"TrainingConfig", "TrainingConfig",

View file

@ -15,7 +15,9 @@ from tenacity import retry, stop_after_attempt, wait_exponential
from .config import TrainingConfig from .config import TrainingConfig
def check_atropos_api(url: str = "http://localhost:8000", timeout: float = 30.0) -> bool: def check_atropos_api(
url: str = "http://localhost:8000", timeout: float = 30.0
) -> bool:
""" """
Check if the Atropos API server is reachable. Check if the Atropos API server is reachable.
@ -99,4 +101,3 @@ def get_batch(url: str = "http://localhost:8000"):
raise RuntimeError(f"Atropos API error: {data.get('message', 'Unknown error')}") raise RuntimeError(f"Atropos API error: {data.get('message', 'Unknown error')}")
return data return data

View file

@ -89,11 +89,14 @@ def save_checkpoint(
# Count how many were non-contiguous (views into fused tensors) # Count how many were non-contiguous (views into fused tensors)
view_count = sum( view_count = sum(
1 for name, param in model.named_parameters() 1
for name, param in model.named_parameters()
if not param.is_contiguous() or param.storage_offset() != 0 if not param.is_contiguous() or param.storage_offset() != 0
) )
if view_count > 0: if view_count > 0:
print(f" [Checkpoint] Unfused {view_count} view tensors (qkv/gate_up fusions)") print(
f" [Checkpoint] Unfused {view_count} view tensors (qkv/gate_up fusions)"
)
# Save state dict manually, then save config separately # Save state dict manually, then save config separately
torch.save(state_dict, os.path.join(checkpoint_path, "pytorch_model.bin")) torch.save(state_dict, os.path.join(checkpoint_path, "pytorch_model.bin"))
@ -102,6 +105,7 @@ def save_checkpoint(
# CRITICAL: Clean up the copied state_dict to free ~8GB GPU memory! # CRITICAL: Clean up the copied state_dict to free ~8GB GPU memory!
del state_dict del state_dict
import gc import gc
gc.collect() gc.collect()
torch.cuda.empty_cache() torch.cuda.empty_cache()
else: else:
@ -151,4 +155,3 @@ def save_lora_checkpoint(
print(" Adapter saved.") print(" Adapter saved.")
return adapter_path return adapter_path

View file

@ -11,16 +11,17 @@ import torch
from .config import TrainingConfig from .config import TrainingConfig
# ============================================================================= # =============================================================================
# Argument Group Builders (modular, reusable) # Argument Group Builders (modular, reusable)
# ============================================================================= # =============================================================================
def add_model_args(parser: argparse.ArgumentParser) -> None: def add_model_args(parser: argparse.ArgumentParser) -> None:
"""Add model-related arguments.""" """Add model-related arguments."""
group = parser.add_argument_group("Model") group = parser.add_argument_group("Model")
group.add_argument( group.add_argument(
"--model", "--model-name", "--model",
"--model-name",
type=str, type=str,
required=True, required=True,
dest="model_name", dest="model_name",
@ -67,7 +68,7 @@ def add_training_args(parser: argparse.ArgumentParser) -> None:
choices=["adamw", "adamw_8bit", "adamw_cpu", "adafactor"], choices=["adamw", "adamw_8bit", "adamw_cpu", "adafactor"],
default="adamw_8bit", default="adamw_8bit",
help="Optimizer: 'adamw' (full precision), 'adamw_8bit' (8-bit states), " help="Optimizer: 'adamw' (full precision), 'adamw_8bit' (8-bit states), "
"'adamw_cpu' (CPU offload), 'adafactor' (no momentum)", "'adamw_cpu' (CPU offload), 'adafactor' (no momentum)",
) )
group.add_argument( group.add_argument(
"--device", "--device",
@ -121,7 +122,8 @@ def add_vllm_args(parser: argparse.ArgumentParser) -> None:
help="Port for the vLLM server", help="Port for the vLLM server",
) )
group.add_argument( group.add_argument(
"--gpu-memory-utilization", "--vllm-gpu-memory-utilization", "--gpu-memory-utilization",
"--vllm-gpu-memory-utilization",
type=float, type=float,
default=0.45, default=0.45,
dest="gpu_memory_utilization", dest="gpu_memory_utilization",
@ -203,7 +205,9 @@ def add_lora_args(parser: argparse.ArgumentParser) -> None:
"""Add LoRA-specific arguments.""" """Add LoRA-specific arguments."""
group = parser.add_argument_group("LoRA Configuration") group = parser.add_argument_group("LoRA Configuration")
group.add_argument("--lora-r", type=int, default=16, help="LoRA rank") group.add_argument("--lora-r", type=int, default=16, help="LoRA rank")
group.add_argument("--lora-alpha", type=int, default=32, help="LoRA alpha (scaling factor)") group.add_argument(
"--lora-alpha", type=int, default=32, help="LoRA alpha (scaling factor)"
)
group.add_argument("--lora-dropout", type=float, default=0.05, help="LoRA dropout") group.add_argument("--lora-dropout", type=float, default=0.05, help="LoRA dropout")
group.add_argument( group.add_argument(
"--lora-target-modules", "--lora-target-modules",
@ -219,8 +223,12 @@ def add_distributed_args(parser: argparse.ArgumentParser) -> None:
group = parser.add_argument_group("Distributed Training") group = parser.add_argument_group("Distributed Training")
group.add_argument("--trainer-rank", type=int, default=0, help="Trainer rank") group.add_argument("--trainer-rank", type=int, default=0, help="Trainer rank")
group.add_argument("--world-size", type=int, default=1, help="World size") group.add_argument("--world-size", type=int, default=1, help="World size")
group.add_argument("--init-method", type=str, default="env://", help="Distributed init method") group.add_argument(
group.add_argument("--num-inference-nodes", type=int, default=0, help="Number of inference nodes") "--init-method", type=str, default="env://", help="Distributed init method"
)
group.add_argument(
"--num-inference-nodes", type=int, default=0, help="Number of inference nodes"
)
def add_debug_args(parser: argparse.ArgumentParser) -> None: def add_debug_args(parser: argparse.ArgumentParser) -> None:
@ -248,6 +256,7 @@ def add_debug_args(parser: argparse.ArgumentParser) -> None:
# Parser Builders # Parser Builders
# ============================================================================= # =============================================================================
def create_base_parser(description: str) -> argparse.ArgumentParser: def create_base_parser(description: str) -> argparse.ArgumentParser:
"""Create a base parser with common formatting.""" """Create a base parser with common formatting."""
return argparse.ArgumentParser( return argparse.ArgumentParser(
@ -299,6 +308,7 @@ def create_unified_parser() -> argparse.ArgumentParser:
# Legacy API (backwards compatibility) # Legacy API (backwards compatibility)
# ============================================================================= # =============================================================================
def parse_args() -> argparse.Namespace: def parse_args() -> argparse.Namespace:
""" """
Parse command-line arguments for the GRPO trainer (grpo.py). Parse command-line arguments for the GRPO trainer (grpo.py).

View file

@ -35,9 +35,9 @@ class TrainingConfig(BaseModel):
optimizer: Literal["adamw", "adamw_8bit", "adamw_cpu", "adafactor"] = Field( optimizer: Literal["adamw", "adamw_8bit", "adamw_cpu", "adafactor"] = Field(
"adamw_8bit", "adamw_8bit",
description="Optimizer to use: 'adamw' (full precision, ~32GB GPU), " description="Optimizer to use: 'adamw' (full precision, ~32GB GPU), "
"'adamw_8bit' (8-bit states, ~8GB GPU, requires bitsandbytes), " "'adamw_8bit' (8-bit states, ~8GB GPU, requires bitsandbytes), "
"'adamw_cpu' (CPU offload, ~0GB GPU, slower), " "'adamw_cpu' (CPU offload, ~0GB GPU, slower), "
"'adafactor' (no momentum, ~8GB GPU)" "'adafactor' (no momentum, ~8GB GPU)",
) )
# === GRPO/PPO Hyperparameters === # === GRPO/PPO Hyperparameters ===
@ -69,12 +69,10 @@ class TrainingConfig(BaseModel):
# === Device & Storage === # === Device & Storage ===
device: str = Field( device: str = Field(
"cuda" if torch.cuda.is_available() else "cpu", "cuda" if torch.cuda.is_available() else "cpu", description="Device to train on"
description="Device to train on"
) )
save_path: str = Field( save_path: str = Field(
"trained_model_checkpoints", "trained_model_checkpoints", description="Base path to save model checkpoints"
description="Base path to save model checkpoints"
) )
checkpoint_interval: int = Field( checkpoint_interval: int = Field(
3, 3,
@ -121,9 +119,7 @@ class TrainingConfig(BaseModel):
trainer_rank: int = Field( trainer_rank: int = Field(
0, description="Rank of this trainer in the distributed group" 0, description="Rank of this trainer in the distributed group"
) )
world_size: int = Field( world_size: int = Field(1, description="Total processes in the distributed group")
1, description="Total processes in the distributed group"
)
init_method: str = Field( init_method: str = Field(
"env://", "env://",
description=( description=(
@ -198,4 +194,3 @@ class TrainingConfig(BaseModel):
"Default is http://localhost:8000. Change for concurrent tests." "Default is http://localhost:8000. Change for concurrent tests."
), ),
) )

View file

@ -92,28 +92,30 @@ def pad_data_to_good_offset(
# Process each sample in the item # Process each sample in the item
for i in range(len(item["tokens"])): for i in range(len(item["tokens"])):
seq_len = len(item["tokens"][i]) seq_len = len(item["tokens"][i])
lengths.append( lengths.append(math.ceil((seq_len - 1) / good_multiple) * good_multiple)
math.ceil((seq_len - 1) / good_multiple) * good_multiple
)
# Create labels with padding (-100 for masked positions) # Create labels with padding (-100 for masked positions)
label_item = np.concatenate([ label_item = np.concatenate(
np.array(item["masks"][i]), [
np.full( np.array(item["masks"][i]),
max(0, token_setup_len - seq_len), np.full(
-100, max(0, token_setup_len - seq_len),
dtype=np.int32, -100,
), dtype=np.int32,
]) ),
]
)
# Pad tokens # Pad tokens
item["tokens"][i] = np.concatenate([ item["tokens"][i] = np.concatenate(
np.array(item["tokens"][i]), [
np.zeros( np.array(item["tokens"][i]),
max(0, token_setup_len - seq_len), np.zeros(
dtype=np.int32, max(0, token_setup_len - seq_len),
), dtype=np.int32,
]) ),
]
)
input_ids.append(item["tokens"][i][:-1]) # Remove last for causal input_ids.append(item["tokens"][i][:-1]) # Remove last for causal
labels.append(label_item[1:]) # Shift by 1 for causal labels.append(label_item[1:]) # Shift by 1 for causal
@ -126,7 +128,9 @@ def pad_data_to_good_offset(
# We just need to pad to match the sequence length # We just need to pad to match the sequence length
if extract_inference_logprobs and "inference_logprobs" in item: if extract_inference_logprobs and "inference_logprobs" in item:
if i < len(item["inference_logprobs"]): if i < len(item["inference_logprobs"]):
raw_logprobs = np.array(item["inference_logprobs"][i], dtype=np.float32) raw_logprobs = np.array(
item["inference_logprobs"][i], dtype=np.float32
)
has_any_logprobs = True has_any_logprobs = True
# Create padded logprobs array matching token_setup_len # Create padded logprobs array matching token_setup_len
@ -141,10 +145,14 @@ def pad_data_to_good_offset(
inference_logprobs_padded.append(padded_logprobs[1:]) inference_logprobs_padded.append(padded_logprobs[1:])
else: else:
# No logprobs for this sample, use 1.0 # No logprobs for this sample, use 1.0
inference_logprobs_padded.append(np.full(token_setup_len - 1, 1.0, dtype=np.float32)) inference_logprobs_padded.append(
np.full(token_setup_len - 1, 1.0, dtype=np.float32)
)
elif extract_inference_logprobs: elif extract_inference_logprobs:
# No inference_logprobs in item, use 1.0 # No inference_logprobs in item, use 1.0
inference_logprobs_padded.append(np.full(token_setup_len - 1, 1.0, dtype=np.float32)) inference_logprobs_padded.append(
np.full(token_setup_len - 1, 1.0, dtype=np.float32)
)
# Extract temperature (priority: override > generation_params > group_overrides > 1.0) # Extract temperature (priority: override > generation_params > group_overrides > 1.0)
t = 1.0 t = 1.0
@ -155,9 +163,13 @@ def pad_data_to_good_offset(
and ("temperature" in item["overrides"][i]) and ("temperature" in item["overrides"][i])
): ):
t = float(item["overrides"][i]["temperature"]) t = float(item["overrides"][i]["temperature"])
elif item.get("generation_params") and ("temperature" in item["generation_params"]): elif item.get("generation_params") and (
"temperature" in item["generation_params"]
):
t = float(item["generation_params"]["temperature"]) t = float(item["generation_params"]["temperature"])
elif item.get("group_overrides") and ("temperature" in item["group_overrides"]): elif item.get("group_overrides") and (
"temperature" in item["group_overrides"]
):
t = float(item["group_overrides"]["temperature"]) t = float(item["group_overrides"]["temperature"])
temperatures.append(t) temperatures.append(t)
@ -172,19 +184,15 @@ def pad_data_to_good_offset(
start = i * batch_size start = i * batch_size
end = (i + 1) * batch_size end = (i + 1) * batch_size
token_batches.append( token_batches.append(torch.tensor(np.stack(input_ids[start:end], axis=0)))
torch.tensor(np.stack(input_ids[start:end], axis=0)) label_batches.append(torch.tensor(np.stack(labels[start:end], axis=0)))
)
label_batches.append(
torch.tensor(np.stack(labels[start:end], axis=0))
)
advantage_batches.append( advantage_batches.append(
torch.tensor(np.stack(advantages[start:end], axis=0)).view(-1, 1) torch.tensor(np.stack(advantages[start:end], axis=0)).view(-1, 1)
) )
temperature_batches.append( temperature_batches.append(
torch.tensor( torch.tensor(np.array(temperatures[start:end], dtype=np.float32)).view(
np.array(temperatures[start:end], dtype=np.float32) -1, 1, 1
).view(-1, 1, 1) )
) )
# Batch inference logprobs (same shape as labels) # Batch inference logprobs (same shape as labels)
@ -194,9 +202,19 @@ def pad_data_to_good_offset(
) )
# Return inference logprob batches if we have any real logprobs # Return inference logprob batches if we have any real logprobs
final_logprob_batches = inference_logprob_batches if (has_any_logprobs and inference_logprob_batches) else None final_logprob_batches = (
inference_logprob_batches
if (has_any_logprobs and inference_logprob_batches)
else None
)
return token_batches, label_batches, advantage_batches, temperature_batches, final_logprob_batches return (
token_batches,
label_batches,
advantage_batches,
temperature_batches,
final_logprob_batches,
)
def get_data( def get_data(
@ -205,13 +223,15 @@ def get_data(
atropos_url: str = "http://localhost:8000", atropos_url: str = "http://localhost:8000",
extract_inference_logprobs: bool = True, extract_inference_logprobs: bool = True,
) -> Tuple[ ) -> Tuple[
List[Tuple[ List[
List[torch.Tensor], # token_batches Tuple[
List[torch.Tensor], # label_batches List[torch.Tensor], # token_batches
List[torch.Tensor], # advantage_batches List[torch.Tensor], # label_batches
List[torch.Tensor], # temperature_batches List[torch.Tensor], # advantage_batches
Optional[List[torch.Tensor]], # inference_logprob_batches List[torch.Tensor], # temperature_batches
]], Optional[List[torch.Tensor]], # inference_logprob_batches
]
],
None, # Legacy return (no longer used) None, # Legacy return (no longer used)
]: ]:
""" """
@ -241,18 +261,37 @@ def get_data(
if data["batch"] is not None: if data["batch"] is not None:
# DEBUG: Check if inference_logprobs exists in the data # DEBUG: Check if inference_logprobs exists in the data
if not _logged_logprob_warning: if not _logged_logprob_warning:
has_logprobs = any("inference_logprobs" in item for item in data["batch"]) has_logprobs = any(
"inference_logprobs" in item for item in data["batch"]
)
if has_logprobs: if has_logprobs:
# Check if they're non-empty # Check if they're non-empty
sample_item = next((item for item in data["batch"] if "inference_logprobs" in item), None) sample_item = next(
(
item
for item in data["batch"]
if "inference_logprobs" in item
),
None,
)
if sample_item and sample_item.get("inference_logprobs"): if sample_item and sample_item.get("inference_logprobs"):
sample_lp = sample_item["inference_logprobs"][0] if sample_item["inference_logprobs"] else [] sample_lp = (
print(f" [Data] ✓ inference_logprobs found in batch (sample len: {len(sample_lp)})") sample_item["inference_logprobs"][0]
if sample_item["inference_logprobs"]
else []
)
print(
f" [Data] ✓ inference_logprobs found in batch (sample len: {len(sample_lp)})"
)
else: else:
print(" [Data] ⚠ inference_logprobs key exists but is empty!") print(
" [Data] ⚠ inference_logprobs key exists but is empty!"
)
else: else:
print(" [Data] ⚠ NO inference_logprobs in batch data!") print(" [Data] ⚠ NO inference_logprobs in batch data!")
print(f" [Data] Keys in first item: {list(data['batch'][0].keys())}") print(
f" [Data] Keys in first item: {list(data['batch'][0].keys())}"
)
_logged_logprob_warning = True _logged_logprob_warning = True
# Save batch for debugging # Save batch for debugging
@ -260,11 +299,24 @@ def get_data(
json.dump(data, f) json.dump(data, f)
# Process and accumulate batches (now includes batched inference logprobs) # Process and accumulate batches (now includes batched inference logprobs)
token_batches, label_batches, adv_batches, temp_batches, inf_logprob_batches = \ (
pad_data_to_good_offset(data, batch_size, extract_inference_logprobs) token_batches,
label_batches,
adv_batches,
temp_batches,
inf_logprob_batches,
) = pad_data_to_good_offset(data, batch_size, extract_inference_logprobs)
# Include inference logprob batches in the tuple # Include inference logprob batches in the tuple
batches.append((token_batches, label_batches, adv_batches, temp_batches, inf_logprob_batches)) batches.append(
(
token_batches,
label_batches,
adv_batches,
temp_batches,
inf_logprob_batches,
)
)
elif len(batches) > 0: elif len(batches) > 0:
# Return accumulated batches when no more data # Return accumulated batches when no more data
@ -272,4 +324,3 @@ def get_data(
else: else:
# Wait for data # Wait for data
time.sleep(1) time.sleep(1)

View file

@ -19,8 +19,8 @@ Usage:
--weight-bridge-mode lora_only --lora-r 16 --lora-alpha 32 --weight-bridge-mode lora_only --lora-r 16 --lora-alpha 32
""" """
from .cli import parse_args, config_from_args from .cli import config_from_args, parse_args
from .trainers import train_legacy, train_shared_vllm, train_lora from .trainers import train_legacy, train_lora, train_shared_vllm
def main(): def main():
@ -28,9 +28,9 @@ def main():
args = parse_args() args = parse_args()
config = config_from_args(args) config = config_from_args(args)
print("\n" + "="*60) print("\n" + "=" * 60)
print("GRPO TRAINER") print("GRPO TRAINER")
print("="*60) print("=" * 60)
print(f"Model: {config.model_name}") print(f"Model: {config.model_name}")
print(f"Mode: {config.weight_bridge_mode}") print(f"Mode: {config.weight_bridge_mode}")
print(f"Training steps: {config.training_steps}") print(f"Training steps: {config.training_steps}")

View file

@ -20,6 +20,7 @@ from .config import TrainingConfig
# Import PEFT for LoRA training # Import PEFT for LoRA training
try: try:
from peft import LoraConfig, TaskType, get_peft_model from peft import LoraConfig, TaskType, get_peft_model
PEFT_AVAILABLE = True PEFT_AVAILABLE = True
except ImportError: except ImportError:
PEFT_AVAILABLE = False PEFT_AVAILABLE = False
@ -37,6 +38,7 @@ def _get_attention_implementation() -> str:
""" """
try: try:
import flash_attn # noqa: F401 import flash_attn # noqa: F401
return "flash_attention_2" return "flash_attention_2"
except ImportError: except ImportError:
return "sdpa" return "sdpa"
@ -61,12 +63,19 @@ def _load_model_with_attention(
""" """
# Select the loader function based on mode # Select the loader function based on mode
# from_config: creates empty shell (meta device), from_pretrained: loads weights # from_config: creates empty shell (meta device), from_pretrained: loads weights
loader = AutoModelForCausalLM.from_config if from_config else AutoModelForCausalLM.from_pretrained loader = (
AutoModelForCausalLM.from_config
if from_config
else AutoModelForCausalLM.from_pretrained
)
# Try attention implementations in order of preference # Try attention implementations in order of preference
for attn_impl in ["flash_attention_2", "sdpa"]: for attn_impl in ["flash_attention_2", "sdpa"]:
# Skip flash_attention_2 if not available # Skip flash_attention_2 if not available
if attn_impl == "flash_attention_2" and _get_attention_implementation() != "flash_attention_2": if (
attn_impl == "flash_attention_2"
and _get_attention_implementation() != "flash_attention_2"
):
continue continue
try: try:
@ -86,6 +95,7 @@ def _load_model_with_attention(
# Should never reach here, but just in case # Should never reach here, but just in case
raise RuntimeError("Failed to load model with any attention implementation") raise RuntimeError("Failed to load model with any attention implementation")
def load_model_and_tokenizer( def load_model_and_tokenizer(
config: TrainingConfig, config: TrainingConfig,
single_copy: bool = False, single_copy: bool = False,
@ -178,7 +188,7 @@ def _load_model_with_lora(config: TrainingConfig) -> torch.nn.Module:
Returns: Returns:
PEFT model with LoRA adapters applied PEFT model with LoRA adapters applied
""" """
if not PEFT_AVAILABLE: # Yeah no PEFT is needed no matter what bless huggingface if not PEFT_AVAILABLE: # Yeah no PEFT is needed no matter what bless huggingface
raise RuntimeError("PEFT library not available. Install with: pip install peft") raise RuntimeError("PEFT library not available. Install with: pip install peft")
print("[Setup] Loading base model for LoRA mode...") print("[Setup] Loading base model for LoRA mode...")
@ -208,7 +218,9 @@ def _load_model_with_lora(config: TrainingConfig) -> torch.nn.Module:
return model return model
def _setup_gradient_checkpointing(model: torch.nn.Module, config: TrainingConfig) -> None: def _setup_gradient_checkpointing(
model: torch.nn.Module, config: TrainingConfig
) -> None:
"""Configure gradient checkpointing for the model.""" """Configure gradient checkpointing for the model."""
# Disable KV cache - incompatible with gradient checkpointing # Disable KV cache - incompatible with gradient checkpointing
model.config.use_cache = False model.config.use_cache = False
@ -297,7 +309,9 @@ def _attach_to_vllm_shared_tensors(
ipc_handles, vllm_to_hf_mapping, config ipc_handles, vllm_to_hf_mapping, config
) )
print(f"[Setup] Attached {attached_count} tensors ({fused_count} from fused layers)") print(
f"[Setup] Attached {attached_count} tensors ({fused_count} from fused layers)"
)
if attached_count == 0: if attached_count == 0:
print("[Setup] Could not attach any tensors, falling back to regular loading") print("[Setup] Could not attach any tensors, falling back to regular loading")
@ -322,6 +336,7 @@ def _attach_to_vllm_shared_tensors(
def _deserialize_ipc_handles(handles_raw: dict) -> dict: def _deserialize_ipc_handles(handles_raw: dict) -> dict:
"""Deserialize base64-encoded bytes in IPC handles.""" """Deserialize base64-encoded bytes in IPC handles."""
def deserialize(handles): def deserialize(handles):
result = {} result = {}
for k, v in handles.items(): for k, v in handles.items():
@ -333,6 +348,7 @@ def _deserialize_ipc_handles(handles_raw: dict) -> dict:
else: else:
result[k] = v result[k] = v
return result return result
return deserialize(handles_raw) return deserialize(handles_raw)
@ -387,8 +403,14 @@ def _reconstruct_shared_tensors(
event_sync_required = ipc_info["event_sync_required"] event_sync_required = ipc_info["event_sync_required"]
share_tuple = ( share_tuple = (
device_index, ipc_handle, storage_size, storage_offset_orig, device_index,
ref_counter_handle, ref_counter_offset, event_handle, event_sync_required, ipc_handle,
storage_size,
storage_offset_orig,
ref_counter_handle,
ref_counter_offset,
event_handle,
event_sync_required,
) )
storage = torch.UntypedStorage._new_shared_cuda(*share_tuple) storage = torch.UntypedStorage._new_shared_cuda(*share_tuple)
@ -424,7 +446,9 @@ def _reconstruct_shared_tensors(
if slice_dim == 0: if slice_dim == 0:
tensor = full_tensor[slice_start:slice_end] tensor = full_tensor[slice_start:slice_end]
else: else:
tensor = full_tensor.narrow(slice_dim, slice_start, slice_end - slice_start) tensor = full_tensor.narrow(
slice_dim, slice_start, slice_end - slice_start
)
tensor.requires_grad_(True) tensor.requires_grad_(True)
hf_state_dict[hf_name] = tensor hf_state_dict[hf_name] = tensor
@ -459,12 +483,16 @@ def _validate_mapping_coverage(
# Note: attached_count may be > param_count because state_dict includes buffers # Note: attached_count may be > param_count because state_dict includes buffers
# while named_parameters only counts trainable params # while named_parameters only counts trainable params
print(f"[Setup] Mapping coverage: {attached_count} tensors for {hf_param_count} parameters " print(
f"(>100% is OK - includes buffers)") f"[Setup] Mapping coverage: {attached_count} tensors for {hf_param_count} parameters "
f"(>100% is OK - includes buffers)"
)
if mapping_coverage < 0.90: if mapping_coverage < 0.90:
unmapped_params = set(model.state_dict().keys()) - set(hf_state_dict.keys()) unmapped_params = set(model.state_dict().keys()) - set(hf_state_dict.keys())
warning_msg = f"[Setup] WARNING: Low mapping coverage ({mapping_coverage:.1%})\n" warning_msg = (
f"[Setup] WARNING: Low mapping coverage ({mapping_coverage:.1%})\n"
)
warning_msg += f"Unmapped parameters ({len(unmapped_params)}):\n" warning_msg += f"Unmapped parameters ({len(unmapped_params)}):\n"
for name in list(unmapped_params)[:20]: for name in list(unmapped_params)[:20]:
warning_msg += f" - {name}\n" warning_msg += f" - {name}\n"
@ -484,11 +512,17 @@ def _initialize_meta_tensors(
config: TrainingConfig, config: TrainingConfig,
) -> None: ) -> None:
"""Initialize any remaining meta tensors after loading.""" """Initialize any remaining meta tensors after loading."""
meta_params = [name for name, p in model.named_parameters() if p.device.type == "meta"] meta_params = [
meta_buffers = [name for name, b in model.named_buffers() if b.device.type == "meta"] name for name, p in model.named_parameters() if p.device.type == "meta"
]
meta_buffers = [
name for name, b in model.named_buffers() if b.device.type == "meta"
]
if config.debug_loading: if config.debug_loading:
print(f"\n[DIAGNOSTIC] Meta params: {len(meta_params)}, Meta buffers: {len(meta_buffers)}") print(
f"\n[DIAGNOSTIC] Meta params: {len(meta_params)}, Meta buffers: {len(meta_buffers)}"
)
def get_parent_and_name(model, full_name): def get_parent_and_name(model, full_name):
parts = full_name.split(".") parts = full_name.split(".")
@ -526,11 +560,15 @@ def _initialize_meta_tensors(
dim = buffer.shape[0] * 2 dim = buffer.shape[0] * 2
# Get rope_theta from model config (default 10000.0 for LLaMA, but Qwen3 uses 5000000!) # Get rope_theta from model config (default 10000.0 for LLaMA, but Qwen3 uses 5000000!)
rope_theta = getattr(model.config, "rope_theta", 10000.0) rope_theta = getattr(model.config, "rope_theta", 10000.0)
inv_freq = 1.0 / (rope_theta ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) inv_freq = 1.0 / (
rope_theta ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)
)
new_buffer = inv_freq.to(dtype=buffer.dtype, device=device) new_buffer = inv_freq.to(dtype=buffer.dtype, device=device)
print(f"[Setup] Initialized {name} with rope_theta={rope_theta}") print(f"[Setup] Initialized {name} with rope_theta={rope_theta}")
else: else:
new_buffer = torch.zeros(buffer.shape, dtype=buffer.dtype, device=device) new_buffer = torch.zeros(
buffer.shape, dtype=buffer.dtype, device=device
)
parent, attr_name = get_parent_and_name(model, name) parent, attr_name = get_parent_and_name(model, name)
parent.register_buffer(attr_name, new_buffer) parent.register_buffer(attr_name, new_buffer)
@ -544,8 +582,12 @@ def _initialize_meta_tensors(
def _validate_no_meta_tensors(model: torch.nn.Module) -> None: def _validate_no_meta_tensors(model: torch.nn.Module) -> None:
"""Ensure no parameters or buffers are still on meta device.""" """Ensure no parameters or buffers are still on meta device."""
final_meta_params = [name for name, p in model.named_parameters() if p.device.type == "meta"] final_meta_params = [
final_meta_buffers = [name for name, b in model.named_buffers() if b.device.type == "meta"] name for name, p in model.named_parameters() if p.device.type == "meta"
]
final_meta_buffers = [
name for name, b in model.named_buffers() if b.device.type == "meta"
]
if final_meta_params or final_meta_buffers: if final_meta_params or final_meta_buffers:
error_msg = "[Setup] CRITICAL ERROR: Some tensors are still on meta device!\n" error_msg = "[Setup] CRITICAL ERROR: Some tensors are still on meta device!\n"
@ -588,7 +630,9 @@ def _create_vllm_to_hf_mapping(
model_config = model.config model_config = model.config
hidden_size = getattr(model_config, "hidden_size", 4096) hidden_size = getattr(model_config, "hidden_size", 4096)
num_attention_heads = getattr(model_config, "num_attention_heads", 32) num_attention_heads = getattr(model_config, "num_attention_heads", 32)
num_key_value_heads = getattr(model_config, "num_key_value_heads", num_attention_heads) num_key_value_heads = getattr(
model_config, "num_key_value_heads", num_attention_heads
)
intermediate_size = getattr(model_config, "intermediate_size", hidden_size * 4) intermediate_size = getattr(model_config, "intermediate_size", hidden_size * 4)
# Try to get head_dim from config (some models like Qwen3 have this) # Try to get head_dim from config (some models like Qwen3 have this)
@ -639,8 +683,10 @@ def _create_vllm_to_hf_mapping(
up_size = intermediate_size up_size = intermediate_size
# Always print sizes for debugging weight sharing issues # Always print sizes for debugging weight sharing issues
print(f"[Mapping] Model config: hidden={hidden_size}, heads={num_attention_heads}, " print(
f"kv_heads={num_key_value_heads}, head_dim={head_dim}") f"[Mapping] Model config: hidden={hidden_size}, heads={num_attention_heads}, "
f"kv_heads={num_key_value_heads}, head_dim={head_dim}"
)
print(f"[Mapping] QKV sizes from HF model: q={q_size}, k={k_size}, v={v_size}") print(f"[Mapping] QKV sizes from HF model: q={q_size}, k={k_size}, v={v_size}")
print(f"[Mapping] Gate/Up sizes from HF model: gate={gate_size}, up={up_size}") print(f"[Mapping] Gate/Up sizes from HF model: gate={gate_size}, up={up_size}")
@ -714,7 +760,8 @@ def _create_vllm_to_hf_mapping(
if debug: if debug:
direct = sum(1 for v in mapping.values() if isinstance(v, str)) direct = sum(1 for v in mapping.values() if isinstance(v, str))
fused = sum(1 for v in mapping.values() if isinstance(v, dict)) fused = sum(1 for v in mapping.values() if isinstance(v, dict))
print(f"[Mapping] Total: {len(mapping)} mapped ({direct} direct, {fused} fused)") print(
f"[Mapping] Total: {len(mapping)} mapped ({direct} direct, {fused} fused)"
)
return mapping return mapping

View file

@ -60,10 +60,13 @@ def wait_for_bridge_config(config_path: str, timeout: int = 60) -> bool:
if os.path.exists(config_path): if os.path.exists(config_path):
try: try:
import json import json
with open(config_path, 'r') as f:
with open(config_path, "r") as f:
config = json.load(f) config = json.load(f)
if config.get('ipc_handles') and len(config['ipc_handles']) > 0: if config.get("ipc_handles") and len(config["ipc_handles"]) > 0:
print(f"[Run] ✓ Bridge config ready with {len(config['ipc_handles'])} IPC handles") print(
f"[Run] ✓ Bridge config ready with {len(config['ipc_handles'])} IPC handles"
)
return True return True
except Exception: except Exception:
pass pass
@ -79,7 +82,7 @@ def main():
args = parser.parse_args() args = parser.parse_args()
# Create log directory # Create log directory
log_dir = getattr(args, 'log_dir', './logs') log_dir = getattr(args, "log_dir", "./logs")
os.makedirs(log_dir, exist_ok=True) os.makedirs(log_dir, exist_ok=True)
# Bridge config path # Bridge config path
@ -91,16 +94,16 @@ def main():
print("[Run] Removed old bridge config") print("[Run] Removed old bridge config")
# === Print Configuration === # === Print Configuration ===
print("\n" + "="*60) print("\n" + "=" * 60)
print("STARTING UNIFIED GRPO TRAINER (shared_vllm mode)") print("STARTING UNIFIED GRPO TRAINER (shared_vllm mode)")
print("="*60) print("=" * 60)
print(f"Model: {args.model_name}") print(f"Model: {args.model_name}")
print(f"vLLM port: {args.vllm_port}") print(f"vLLM port: {args.vllm_port}")
print(f"GPU memory utilization: {args.gpu_memory_utilization}") print(f"GPU memory utilization: {args.gpu_memory_utilization}")
print(f"Training steps: {args.training_steps}") print(f"Training steps: {args.training_steps}")
print(f"Optimizer: {args.optimizer}") print(f"Optimizer: {args.optimizer}")
print(f"GRPO: kl_coef={args.kl_coef}, clip_eps={args.clip_eps}") print(f"GRPO: kl_coef={args.kl_coef}, clip_eps={args.clip_eps}")
print("="*60 + "\n") print("=" * 60 + "\n")
# Get the path to vllm_api_server.py # Get the path to vllm_api_server.py
script_dir = Path(__file__).parent script_dir = Path(__file__).parent
@ -126,12 +129,19 @@ def main():
# Build vLLM command # Build vLLM command
vllm_cmd = [ vllm_cmd = [
sys.executable, "-u", str(vllm_server_script), sys.executable,
"--model", args.model_name, "-u",
"--port", str(args.vllm_port), str(vllm_server_script),
"--dtype", args.dtype, "--model",
"--gpu-memory-utilization", str(args.gpu_memory_utilization), args.model_name,
"--max-model-len", str(args.max_model_len), "--port",
str(args.vllm_port),
"--dtype",
args.dtype,
"--gpu-memory-utilization",
str(args.gpu_memory_utilization),
"--max-model-len",
str(args.max_model_len),
"--enforce-eager", # Required for shared weights "--enforce-eager", # Required for shared weights
] ]
@ -212,6 +222,7 @@ def main():
except Exception as e: except Exception as e:
print(f"\n[Run] ✗ Training failed: {e}") print(f"\n[Run] ✗ Training failed: {e}")
import traceback import traceback
traceback.print_exc() traceback.print_exc()
sys.exit(1) sys.exit(1)

View file

@ -138,4 +138,3 @@ if [ -d "$LOG_DIR/checkpoints" ]; then
}' | jq '.text[0]' | tee "${LOG_DIR}/trained_response.txt" }' | jq '.text[0]' | tee "${LOG_DIR}/trained_response.txt"
fi fi
fi fi

View file

@ -143,4 +143,3 @@ curl -s -X POST "http://localhost:${VLLM_PORT}/generate" \
"max_tokens": 100, "max_tokens": 100,
"temperature": 0.1 "temperature": 0.1
}' | jq '.text[0]' | tee "${LOG_DIR}/trained_response.txt" }' | jq '.text[0]' | tee "${LOG_DIR}/trained_response.txt"

View file

@ -25,7 +25,10 @@ class CPUOffloadAdamW(torch.optim.Optimizer):
Full precision (no quantization), but states stay on CPU RAM instead of GPU. Full precision (no quantization), but states stay on CPU RAM instead of GPU.
Trade-off: Slower (~2x) but uses ~0GB GPU memory for optimizer states. Trade-off: Slower (~2x) but uses ~0GB GPU memory for optimizer states.
""" """
def __init__(self, params, lr=1e-5, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.01):
def __init__(
self, params, lr=1e-5, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.01
):
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
super().__init__(params, defaults) super().__init__(params, defaults)
@ -33,10 +36,10 @@ class CPUOffloadAdamW(torch.optim.Optimizer):
"""Lazily initialize state on CPU.""" """Lazily initialize state on CPU."""
state = self.state[p] state = self.state[p]
if len(state) == 0: if len(state) == 0:
state['step'] = 0 state["step"] = 0
# Store on CPU in FP32 # Store on CPU in FP32
state['exp_avg'] = torch.zeros_like(p, device='cpu', dtype=torch.float32) state["exp_avg"] = torch.zeros_like(p, device="cpu", dtype=torch.float32)
state['exp_avg_sq'] = torch.zeros_like(p, device='cpu', dtype=torch.float32) state["exp_avg_sq"] = torch.zeros_like(p, device="cpu", dtype=torch.float32)
return state return state
@torch.no_grad() @torch.no_grad()
@ -47,41 +50,41 @@ class CPUOffloadAdamW(torch.optim.Optimizer):
loss = closure() loss = closure()
for group in self.param_groups: for group in self.param_groups:
beta1, beta2 = group['betas'] beta1, beta2 = group["betas"]
for p in group['params']: for p in group["params"]:
if p.grad is None: if p.grad is None:
continue continue
grad = p.grad grad = p.grad
state = self._init_state(p) state = self._init_state(p)
state['step'] += 1 state["step"] += 1
# Move states to GPU for computation # Move states to GPU for computation
exp_avg = state['exp_avg'].to(p.device) exp_avg = state["exp_avg"].to(p.device)
exp_avg_sq = state['exp_avg_sq'].to(p.device) exp_avg_sq = state["exp_avg_sq"].to(p.device)
# AdamW update # AdamW update
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
# Bias correction # Bias correction
bias_correction1 = 1 - beta1 ** state['step'] bias_correction1 = 1 - beta1 ** state["step"]
bias_correction2 = 1 - beta2 ** state['step'] bias_correction2 = 1 - beta2 ** state["step"]
step_size = group['lr'] / bias_correction1 step_size = group["lr"] / bias_correction1
# Update weights # Update weights
denom = (exp_avg_sq.sqrt() / (bias_correction2 ** 0.5)).add_(group['eps']) denom = (exp_avg_sq.sqrt() / (bias_correction2**0.5)).add_(group["eps"])
p.addcdiv_(exp_avg, denom, value=-step_size) p.addcdiv_(exp_avg, denom, value=-step_size)
# Weight decay # Weight decay
if group['weight_decay'] != 0: if group["weight_decay"] != 0:
p.add_(p, alpha=-group['lr'] * group['weight_decay']) p.add_(p, alpha=-group["lr"] * group["weight_decay"])
# Move states back to CPU (non-blocking for better perf) # Move states back to CPU (non-blocking for better perf)
state['exp_avg'].copy_(exp_avg.cpu()) state["exp_avg"].copy_(exp_avg.cpu())
state['exp_avg_sq'].copy_(exp_avg_sq.cpu()) state["exp_avg_sq"].copy_(exp_avg_sq.cpu())
return loss return loss
@ -99,6 +102,7 @@ def create_optimizer(model: torch.nn.Module, config) -> torch.optim.Optimizer:
if config.optimizer == "adamw_8bit": if config.optimizer == "adamw_8bit":
try: try:
import bitsandbytes as bnb import bitsandbytes as bnb
optimizer = bnb.optim.AdamW8bit(model.parameters(), lr=config.lr) optimizer = bnb.optim.AdamW8bit(model.parameters(), lr=config.lr)
print("[Setup] Using 8-bit AdamW (saves ~24GB optimizer memory)") print("[Setup] Using 8-bit AdamW (saves ~24GB optimizer memory)")
return optimizer return optimizer
@ -108,13 +112,18 @@ def create_optimizer(model: torch.nn.Module, config) -> torch.optim.Optimizer:
if config.optimizer == "adamw_cpu": if config.optimizer == "adamw_cpu":
optimizer = CPUOffloadAdamW(model.parameters(), lr=config.lr) optimizer = CPUOffloadAdamW(model.parameters(), lr=config.lr)
print("[Setup] Using AdamW with CPU offload (full precision, ~0GB GPU for states)") print(
print("[Setup] NOTE: ~2x slower due to CPU<->GPU transfers, but no quantization") "[Setup] Using AdamW with CPU offload (full precision, ~0GB GPU for states)"
)
print(
"[Setup] NOTE: ~2x slower due to CPU<->GPU transfers, but no quantization"
)
return optimizer return optimizer
if config.optimizer == "adafactor": if config.optimizer == "adafactor":
try: try:
from transformers.optimization import Adafactor from transformers.optimization import Adafactor
optimizer = Adafactor( optimizer = Adafactor(
model.parameters(), model.parameters(),
lr=config.lr, lr=config.lr,
@ -135,7 +144,7 @@ def create_optimizer(model: torch.nn.Module, config) -> torch.optim.Optimizer:
from .checkpointing import save_checkpoint, save_lora_checkpoint # noqa: E402 from .checkpointing import save_checkpoint, save_lora_checkpoint # noqa: E402
from .config import TrainingConfig # noqa: E402 from .config import TrainingConfig # noqa: E402
from .data import get_data # noqa: E402 from .data import get_data # noqa: E402
from .model import load_model_and_tokenizer, PEFT_AVAILABLE # noqa: E402 from .model import PEFT_AVAILABLE, load_model_and_tokenizer # noqa: E402
from .training import ( # noqa: E402 from .training import ( # noqa: E402
finalize_training, finalize_training,
log_metrics, log_metrics,
@ -146,8 +155,8 @@ from .vllm_manager import ( # noqa: E402
check_vllm_health, check_vllm_health,
check_vllm_process_health, check_vllm_process_health,
launch_vllm_server, launch_vllm_server,
terminate_vllm_process,
set_vllm_process, set_vllm_process,
terminate_vllm_process,
) )
@ -171,13 +180,13 @@ def train_legacy(config: TrainingConfig):
model, tokenizer = load_model_and_tokenizer(config) model, tokenizer = load_model_and_tokenizer(config)
optimizer = create_optimizer(model, config) optimizer = create_optimizer(model, config)
print("\n" + "="*60) print("\n" + "=" * 60)
print("LEGACY MODE (checkpoint + vLLM restart)") print("LEGACY MODE (checkpoint + vLLM restart)")
print("="*60) print("=" * 60)
print(f"Training for {config.training_steps} steps on {config.device}") print(f"Training for {config.training_steps} steps on {config.device}")
print(f"vLLM restart interval: every {config.vllm_restart_interval} steps") print(f"vLLM restart interval: every {config.vllm_restart_interval} steps")
print(f"Save path: {config.save_path}") print(f"Save path: {config.save_path}")
print("="*60 + "\n") print("=" * 60 + "\n")
os.makedirs(config.save_path, exist_ok=True) os.makedirs(config.save_path, exist_ok=True)
@ -206,24 +215,36 @@ def train_legacy(config: TrainingConfig):
# Fetch data (with inference logprobs for proper GRPO) # Fetch data (with inference logprobs for proper GRPO)
data_fetch_start = time.time() data_fetch_start = time.time()
if len(batches) == 0: if len(batches) == 0:
batches, _ = get_data(config.batch_size, config.seq_len, config.atropos_url, batches, _ = get_data(
extract_inference_logprobs=True) config.batch_size,
config.seq_len,
config.atropos_url,
extract_inference_logprobs=True,
)
batch_data = batches.pop(0) batch_data = batches.pop(0)
token_batches, label_batches, advantage_batches, temperature_batches = batch_data[:4] token_batches, label_batches, advantage_batches, temperature_batches = (
batch_data[:4]
)
inference_logprob_batches = batch_data[4] if len(batch_data) > 4 else None inference_logprob_batches = batch_data[4] if len(batch_data) > 4 else None
data_fetch_time = time.time() - data_fetch_start data_fetch_time = time.time() - data_fetch_start
benchmark_stats["data_fetch_times"].append(data_fetch_time) benchmark_stats["data_fetch_times"].append(data_fetch_time)
# Check if we should sync (save checkpoint + restart vLLM) # Check if we should sync (save checkpoint + restart vLLM)
should_sync = (step + 1) % config.vllm_restart_interval == 0 or step == config.training_steps - 1 should_sync = (
step + 1
) % config.vllm_restart_interval == 0 or step == config.training_steps - 1
if should_sync: if should_sync:
terminate_vllm_process() terminate_vllm_process()
# Training step (with proper GRPO using inference logprobs) # Training step (with proper GRPO using inference logprobs)
step_start = time.time() step_start = time.time()
metrics = run_training_step( metrics = run_training_step(
model, optimizer, model,
token_batches, label_batches, advantage_batches, temperature_batches, optimizer,
token_batches,
label_batches,
advantage_batches,
temperature_batches,
config, config,
inference_logprob_batches=inference_logprob_batches, inference_logprob_batches=inference_logprob_batches,
) )
@ -231,15 +252,21 @@ def train_legacy(config: TrainingConfig):
benchmark_stats["step_times"].append(step_time) benchmark_stats["step_times"].append(step_time)
# GPU memory tracking # GPU memory tracking
gpu_mem_gb = torch.cuda.memory_allocated() / 1e9 if torch.cuda.is_available() else 0 gpu_mem_gb = (
gpu_mem_reserved_gb = torch.cuda.memory_reserved() / 1e9 if torch.cuda.is_available() else 0 torch.cuda.memory_allocated() / 1e9 if torch.cuda.is_available() else 0
)
gpu_mem_reserved_gb = (
torch.cuda.memory_reserved() / 1e9 if torch.cuda.is_available() else 0
)
benchmark_stats["gpu_memories"].append(gpu_mem_gb) benchmark_stats["gpu_memories"].append(gpu_mem_gb)
# Sync (checkpoint + restart) # Sync (checkpoint + restart)
sync_time = 0 sync_time = 0
if should_sync: if should_sync:
sync_start = time.time() sync_start = time.time()
checkpoint_path = save_checkpoint(model, tokenizer, config.save_path, step + 1) checkpoint_path = save_checkpoint(
model, tokenizer, config.save_path, step + 1
)
torch.cuda.empty_cache() torch.cuda.empty_cache()
vllm_proc = launch_vllm_server(config, checkpoint_path) vllm_proc = launch_vllm_server(config, checkpoint_path)
set_vllm_process(vllm_proc) set_vllm_process(vllm_proc)
@ -247,20 +274,31 @@ def train_legacy(config: TrainingConfig):
benchmark_stats["sync_times"].append(sync_time) benchmark_stats["sync_times"].append(sync_time)
# Update metrics # Update metrics
metrics.update({ metrics.update(
"step_time": step_time, {
"sync_time": sync_time, "step_time": step_time,
"data_fetch_time": data_fetch_time, "sync_time": sync_time,
"gpu_memory_gb": gpu_mem_gb, "data_fetch_time": data_fetch_time,
"gpu_memory_reserved_gb": gpu_mem_reserved_gb, "gpu_memory_gb": gpu_mem_gb,
}) "gpu_memory_reserved_gb": gpu_mem_reserved_gb,
}
)
log_metrics(metrics, step + 1, use_wandb, benchmark=config.benchmark) log_metrics(metrics, step + 1, use_wandb, benchmark=config.benchmark)
check_vllm_process_health() check_vllm_process_health()
# === Cleanup === # === Cleanup ===
save_checkpoint(model, tokenizer, config.save_path, config.training_steps, is_final=True) save_checkpoint(
finalize_training(use_wandb, training_start_time, "legacy", config.training_steps, benchmark_stats, config.benchmark) model, tokenizer, config.save_path, config.training_steps, is_final=True
)
finalize_training(
use_wandb,
training_start_time,
"legacy",
config.training_steps,
benchmark_stats,
config.benchmark,
)
def train_shared_vllm(config: TrainingConfig): def train_shared_vllm(config: TrainingConfig):
@ -281,13 +319,13 @@ def train_shared_vllm(config: TrainingConfig):
# === Setup === # === Setup ===
use_wandb = setup_wandb(config) use_wandb = setup_wandb(config)
print("\n" + "="*60) print("\n" + "=" * 60)
print("SINGLE-COPY MODE (CUDA IPC)") print("SINGLE-COPY MODE (CUDA IPC)")
print(">>> Trainer uses vLLM's tensors directly!") print(">>> Trainer uses vLLM's tensors directly!")
print("="*60) print("=" * 60)
print(f"Model: {config.model_name}") print(f"Model: {config.model_name}")
print(f"Save path: {config.save_path}") print(f"Save path: {config.save_path}")
print("="*60 + "\n") print("=" * 60 + "\n")
# Attach to vLLM's shared tensors # Attach to vLLM's shared tensors
print("[1/2] Attaching to vLLM's shared tensors...") print("[1/2] Attaching to vLLM's shared tensors...")
@ -331,11 +369,15 @@ def train_shared_vllm(config: TrainingConfig):
data_fetch_start = time.time() data_fetch_start = time.time()
if len(batches) == 0: if len(batches) == 0:
batches, _ = get_data( batches, _ = get_data(
config.batch_size, config.seq_len, config.atropos_url, config.batch_size,
config.seq_len,
config.atropos_url,
extract_inference_logprobs=True, # Enable proper GRPO with reference logprobs extract_inference_logprobs=True, # Enable proper GRPO with reference logprobs
) )
batch_data = batches.pop(0) batch_data = batches.pop(0)
token_batches, label_batches, advantage_batches, temperature_batches = batch_data[:4] token_batches, label_batches, advantage_batches, temperature_batches = (
batch_data[:4]
)
inference_logprob_batches = batch_data[4] if len(batch_data) > 4 else None inference_logprob_batches = batch_data[4] if len(batch_data) > 4 else None
data_fetch_time = time.time() - data_fetch_start data_fetch_time = time.time() - data_fetch_start
benchmark_stats["data_fetch_times"].append(data_fetch_time) benchmark_stats["data_fetch_times"].append(data_fetch_time)
@ -343,8 +385,12 @@ def train_shared_vllm(config: TrainingConfig):
# Training step with proper GRPO (importance sampling + KL penalty) # Training step with proper GRPO (importance sampling + KL penalty)
step_start = time.time() step_start = time.time()
metrics = run_training_step( metrics = run_training_step(
model, optimizer, model,
token_batches, label_batches, advantage_batches, temperature_batches, optimizer,
token_batches,
label_batches,
advantage_batches,
temperature_batches,
config, config,
inference_logprob_batches=inference_logprob_batches, # Pass for GRPO ratio computation inference_logprob_batches=inference_logprob_batches, # Pass for GRPO ratio computation
) )
@ -352,8 +398,12 @@ def train_shared_vllm(config: TrainingConfig):
benchmark_stats["step_times"].append(step_time) benchmark_stats["step_times"].append(step_time)
# GPU memory tracking # GPU memory tracking
gpu_mem_gb = torch.cuda.memory_allocated() / 1e9 if torch.cuda.is_available() else 0 gpu_mem_gb = (
gpu_mem_reserved_gb = torch.cuda.memory_reserved() / 1e9 if torch.cuda.is_available() else 0 torch.cuda.memory_allocated() / 1e9 if torch.cuda.is_available() else 0
)
gpu_mem_reserved_gb = (
torch.cuda.memory_reserved() / 1e9 if torch.cuda.is_available() else 0
)
benchmark_stats["gpu_memories"].append(gpu_mem_gb) benchmark_stats["gpu_memories"].append(gpu_mem_gb)
# In single-copy mode, weights are updated in-place (no sync needed!) # In single-copy mode, weights are updated in-place (no sync needed!)
@ -362,23 +412,37 @@ def train_shared_vllm(config: TrainingConfig):
benchmark_stats["sync_times"].append(sync_time) benchmark_stats["sync_times"].append(sync_time)
# Update metrics # Update metrics
metrics.update({ metrics.update(
"step_time": step_time, {
"sync_time": sync_time, "step_time": step_time,
"data_fetch_time": data_fetch_time, "sync_time": sync_time,
"gpu_memory_gb": gpu_mem_gb, "data_fetch_time": data_fetch_time,
"gpu_memory_reserved_gb": gpu_mem_reserved_gb, "gpu_memory_gb": gpu_mem_gb,
}) "gpu_memory_reserved_gb": gpu_mem_reserved_gb,
}
)
log_metrics(metrics, step + 1, use_wandb, benchmark=config.benchmark) log_metrics(metrics, step + 1, use_wandb, benchmark=config.benchmark)
# Periodic checkpoint (for recovery, not for vLLM sync) # Periodic checkpoint (for recovery, not for vLLM sync)
if config.checkpoint_interval > 0 and (step + 1) % config.checkpoint_interval == 0: if (
config.checkpoint_interval > 0
and (step + 1) % config.checkpoint_interval == 0
):
save_checkpoint(model, tokenizer, config.save_path, step + 1) save_checkpoint(model, tokenizer, config.save_path, step + 1)
# === Cleanup === # === Cleanup ===
save_checkpoint(model, tokenizer, config.save_path, config.training_steps, is_final=True) save_checkpoint(
finalize_training(use_wandb, training_start_time, "shared_vllm", config.training_steps, benchmark_stats, config.benchmark) model, tokenizer, config.save_path, config.training_steps, is_final=True
)
finalize_training(
use_wandb,
training_start_time,
"shared_vllm",
config.training_steps,
benchmark_stats,
config.benchmark,
)
def train_lora(config: TrainingConfig): def train_lora(config: TrainingConfig):
@ -399,29 +463,33 @@ def train_lora(config: TrainingConfig):
- External vLLM server running with --enable-lora - External vLLM server running with --enable-lora
""" """
if not PEFT_AVAILABLE: if not PEFT_AVAILABLE:
raise RuntimeError("PEFT library required for LoRA mode. Install with: pip install peft") raise RuntimeError(
"PEFT library required for LoRA mode. Install with: pip install peft"
)
training_start_time = time.time() training_start_time = time.time()
# === Setup === # === Setup ===
use_wandb = setup_wandb(config) use_wandb = setup_wandb(config)
print("\n" + "="*60) print("\n" + "=" * 60)
print("LORA MODE (adapter-only training)") print("LORA MODE (adapter-only training)")
print("="*60) print("=" * 60)
print(f"Base model: {config.model_name}") print(f"Base model: {config.model_name}")
print(f"LoRA config: r={config.lora_r}, alpha={config.lora_alpha}") print(f"LoRA config: r={config.lora_r}, alpha={config.lora_alpha}")
print(f"Save path: {config.save_path}") print(f"Save path: {config.save_path}")
print(f"vLLM port: {config.vllm_port}") print(f"vLLM port: {config.vllm_port}")
print("="*60 + "\n") print("=" * 60 + "\n")
# Check external vLLM server # Check external vLLM server
print("[1/3] Checking external vLLM server...") print("[1/3] Checking external vLLM server...")
if not check_vllm_health(config.vllm_port): if not check_vllm_health(config.vllm_port):
print(f"\nERROR: vLLM server not running on port {config.vllm_port}") print(f"\nERROR: vLLM server not running on port {config.vllm_port}")
print("\nLoRA mode requires an external vLLM server. Start it first:") print("\nLoRA mode requires an external vLLM server. Start it first:")
print(f" python example_trainer/vllm_api_server.py --model {config.model_name} " print(
f"--port {config.vllm_port} --enable-lora --enforce-eager") f" python example_trainer/vllm_api_server.py --model {config.model_name} "
f"--port {config.vllm_port} --enable-lora --enforce-eager"
)
raise RuntimeError(f"External vLLM server required on port {config.vllm_port}") raise RuntimeError(f"External vLLM server required on port {config.vllm_port}")
print(f"vLLM server healthy on port {config.vllm_port}") print(f"vLLM server healthy on port {config.vllm_port}")
@ -459,10 +527,16 @@ def train_lora(config: TrainingConfig):
# Fetch data (with inference logprobs for proper GRPO) # Fetch data (with inference logprobs for proper GRPO)
data_fetch_start = time.time() data_fetch_start = time.time()
if len(batches) == 0: if len(batches) == 0:
batches, _ = get_data(config.batch_size, config.seq_len, config.atropos_url, batches, _ = get_data(
extract_inference_logprobs=True) config.batch_size,
config.seq_len,
config.atropos_url,
extract_inference_logprobs=True,
)
batch_data = batches.pop(0) batch_data = batches.pop(0)
token_batches, label_batches, advantage_batches, temperature_batches = batch_data[:4] token_batches, label_batches, advantage_batches, temperature_batches = (
batch_data[:4]
)
inference_logprob_batches = batch_data[4] if len(batch_data) > 4 else None inference_logprob_batches = batch_data[4] if len(batch_data) > 4 else None
data_fetch_time = time.time() - data_fetch_start data_fetch_time = time.time() - data_fetch_start
benchmark_stats["data_fetch_times"].append(data_fetch_time) benchmark_stats["data_fetch_times"].append(data_fetch_time)
@ -470,8 +544,12 @@ def train_lora(config: TrainingConfig):
# Training step with proper GRPO # Training step with proper GRPO
step_start = time.time() step_start = time.time()
metrics = run_training_step( metrics = run_training_step(
model, optimizer, model,
token_batches, label_batches, advantage_batches, temperature_batches, optimizer,
token_batches,
label_batches,
advantage_batches,
temperature_batches,
config, config,
inference_logprob_batches=inference_logprob_batches, inference_logprob_batches=inference_logprob_batches,
) )
@ -479,8 +557,12 @@ def train_lora(config: TrainingConfig):
benchmark_stats["step_times"].append(step_time) benchmark_stats["step_times"].append(step_time)
# GPU memory tracking # GPU memory tracking
gpu_mem_gb = torch.cuda.memory_allocated() / 1e9 if torch.cuda.is_available() else 0 gpu_mem_gb = (
gpu_mem_reserved_gb = torch.cuda.memory_reserved() / 1e9 if torch.cuda.is_available() else 0 torch.cuda.memory_allocated() / 1e9 if torch.cuda.is_available() else 0
)
gpu_mem_reserved_gb = (
torch.cuda.memory_reserved() / 1e9 if torch.cuda.is_available() else 0
)
benchmark_stats["gpu_memories"].append(gpu_mem_gb) benchmark_stats["gpu_memories"].append(gpu_mem_gb)
# Periodic adapter save + hot-swap # Periodic adapter save + hot-swap
@ -494,24 +576,35 @@ def train_lora(config: TrainingConfig):
benchmark_stats["sync_times"].append(sync_time) benchmark_stats["sync_times"].append(sync_time)
# Update metrics # Update metrics
metrics.update({ metrics.update(
"step_time": step_time, {
"sync_time": sync_time, "step_time": step_time,
"data_fetch_time": data_fetch_time, "sync_time": sync_time,
"gpu_memory_gb": gpu_mem_gb, "data_fetch_time": data_fetch_time,
"gpu_memory_reserved_gb": gpu_mem_reserved_gb, "gpu_memory_gb": gpu_mem_gb,
}) "gpu_memory_reserved_gb": gpu_mem_reserved_gb,
}
)
log_metrics(metrics, step + 1, use_wandb, benchmark=config.benchmark) log_metrics(metrics, step + 1, use_wandb, benchmark=config.benchmark)
# === Cleanup === # === Cleanup ===
final_sync_start = time.time() final_sync_start = time.time()
final_adapter_path = save_lora_checkpoint(model, config.save_path, config.training_steps, is_final=True) final_adapter_path = save_lora_checkpoint(
model, config.save_path, config.training_steps, is_final=True
)
_hotswap_lora_adapter(config.vllm_port, final_adapter_path, "final") _hotswap_lora_adapter(config.vllm_port, final_adapter_path, "final")
final_sync_time = time.time() - final_sync_start final_sync_time = time.time() - final_sync_start
benchmark_stats["sync_times"].append(final_sync_time) benchmark_stats["sync_times"].append(final_sync_time)
finalize_training(use_wandb, training_start_time, "lora_only", config.training_steps, benchmark_stats, config.benchmark) finalize_training(
use_wandb,
training_start_time,
"lora_only",
config.training_steps,
benchmark_stats,
config.benchmark,
)
# Save tokenizer # Save tokenizer
tokenizer_path = os.path.join(config.save_path, "tokenizer") tokenizer_path = os.path.join(config.save_path, "tokenizer")
@ -563,4 +656,3 @@ def _hotswap_lora_adapter(
except Exception as e: except Exception as e:
print(f" [LORA] ✗ Hot-swap request failed: {e}") print(f" [LORA] ✗ Hot-swap request failed: {e}")
return False return False

View file

@ -18,10 +18,10 @@ import wandb
from .config import TrainingConfig from .config import TrainingConfig
# Global storage for logprob alignment stats # Global storage for logprob alignment stats
_logprob_alignment_stats: Dict[str, float] = {} _logprob_alignment_stats: Dict[str, float] = {}
def setup_wandb(config: TrainingConfig) -> bool: def setup_wandb(config: TrainingConfig) -> bool:
""" """
Initialize Weights & Biases logging if enabled. Initialize Weights & Biases logging if enabled.
@ -134,7 +134,9 @@ def compute_grpo_loss(
# === GRPO/PPO Loss Computation === # === GRPO/PPO Loss Computation ===
if use_reference_logprobs and inference_logprobs is not None: if use_reference_logprobs and inference_logprobs is not None:
# Move inference logprobs to correct device/dtype # Move inference logprobs to correct device/dtype
ref_logprobs = inference_logprobs.to(logp_per_token.device, logp_per_token.dtype) ref_logprobs = inference_logprobs.to(
logp_per_token.device, logp_per_token.dtype
)
# NOTE: inference_logprobs uses 1.0 for masked (prompt) positions, actual negative values for generated # NOTE: inference_logprobs uses 1.0 for masked (prompt) positions, actual negative values for generated
with torch.no_grad(): with torch.no_grad():
@ -156,10 +158,16 @@ def compute_grpo_loss(
# Check if ref logprobs are negative (as they should be for generated tokens) # Check if ref logprobs are negative (as they should be for generated tokens)
# If ref_at_generated is close to 1.0, that means the 1.0 placeholder is being used # If ref_at_generated is close to 1.0, that means the 1.0 placeholder is being used
if ref_at_generated > 0.5: if ref_at_generated > 0.5:
print(f" [WARNING] ref_logprobs avg {ref_at_generated:.3f} (should be negative!)") print(
print(" [WARNING] This suggests inference_logprobs alignment is wrong") f" [WARNING] ref_logprobs avg {ref_at_generated:.3f} (should be negative!)"
)
print(
" [WARNING] This suggests inference_logprobs alignment is wrong"
)
elif abs(ref_at_generated - train_at_generated) > 2.0: elif abs(ref_at_generated - train_at_generated) > 2.0:
print(f" [DEBUG] Logprob gap: ref={ref_at_generated:.3f}, train={train_at_generated:.3f}") print(
f" [DEBUG] Logprob gap: ref={ref_at_generated:.3f}, train={train_at_generated:.3f}"
)
# Compute importance sampling ratio: policy(a|s) / policy_old(a|s) = exp(log policy - log policy_old) # Compute importance sampling ratio: policy(a|s) / policy_old(a|s) = exp(log policy - log policy_old)
log_ratio = logp_per_token - ref_logprobs log_ratio = logp_per_token - ref_logprobs
@ -192,7 +200,9 @@ def compute_grpo_loss(
# = exp(-log_ratio) + log_ratio - 1 # = exp(-log_ratio) + log_ratio - 1
kl_per_token = torch.exp(-log_ratio) + log_ratio - 1.0 kl_per_token = torch.exp(-log_ratio) + log_ratio - 1.0
kl_penalty = ((kl_per_token * mask).sum(dim=-1) / mask_sum).mean() kl_penalty = ((kl_per_token * mask).sum(dim=-1) / mask_sum).mean()
total_loss = (policy_loss + kl_coef * kl_penalty) / gradient_accumulation_steps total_loss = (
policy_loss + kl_coef * kl_penalty
) / gradient_accumulation_steps
else: else:
kl_penalty = torch.tensor(0.0, device=logp_per_token.device) kl_penalty = torch.tensor(0.0, device=logp_per_token.device)
total_loss = policy_loss / gradient_accumulation_steps total_loss = policy_loss / gradient_accumulation_steps
@ -200,7 +210,9 @@ def compute_grpo_loss(
# Compute metrics for logging # Compute metrics for logging
with torch.no_grad(): with torch.no_grad():
# Fraction of tokens where ratio was clipped # Fraction of tokens where ratio was clipped
clipped_fraction = ((ratio < 1.0 - clip_eps) | (ratio > 1.0 + clip_eps)).float() clipped_fraction = (
(ratio < 1.0 - clip_eps) | (ratio > 1.0 + clip_eps)
).float()
clipped_fraction = (clipped_fraction * mask).sum() / mask.sum() clipped_fraction = (clipped_fraction * mask).sum() / mask.sum()
# Mean ratio and KL for monitoring (using Schulman's estimator) # Mean ratio and KL for monitoring (using Schulman's estimator)
@ -256,7 +268,11 @@ def compute_grpo_loss(
"kl_penalty": kl_penalty.item() if torch.is_tensor(kl_penalty) else kl_penalty, "kl_penalty": kl_penalty.item() if torch.is_tensor(kl_penalty) else kl_penalty,
"mean_ratio": mean_ratio.item() if torch.is_tensor(mean_ratio) else mean_ratio, "mean_ratio": mean_ratio.item() if torch.is_tensor(mean_ratio) else mean_ratio,
"mean_kl": mean_kl.item() if torch.is_tensor(mean_kl) else mean_kl, "mean_kl": mean_kl.item() if torch.is_tensor(mean_kl) else mean_kl,
"clipped_fraction": clipped_fraction.item() if torch.is_tensor(clipped_fraction) else clipped_fraction, "clipped_fraction": (
clipped_fraction.item()
if torch.is_tensor(clipped_fraction)
else clipped_fraction
),
# Token-level alignment metrics (key for verifying weight sharing) # Token-level alignment metrics (key for verifying weight sharing)
"logprob_diff_mean": logprob_diff_mean, "logprob_diff_mean": logprob_diff_mean,
"logprob_diff_abs_mean": logprob_diff_abs_mean, "logprob_diff_abs_mean": logprob_diff_abs_mean,
@ -315,23 +331,25 @@ def run_training_step(
all_inference_logprobs: List[torch.Tensor] = [] all_inference_logprobs: List[torch.Tensor] = []
# Get GRPO hyperparameters from config # Get GRPO hyperparameters from config
kl_coef = getattr(config, 'kl_coef', 0.1) kl_coef = getattr(config, "kl_coef", 0.1)
clip_eps = getattr(config, 'clip_eps', 0.2) clip_eps = getattr(config, "clip_eps", 0.2)
use_reference_logprobs = getattr(config, 'use_reference_logprobs', True) use_reference_logprobs = getattr(config, "use_reference_logprobs", True)
# Accumulate gradients over micro-batches # Accumulate gradients over micro-batches
num_batches = len(token_batches) if token_batches else 1 num_batches = len(token_batches) if token_batches else 1
for batch_idx, (tokens, labels, advantages, temperatures) in enumerate(zip( for batch_idx, (tokens, labels, advantages, temperatures) in enumerate(
token_batches, label_batches, advantage_batches, temperature_batches zip(token_batches, label_batches, advantage_batches, temperature_batches)
)): ):
tokens = tokens.to(config.device) tokens = tokens.to(config.device)
labels = labels.to(config.device) labels = labels.to(config.device)
advantages = advantages.to(config.device) advantages = advantages.to(config.device)
# Get corresponding inference logprobs batch if available # Get corresponding inference logprobs batch if available
inf_logprobs = None inf_logprobs = None
if inference_logprob_batches is not None and batch_idx < len(inference_logprob_batches): if inference_logprob_batches is not None and batch_idx < len(
inference_logprob_batches
):
inf_logprobs = inference_logprob_batches[batch_idx] inf_logprobs = inference_logprob_batches[batch_idx]
loss, metrics = compute_grpo_loss( loss, metrics = compute_grpo_loss(
@ -363,12 +381,17 @@ def run_training_step(
# Accumulate token-level alignment metrics # Accumulate token-level alignment metrics
total_logprob_diff_mean += metrics.get("logprob_diff_mean", 0.0) total_logprob_diff_mean += metrics.get("logprob_diff_mean", 0.0)
total_logprob_diff_abs_mean += metrics.get("logprob_diff_abs_mean", 0.0) total_logprob_diff_abs_mean += metrics.get("logprob_diff_abs_mean", 0.0)
total_logprob_diff_max = max(total_logprob_diff_max, metrics.get("logprob_diff_max", 0.0)) total_logprob_diff_max = max(
total_logprob_diff_max, metrics.get("logprob_diff_max", 0.0)
)
# Collect logprobs for alignment monitoring # Collect logprobs for alignment monitoring
if "training_logprobs" in metrics and metrics["training_logprobs"] is not None: if "training_logprobs" in metrics and metrics["training_logprobs"] is not None:
all_training_logprobs.append(metrics["training_logprobs"]) all_training_logprobs.append(metrics["training_logprobs"])
if "inference_logprobs" in metrics and metrics["inference_logprobs"] is not None: if (
"inference_logprobs" in metrics
and metrics["inference_logprobs"] is not None
):
all_inference_logprobs.append(metrics["inference_logprobs"]) all_inference_logprobs.append(metrics["inference_logprobs"])
# Gradient clipping and optimizer step # Gradient clipping and optimizer step
@ -387,7 +410,7 @@ def run_training_step(
result = { result = {
"loss": total_loss, "loss": total_loss,
"grad_norm": grad_norm.item() if hasattr(grad_norm, 'item') else grad_norm, "grad_norm": grad_norm.item() if hasattr(grad_norm, "item") else grad_norm,
"pos_logp": total_pos_logp, "pos_logp": total_pos_logp,
"neg_logp": total_neg_logp, "neg_logp": total_neg_logp,
"pos_count": total_pos, "pos_count": total_pos,
@ -404,7 +427,9 @@ def run_training_step(
if all_training_logprobs: if all_training_logprobs:
train_flat = torch.cat(all_training_logprobs) train_flat = torch.cat(all_training_logprobs)
if train_flat.numel() > 0: if train_flat.numel() > 0:
_logprob_alignment_stats["logprobs/training_mean"] = train_flat.mean().item() _logprob_alignment_stats["logprobs/training_mean"] = (
train_flat.mean().item()
)
_logprob_alignment_stats["logprobs/training_std"] = train_flat.std().item() _logprob_alignment_stats["logprobs/training_std"] = train_flat.std().item()
if all_inference_logprobs: if all_inference_logprobs:
@ -415,8 +440,12 @@ def run_training_step(
# Token-level alignment metrics - THE key metric for verifying weight sharing # Token-level alignment metrics - THE key metric for verifying weight sharing
# diff_abs_mean close to 0 = weights are truly shared # diff_abs_mean close to 0 = weights are truly shared
_logprob_alignment_stats["alignment/diff_mean"] = total_logprob_diff_mean / num_batches _logprob_alignment_stats["alignment/diff_mean"] = (
_logprob_alignment_stats["alignment/diff_abs_mean"] = total_logprob_diff_abs_mean / num_batches total_logprob_diff_mean / num_batches
)
_logprob_alignment_stats["alignment/diff_abs_mean"] = (
total_logprob_diff_abs_mean / num_batches
)
_logprob_alignment_stats["alignment/diff_max"] = total_logprob_diff_max _logprob_alignment_stats["alignment/diff_max"] = total_logprob_diff_max
return result return result
@ -495,8 +524,13 @@ def log_metrics(
"grpo/clipped_fraction": clipped_frac, "grpo/clipped_fraction": clipped_frac,
} }
# Add timing metrics if present # Add timing metrics if present
for key in ["step_time", "sync_time", "data_fetch_time", for key in [
"gpu_memory_gb", "gpu_memory_reserved_gb"]: "step_time",
"sync_time",
"data_fetch_time",
"gpu_memory_gb",
"gpu_memory_reserved_gb",
]:
if key in metrics: if key in metrics:
log_dict[f"train/{key}"] = metrics[key] log_dict[f"train/{key}"] = metrics[key]
@ -549,7 +583,9 @@ def finalize_training(
total_step_time = sum(step_times) total_step_time = sum(step_times)
avg_sync_time = sum(sync_times) / len(sync_times) if sync_times else 0 avg_sync_time = sum(sync_times) / len(sync_times) if sync_times else 0
total_sync_time = sum(sync_times) total_sync_time = sum(sync_times)
avg_data_fetch = sum(data_fetch_times) / len(data_fetch_times) if data_fetch_times else 0 avg_data_fetch = (
sum(data_fetch_times) / len(data_fetch_times) if data_fetch_times else 0
)
total_data_fetch = sum(data_fetch_times) total_data_fetch = sum(data_fetch_times)
avg_gpu_mem = sum(gpu_memories) / len(gpu_memories) if gpu_memories else 0 avg_gpu_mem = sum(gpu_memories) / len(gpu_memories) if gpu_memories else 0
@ -557,13 +593,17 @@ def finalize_training(
print(f"\n{'='*70}") print(f"\n{'='*70}")
print(f"BENCHMARK SUMMARY ({mode})") print(f"BENCHMARK SUMMARY ({mode})")
print(f"{'='*70}") print(f"{'='*70}")
print(f" Total training time: {total_time:.2f}s ({total_time/60:.2f} min)") print(
f" Total training time: {total_time:.2f}s ({total_time/60:.2f} min)"
)
print(f" Total steps: {total_steps}") print(f" Total steps: {total_steps}")
print(" ") print(" ")
print(" TIMING BREAKDOWN:") print(" TIMING BREAKDOWN:")
print(f" Avg step time: {avg_step_time:.2f}s") print(f" Avg step time: {avg_step_time:.2f}s")
print(f" Total step time: {total_step_time:.2f}s") print(f" Total step time: {total_step_time:.2f}s")
print(f" Avg sync time: {avg_sync_time:.2f}s (x{len(sync_times)} syncs)") print(
f" Avg sync time: {avg_sync_time:.2f}s (x{len(sync_times)} syncs)"
)
print(f" Total sync time: {total_sync_time:.2f}s") print(f" Total sync time: {total_sync_time:.2f}s")
print(f" Avg data fetch time: {avg_data_fetch:.2f}s") print(f" Avg data fetch time: {avg_data_fetch:.2f}s")
print(f" Total data fetch time: {total_data_fetch:.2f}s") print(f" Total data fetch time: {total_data_fetch:.2f}s")
@ -584,4 +624,3 @@ def finalize_training(
wandb.finish() wandb.finish()
elif use_wandb: elif use_wandb:
wandb.finish() wandb.finish()

View file

@ -53,7 +53,7 @@ os.environ.setdefault("VLLM_USE_V1", "0")
# Set spawn method for multiprocessing (required for CUDA) # Set spawn method for multiprocessing (required for CUDA)
os.environ.setdefault("VLLM_WORKER_MULTIPROC_METHOD", "spawn") os.environ.setdefault("VLLM_WORKER_MULTIPROC_METHOD", "spawn")
try: try:
multiprocessing.set_start_method('spawn', force=True) multiprocessing.set_start_method("spawn", force=True)
except RuntimeError: except RuntimeError:
pass # Already set pass # Already set
@ -86,6 +86,7 @@ def _apply_patches_early() -> bool:
try: try:
import sys import sys
from pathlib import Path from pathlib import Path
# Add parent directory to path so we can import vllm_patching # Add parent directory to path so we can import vllm_patching
script_dir = Path(__file__).parent script_dir = Path(__file__).parent
if str(script_dir) not in sys.path: if str(script_dir) not in sys.path:
@ -106,6 +107,7 @@ def _apply_patches_early() -> bool:
except Exception as e: except Exception as e:
print(f"[vLLM Server] Error applying patches: {e}") print(f"[vLLM Server] Error applying patches: {e}")
import traceback import traceback
traceback.print_exc() traceback.print_exc()
return False return False
@ -145,17 +147,20 @@ except ImportError:
def add_argument(self, *args, **kwargs): def add_argument(self, *args, **kwargs):
# Remove 'deprecated' kwarg if present (not supported before Python 3.13) # Remove 'deprecated' kwarg if present (not supported before Python 3.13)
kwargs.pop('deprecated', None) kwargs.pop("deprecated", None)
return super().add_argument(*args, **kwargs) return super().add_argument(*args, **kwargs)
# set_ulimit might not exist in all vLLM versions # set_ulimit might not exist in all vLLM versions
try: try:
from vllm.utils import set_ulimit from vllm.utils import set_ulimit
except ImportError: except ImportError:
def set_ulimit() -> None: def set_ulimit() -> None:
"""No-op fallback for set_ulimit.""" """No-op fallback for set_ulimit."""
pass pass
from vllm.outputs import RequestOutput # noqa: F401, E402 from vllm.outputs import RequestOutput # noqa: F401, E402
from vllm.version import __version__ as VLLM_VERSION # noqa: E402 from vllm.version import __version__ as VLLM_VERSION # noqa: E402
@ -602,7 +607,9 @@ async def lora_load(request: LoraLoadRequest) -> JSONResponse:
) # vLLM needs unique int ID ) # vLLM needs unique int ID
bridge_state.lora_load_count += 1 bridge_state.lora_load_count += 1
logger.info(f"LoRA adapter loaded: {request.adapter_path} (id={bridge_state.active_lora_id})") logger.info(
f"LoRA adapter loaded: {request.adapter_path} (id={bridge_state.active_lora_id})"
)
return JSONResponse( return JSONResponse(
{ {

View file

@ -17,7 +17,6 @@ import requests
from .config import TrainingConfig from .config import TrainingConfig
# Global variable to keep track of the vLLM process # Global variable to keep track of the vLLM process
_vllm_process: Optional[subprocess.Popen] = None _vllm_process: Optional[subprocess.Popen] = None
@ -25,7 +24,7 @@ _vllm_process: Optional[subprocess.Popen] = None
def is_port_in_use(port: int) -> bool: def is_port_in_use(port: int) -> bool:
"""Check if a port is already in use.""" """Check if a port is already in use."""
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
return s.connect_ex(('localhost', port)) == 0 return s.connect_ex(("localhost", port)) == 0
def kill_process_on_port(port: int, timeout: float = 5.0) -> bool: def kill_process_on_port(port: int, timeout: float = 5.0) -> bool:
@ -42,13 +41,10 @@ def kill_process_on_port(port: int, timeout: float = 5.0) -> bool:
try: try:
# Try to find and kill the process using lsof (Linux/Mac) # Try to find and kill the process using lsof (Linux/Mac)
result = subprocess.run( result = subprocess.run(
["lsof", "-t", "-i", f":{port}"], ["lsof", "-t", "-i", f":{port}"], capture_output=True, text=True, timeout=5
capture_output=True,
text=True,
timeout=5
) )
if result.stdout.strip(): if result.stdout.strip():
pids = result.stdout.strip().split('\n') pids = result.stdout.strip().split("\n")
for pid in pids: for pid in pids:
try: try:
os.kill(int(pid), signal.SIGTERM) os.kill(int(pid), signal.SIGTERM)
@ -135,7 +131,9 @@ def launch_vllm_server(
if is_port_in_use(config.vllm_port): if is_port_in_use(config.vllm_port):
print(f" WARNING: Port {config.vllm_port} is already in use!") print(f" WARNING: Port {config.vllm_port} is already in use!")
if not kill_process_on_port(config.vllm_port): if not kill_process_on_port(config.vllm_port):
print(f" ERROR: Could not free port {config.vllm_port}. Please manually kill the process.") print(
f" ERROR: Could not free port {config.vllm_port}. Please manually kill the process."
)
print(f" Try: lsof -i :{config.vllm_port} | grep LISTEN") print(f" Try: lsof -i :{config.vllm_port} | grep LISTEN")
print(f" Or: pkill -f 'vllm.*{config.vllm_port}'") print(f" Or: pkill -f 'vllm.*{config.vllm_port}'")
return None return None
@ -209,7 +207,9 @@ def check_vllm_process_health() -> None:
global _vllm_process global _vllm_process
if _vllm_process is not None and _vllm_process.poll() is not None: if _vllm_process is not None and _vllm_process.poll() is not None:
print(f" WARNING: vLLM terminated unexpectedly (code: {_vllm_process.returncode})") print(
f" WARNING: vLLM terminated unexpectedly (code: {_vllm_process.returncode})"
)
_vllm_process = None _vllm_process = None
@ -299,7 +299,9 @@ def hotswap_lora_adapter(
print(f" [LORA] ✓ Hot-swapped adapter: {adapter_name} ({adapter_path})") print(f" [LORA] ✓ Hot-swapped adapter: {adapter_name} ({adapter_path})")
return True return True
else: else:
print(f" [LORA] ✗ Hot-swap failed: {response.status_code} - {response.text}") print(
f" [LORA] ✗ Hot-swap failed: {response.status_code} - {response.text}"
)
return False return False
except requests.exceptions.ConnectionError: except requests.exceptions.ConnectionError:
@ -308,4 +310,3 @@ def hotswap_lora_adapter(
except Exception as e: except Exception as e:
print(f" [LORA] ✗ Error during hot-swap: {e}") print(f" [LORA] ✗ Error during hot-swap: {e}")
return False return False

View file

@ -37,6 +37,7 @@ def _patch_lora_triton_for_blackwell() -> bool:
""" """
try: try:
import vllm import vllm
vllm_path = vllm.__path__[0] vllm_path = vllm.__path__[0]
kernel_utils_path = f"{vllm_path}/lora/ops/triton_ops/kernel_utils.py" kernel_utils_path = f"{vllm_path}/lora/ops/triton_ops/kernel_utils.py"
@ -45,42 +46,42 @@ def _patch_lora_triton_for_blackwell() -> bool:
print("[vLLM Patch] LoRA kernel_utils.py not found, skipping GDC patch") print("[vLLM Patch] LoRA kernel_utils.py not found, skipping GDC patch")
return False return False
with open(kernel_utils_path, 'r') as f: with open(kernel_utils_path, "r") as f:
content = f.read() content = f.read()
# Check if already patched # Check if already patched
if 'PATCHED FOR B200' in content: if "PATCHED FOR B200" in content:
print("[vLLM Patch] LoRA GDC already patched for B200") print("[vLLM Patch] LoRA GDC already patched for B200")
return True return True
modified = False modified = False
# Patch USE_GDC = True -> False # Patch USE_GDC = True -> False
if 'USE_GDC = True' in content: if "USE_GDC = True" in content:
content = content.replace( content = content.replace(
'USE_GDC = True', "USE_GDC = True",
'USE_GDC = False # PATCHED FOR B200 - GDC causes Triton compilation failure' "USE_GDC = False # PATCHED FOR B200 - GDC causes Triton compilation failure",
) )
modified = True modified = True
# Patch USE_GDC: tl.constexpr = True -> False # Patch USE_GDC: tl.constexpr = True -> False
if 'USE_GDC: tl.constexpr = True' in content: if "USE_GDC: tl.constexpr = True" in content:
content = content.replace( content = content.replace(
'USE_GDC: tl.constexpr = True', "USE_GDC: tl.constexpr = True",
'USE_GDC: tl.constexpr = False # PATCHED FOR B200' "USE_GDC: tl.constexpr = False # PATCHED FOR B200",
) )
modified = True modified = True
# Patch the gdc_wait call itself # Patch the gdc_wait call itself
if 'tl.extra.cuda.gdc_wait()' in content: if "tl.extra.cuda.gdc_wait()" in content:
content = content.replace( content = content.replace(
'tl.extra.cuda.gdc_wait()', "tl.extra.cuda.gdc_wait()",
'pass # tl.extra.cuda.gdc_wait() PATCHED FOR B200 - disabled' "pass # tl.extra.cuda.gdc_wait() PATCHED FOR B200 - disabled",
) )
modified = True modified = True
if modified: if modified:
with open(kernel_utils_path, 'w') as f: with open(kernel_utils_path, "w") as f:
f.write(content) f.write(content)
print(f"[vLLM Patch] ✓ Patched LoRA Triton GDC in {kernel_utils_path}") print(f"[vLLM Patch] ✓ Patched LoRA Triton GDC in {kernel_utils_path}")