[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
This commit is contained in:
pre-commit-ci[bot] 2026-02-06 06:46:14 +00:00 committed by Jai Suphavadeeprasit
parent d07ab3e3ce
commit 5cfd1929f1
19 changed files with 708 additions and 452 deletions

View file

@ -373,7 +373,9 @@ class GSM8KEvalEnv(BaseEnv):
# Create evaluation tasks
async def eval_task(item):
return await self.rollout_and_score_eval(item, self.server.servers[0].config)
return await self.rollout_and_score_eval(
item, self.server.servers[0].config
)
tasks = [eval_task(item) for item in self.eval_items]

View file

@ -299,6 +299,7 @@ class MathEnv(BaseEnv):
if not self.config.run_evaluation:
return
import time
start_time = time.time()
eval_tasks = []
@ -320,9 +321,7 @@ class MathEnv(BaseEnv):
metrics[f"{subset}_accuracy"] = accuracy
metrics[f"{subset}_total"] = len(scores)
metrics[f"{subset}_correct"] = sum(scores)
self.eval_metrics.append(
(f"eval/{subset}_percent_correct", accuracy)
)
self.eval_metrics.append((f"eval/{subset}_percent_correct", accuracy))
# overall score
all_scores = []
@ -332,9 +331,7 @@ class MathEnv(BaseEnv):
metrics["overall_accuracy"] = overall_accuracy
metrics["overall_total"] = len(all_scores)
metrics["overall_correct"] = sum(all_scores)
self.eval_metrics.append(
("eval/overall_percent_correct", overall_accuracy)
)
self.eval_metrics.append(("eval/overall_percent_correct", overall_accuracy))
end_time = time.time()
@ -342,7 +339,9 @@ class MathEnv(BaseEnv):
print("\n" + "=" * 60)
print("Math Zero Evaluation Results")
print("=" * 60)
print(f"Overall Accuracy: {overall_accuracy:.2%} ({sum(all_scores)}/{len(all_scores)})")
print(
f"Overall Accuracy: {overall_accuracy:.2%} ({sum(all_scores)}/{len(all_scores)})"
)
print("\nPer-subset breakdown:")
for subset, scores in sorted(task_lists.items()):
acc = sum(scores) / len(scores)

View file

@ -542,4 +542,3 @@ python -m example_trainer.vllm_api_server # NOT direct vllm commands
| `vllm_api_server.py` | Streamlined vLLM server for training |
| `vllm_manager.py` | vLLM process lifecycle management |
| `checkpointing.py` | Save/load checkpoints and adapters |

View file

@ -20,9 +20,9 @@ Usage:
train_legacy(config)
"""
from .cli import config_from_args, parse_args
from .config import TrainingConfig
from .trainers import train_legacy, train_shared_vllm, train_lora
from .cli import parse_args, config_from_args
from .trainers import train_legacy, train_lora, train_shared_vllm
__all__ = [
"TrainingConfig",

View file

@ -15,7 +15,9 @@ from tenacity import retry, stop_after_attempt, wait_exponential
from .config import TrainingConfig
def check_atropos_api(url: str = "http://localhost:8000", timeout: float = 30.0) -> bool:
def check_atropos_api(
url: str = "http://localhost:8000", timeout: float = 30.0
) -> bool:
"""
Check if the Atropos API server is reachable.
@ -99,4 +101,3 @@ def get_batch(url: str = "http://localhost:8000"):
raise RuntimeError(f"Atropos API error: {data.get('message', 'Unknown error')}")
return data

View file

@ -89,11 +89,14 @@ def save_checkpoint(
# Count how many were non-contiguous (views into fused tensors)
view_count = sum(
1 for name, param in model.named_parameters()
1
for name, param in model.named_parameters()
if not param.is_contiguous() or param.storage_offset() != 0
)
if view_count > 0:
print(f" [Checkpoint] Unfused {view_count} view tensors (qkv/gate_up fusions)")
print(
f" [Checkpoint] Unfused {view_count} view tensors (qkv/gate_up fusions)"
)
# Save state dict manually, then save config separately
torch.save(state_dict, os.path.join(checkpoint_path, "pytorch_model.bin"))
@ -102,6 +105,7 @@ def save_checkpoint(
# CRITICAL: Clean up the copied state_dict to free ~8GB GPU memory!
del state_dict
import gc
gc.collect()
torch.cuda.empty_cache()
else:
@ -151,4 +155,3 @@ def save_lora_checkpoint(
print(" Adapter saved.")
return adapter_path

View file

@ -11,16 +11,17 @@ import torch
from .config import TrainingConfig
# =============================================================================
# Argument Group Builders (modular, reusable)
# =============================================================================
def add_model_args(parser: argparse.ArgumentParser) -> None:
"""Add model-related arguments."""
group = parser.add_argument_group("Model")
group.add_argument(
"--model", "--model-name",
"--model",
"--model-name",
type=str,
required=True,
dest="model_name",
@ -67,7 +68,7 @@ def add_training_args(parser: argparse.ArgumentParser) -> None:
choices=["adamw", "adamw_8bit", "adamw_cpu", "adafactor"],
default="adamw_8bit",
help="Optimizer: 'adamw' (full precision), 'adamw_8bit' (8-bit states), "
"'adamw_cpu' (CPU offload), 'adafactor' (no momentum)",
"'adamw_cpu' (CPU offload), 'adafactor' (no momentum)",
)
group.add_argument(
"--device",
@ -121,7 +122,8 @@ def add_vllm_args(parser: argparse.ArgumentParser) -> None:
help="Port for the vLLM server",
)
group.add_argument(
"--gpu-memory-utilization", "--vllm-gpu-memory-utilization",
"--gpu-memory-utilization",
"--vllm-gpu-memory-utilization",
type=float,
default=0.45,
dest="gpu_memory_utilization",
@ -203,7 +205,9 @@ def add_lora_args(parser: argparse.ArgumentParser) -> None:
"""Add LoRA-specific arguments."""
group = parser.add_argument_group("LoRA Configuration")
group.add_argument("--lora-r", type=int, default=16, help="LoRA rank")
group.add_argument("--lora-alpha", type=int, default=32, help="LoRA alpha (scaling factor)")
group.add_argument(
"--lora-alpha", type=int, default=32, help="LoRA alpha (scaling factor)"
)
group.add_argument("--lora-dropout", type=float, default=0.05, help="LoRA dropout")
group.add_argument(
"--lora-target-modules",
@ -219,8 +223,12 @@ def add_distributed_args(parser: argparse.ArgumentParser) -> None:
group = parser.add_argument_group("Distributed Training")
group.add_argument("--trainer-rank", type=int, default=0, help="Trainer rank")
group.add_argument("--world-size", type=int, default=1, help="World size")
group.add_argument("--init-method", type=str, default="env://", help="Distributed init method")
group.add_argument("--num-inference-nodes", type=int, default=0, help="Number of inference nodes")
group.add_argument(
"--init-method", type=str, default="env://", help="Distributed init method"
)
group.add_argument(
"--num-inference-nodes", type=int, default=0, help="Number of inference nodes"
)
def add_debug_args(parser: argparse.ArgumentParser) -> None:
@ -248,6 +256,7 @@ def add_debug_args(parser: argparse.ArgumentParser) -> None:
# Parser Builders
# =============================================================================
def create_base_parser(description: str) -> argparse.ArgumentParser:
"""Create a base parser with common formatting."""
return argparse.ArgumentParser(
@ -299,6 +308,7 @@ def create_unified_parser() -> argparse.ArgumentParser:
# Legacy API (backwards compatibility)
# =============================================================================
def parse_args() -> argparse.Namespace:
"""
Parse command-line arguments for the GRPO trainer (grpo.py).

View file

@ -35,9 +35,9 @@ class TrainingConfig(BaseModel):
optimizer: Literal["adamw", "adamw_8bit", "adamw_cpu", "adafactor"] = Field(
"adamw_8bit",
description="Optimizer to use: 'adamw' (full precision, ~32GB GPU), "
"'adamw_8bit' (8-bit states, ~8GB GPU, requires bitsandbytes), "
"'adamw_cpu' (CPU offload, ~0GB GPU, slower), "
"'adafactor' (no momentum, ~8GB GPU)"
"'adamw_8bit' (8-bit states, ~8GB GPU, requires bitsandbytes), "
"'adamw_cpu' (CPU offload, ~0GB GPU, slower), "
"'adafactor' (no momentum, ~8GB GPU)",
)
# === GRPO/PPO Hyperparameters ===
@ -69,12 +69,10 @@ class TrainingConfig(BaseModel):
# === Device & Storage ===
device: str = Field(
"cuda" if torch.cuda.is_available() else "cpu",
description="Device to train on"
"cuda" if torch.cuda.is_available() else "cpu", description="Device to train on"
)
save_path: str = Field(
"trained_model_checkpoints",
description="Base path to save model checkpoints"
"trained_model_checkpoints", description="Base path to save model checkpoints"
)
checkpoint_interval: int = Field(
3,
@ -121,9 +119,7 @@ class TrainingConfig(BaseModel):
trainer_rank: int = Field(
0, description="Rank of this trainer in the distributed group"
)
world_size: int = Field(
1, description="Total processes in the distributed group"
)
world_size: int = Field(1, description="Total processes in the distributed group")
init_method: str = Field(
"env://",
description=(
@ -198,4 +194,3 @@ class TrainingConfig(BaseModel):
"Default is http://localhost:8000. Change for concurrent tests."
),
)

View file

@ -92,28 +92,30 @@ def pad_data_to_good_offset(
# Process each sample in the item
for i in range(len(item["tokens"])):
seq_len = len(item["tokens"][i])
lengths.append(
math.ceil((seq_len - 1) / good_multiple) * good_multiple
)
lengths.append(math.ceil((seq_len - 1) / good_multiple) * good_multiple)
# Create labels with padding (-100 for masked positions)
label_item = np.concatenate([
np.array(item["masks"][i]),
np.full(
max(0, token_setup_len - seq_len),
-100,
dtype=np.int32,
),
])
label_item = np.concatenate(
[
np.array(item["masks"][i]),
np.full(
max(0, token_setup_len - seq_len),
-100,
dtype=np.int32,
),
]
)
# Pad tokens
item["tokens"][i] = np.concatenate([
np.array(item["tokens"][i]),
np.zeros(
max(0, token_setup_len - seq_len),
dtype=np.int32,
),
])
item["tokens"][i] = np.concatenate(
[
np.array(item["tokens"][i]),
np.zeros(
max(0, token_setup_len - seq_len),
dtype=np.int32,
),
]
)
input_ids.append(item["tokens"][i][:-1]) # Remove last for causal
labels.append(label_item[1:]) # Shift by 1 for causal
@ -126,7 +128,9 @@ def pad_data_to_good_offset(
# We just need to pad to match the sequence length
if extract_inference_logprobs and "inference_logprobs" in item:
if i < len(item["inference_logprobs"]):
raw_logprobs = np.array(item["inference_logprobs"][i], dtype=np.float32)
raw_logprobs = np.array(
item["inference_logprobs"][i], dtype=np.float32
)
has_any_logprobs = True
# Create padded logprobs array matching token_setup_len
@ -141,10 +145,14 @@ def pad_data_to_good_offset(
inference_logprobs_padded.append(padded_logprobs[1:])
else:
# No logprobs for this sample, use 1.0
inference_logprobs_padded.append(np.full(token_setup_len - 1, 1.0, dtype=np.float32))
inference_logprobs_padded.append(
np.full(token_setup_len - 1, 1.0, dtype=np.float32)
)
elif extract_inference_logprobs:
# No inference_logprobs in item, use 1.0
inference_logprobs_padded.append(np.full(token_setup_len - 1, 1.0, dtype=np.float32))
inference_logprobs_padded.append(
np.full(token_setup_len - 1, 1.0, dtype=np.float32)
)
# Extract temperature (priority: override > generation_params > group_overrides > 1.0)
t = 1.0
@ -155,9 +163,13 @@ def pad_data_to_good_offset(
and ("temperature" in item["overrides"][i])
):
t = float(item["overrides"][i]["temperature"])
elif item.get("generation_params") and ("temperature" in item["generation_params"]):
elif item.get("generation_params") and (
"temperature" in item["generation_params"]
):
t = float(item["generation_params"]["temperature"])
elif item.get("group_overrides") and ("temperature" in item["group_overrides"]):
elif item.get("group_overrides") and (
"temperature" in item["group_overrides"]
):
t = float(item["group_overrides"]["temperature"])
temperatures.append(t)
@ -172,19 +184,15 @@ def pad_data_to_good_offset(
start = i * batch_size
end = (i + 1) * batch_size
token_batches.append(
torch.tensor(np.stack(input_ids[start:end], axis=0))
)
label_batches.append(
torch.tensor(np.stack(labels[start:end], axis=0))
)
token_batches.append(torch.tensor(np.stack(input_ids[start:end], axis=0)))
label_batches.append(torch.tensor(np.stack(labels[start:end], axis=0)))
advantage_batches.append(
torch.tensor(np.stack(advantages[start:end], axis=0)).view(-1, 1)
)
temperature_batches.append(
torch.tensor(
np.array(temperatures[start:end], dtype=np.float32)
).view(-1, 1, 1)
torch.tensor(np.array(temperatures[start:end], dtype=np.float32)).view(
-1, 1, 1
)
)
# Batch inference logprobs (same shape as labels)
@ -194,9 +202,19 @@ def pad_data_to_good_offset(
)
# Return inference logprob batches if we have any real logprobs
final_logprob_batches = inference_logprob_batches if (has_any_logprobs and inference_logprob_batches) else None
final_logprob_batches = (
inference_logprob_batches
if (has_any_logprobs and inference_logprob_batches)
else None
)
return token_batches, label_batches, advantage_batches, temperature_batches, final_logprob_batches
return (
token_batches,
label_batches,
advantage_batches,
temperature_batches,
final_logprob_batches,
)
def get_data(
@ -205,13 +223,15 @@ def get_data(
atropos_url: str = "http://localhost:8000",
extract_inference_logprobs: bool = True,
) -> Tuple[
List[Tuple[
List[torch.Tensor], # token_batches
List[torch.Tensor], # label_batches
List[torch.Tensor], # advantage_batches
List[torch.Tensor], # temperature_batches
Optional[List[torch.Tensor]], # inference_logprob_batches
]],
List[
Tuple[
List[torch.Tensor], # token_batches
List[torch.Tensor], # label_batches
List[torch.Tensor], # advantage_batches
List[torch.Tensor], # temperature_batches
Optional[List[torch.Tensor]], # inference_logprob_batches
]
],
None, # Legacy return (no longer used)
]:
"""
@ -241,18 +261,37 @@ def get_data(
if data["batch"] is not None:
# DEBUG: Check if inference_logprobs exists in the data
if not _logged_logprob_warning:
has_logprobs = any("inference_logprobs" in item for item in data["batch"])
has_logprobs = any(
"inference_logprobs" in item for item in data["batch"]
)
if has_logprobs:
# Check if they're non-empty
sample_item = next((item for item in data["batch"] if "inference_logprobs" in item), None)
sample_item = next(
(
item
for item in data["batch"]
if "inference_logprobs" in item
),
None,
)
if sample_item and sample_item.get("inference_logprobs"):
sample_lp = sample_item["inference_logprobs"][0] if sample_item["inference_logprobs"] else []
print(f" [Data] ✓ inference_logprobs found in batch (sample len: {len(sample_lp)})")
sample_lp = (
sample_item["inference_logprobs"][0]
if sample_item["inference_logprobs"]
else []
)
print(
f" [Data] ✓ inference_logprobs found in batch (sample len: {len(sample_lp)})"
)
else:
print(" [Data] ⚠ inference_logprobs key exists but is empty!")
print(
" [Data] ⚠ inference_logprobs key exists but is empty!"
)
else:
print(" [Data] ⚠ NO inference_logprobs in batch data!")
print(f" [Data] Keys in first item: {list(data['batch'][0].keys())}")
print(
f" [Data] Keys in first item: {list(data['batch'][0].keys())}"
)
_logged_logprob_warning = True
# Save batch for debugging
@ -260,11 +299,24 @@ def get_data(
json.dump(data, f)
# Process and accumulate batches (now includes batched inference logprobs)
token_batches, label_batches, adv_batches, temp_batches, inf_logprob_batches = \
pad_data_to_good_offset(data, batch_size, extract_inference_logprobs)
(
token_batches,
label_batches,
adv_batches,
temp_batches,
inf_logprob_batches,
) = pad_data_to_good_offset(data, batch_size, extract_inference_logprobs)
# Include inference logprob batches in the tuple
batches.append((token_batches, label_batches, adv_batches, temp_batches, inf_logprob_batches))
batches.append(
(
token_batches,
label_batches,
adv_batches,
temp_batches,
inf_logprob_batches,
)
)
elif len(batches) > 0:
# Return accumulated batches when no more data
@ -272,4 +324,3 @@ def get_data(
else:
# Wait for data
time.sleep(1)

View file

@ -19,8 +19,8 @@ Usage:
--weight-bridge-mode lora_only --lora-r 16 --lora-alpha 32
"""
from .cli import parse_args, config_from_args
from .trainers import train_legacy, train_shared_vllm, train_lora
from .cli import config_from_args, parse_args
from .trainers import train_legacy, train_lora, train_shared_vllm
def main():
@ -28,9 +28,9 @@ def main():
args = parse_args()
config = config_from_args(args)
print("\n" + "="*60)
print("\n" + "=" * 60)
print("GRPO TRAINER")
print("="*60)
print("=" * 60)
print(f"Model: {config.model_name}")
print(f"Mode: {config.weight_bridge_mode}")
print(f"Training steps: {config.training_steps}")

View file

@ -20,6 +20,7 @@ from .config import TrainingConfig
# Import PEFT for LoRA training
try:
from peft import LoraConfig, TaskType, get_peft_model
PEFT_AVAILABLE = True
except ImportError:
PEFT_AVAILABLE = False
@ -37,6 +38,7 @@ def _get_attention_implementation() -> str:
"""
try:
import flash_attn # noqa: F401
return "flash_attention_2"
except ImportError:
return "sdpa"
@ -61,12 +63,19 @@ def _load_model_with_attention(
"""
# Select the loader function based on mode
# from_config: creates empty shell (meta device), from_pretrained: loads weights
loader = AutoModelForCausalLM.from_config if from_config else AutoModelForCausalLM.from_pretrained
loader = (
AutoModelForCausalLM.from_config
if from_config
else AutoModelForCausalLM.from_pretrained
)
# Try attention implementations in order of preference
for attn_impl in ["flash_attention_2", "sdpa"]:
# Skip flash_attention_2 if not available
if attn_impl == "flash_attention_2" and _get_attention_implementation() != "flash_attention_2":
if (
attn_impl == "flash_attention_2"
and _get_attention_implementation() != "flash_attention_2"
):
continue
try:
@ -86,6 +95,7 @@ def _load_model_with_attention(
# Should never reach here, but just in case
raise RuntimeError("Failed to load model with any attention implementation")
def load_model_and_tokenizer(
config: TrainingConfig,
single_copy: bool = False,
@ -178,7 +188,7 @@ def _load_model_with_lora(config: TrainingConfig) -> torch.nn.Module:
Returns:
PEFT model with LoRA adapters applied
"""
if not PEFT_AVAILABLE: # Yeah no PEFT is needed no matter what bless huggingface
if not PEFT_AVAILABLE: # Yeah no PEFT is needed no matter what bless huggingface
raise RuntimeError("PEFT library not available. Install with: pip install peft")
print("[Setup] Loading base model for LoRA mode...")
@ -208,7 +218,9 @@ def _load_model_with_lora(config: TrainingConfig) -> torch.nn.Module:
return model
def _setup_gradient_checkpointing(model: torch.nn.Module, config: TrainingConfig) -> None:
def _setup_gradient_checkpointing(
model: torch.nn.Module, config: TrainingConfig
) -> None:
"""Configure gradient checkpointing for the model."""
# Disable KV cache - incompatible with gradient checkpointing
model.config.use_cache = False
@ -297,7 +309,9 @@ def _attach_to_vllm_shared_tensors(
ipc_handles, vllm_to_hf_mapping, config
)
print(f"[Setup] Attached {attached_count} tensors ({fused_count} from fused layers)")
print(
f"[Setup] Attached {attached_count} tensors ({fused_count} from fused layers)"
)
if attached_count == 0:
print("[Setup] Could not attach any tensors, falling back to regular loading")
@ -322,6 +336,7 @@ def _attach_to_vllm_shared_tensors(
def _deserialize_ipc_handles(handles_raw: dict) -> dict:
"""Deserialize base64-encoded bytes in IPC handles."""
def deserialize(handles):
result = {}
for k, v in handles.items():
@ -333,6 +348,7 @@ def _deserialize_ipc_handles(handles_raw: dict) -> dict:
else:
result[k] = v
return result
return deserialize(handles_raw)
@ -387,8 +403,14 @@ def _reconstruct_shared_tensors(
event_sync_required = ipc_info["event_sync_required"]
share_tuple = (
device_index, ipc_handle, storage_size, storage_offset_orig,
ref_counter_handle, ref_counter_offset, event_handle, event_sync_required,
device_index,
ipc_handle,
storage_size,
storage_offset_orig,
ref_counter_handle,
ref_counter_offset,
event_handle,
event_sync_required,
)
storage = torch.UntypedStorage._new_shared_cuda(*share_tuple)
@ -424,7 +446,9 @@ def _reconstruct_shared_tensors(
if slice_dim == 0:
tensor = full_tensor[slice_start:slice_end]
else:
tensor = full_tensor.narrow(slice_dim, slice_start, slice_end - slice_start)
tensor = full_tensor.narrow(
slice_dim, slice_start, slice_end - slice_start
)
tensor.requires_grad_(True)
hf_state_dict[hf_name] = tensor
@ -459,12 +483,16 @@ def _validate_mapping_coverage(
# Note: attached_count may be > param_count because state_dict includes buffers
# while named_parameters only counts trainable params
print(f"[Setup] Mapping coverage: {attached_count} tensors for {hf_param_count} parameters "
f"(>100% is OK - includes buffers)")
print(
f"[Setup] Mapping coverage: {attached_count} tensors for {hf_param_count} parameters "
f"(>100% is OK - includes buffers)"
)
if mapping_coverage < 0.90:
unmapped_params = set(model.state_dict().keys()) - set(hf_state_dict.keys())
warning_msg = f"[Setup] WARNING: Low mapping coverage ({mapping_coverage:.1%})\n"
warning_msg = (
f"[Setup] WARNING: Low mapping coverage ({mapping_coverage:.1%})\n"
)
warning_msg += f"Unmapped parameters ({len(unmapped_params)}):\n"
for name in list(unmapped_params)[:20]:
warning_msg += f" - {name}\n"
@ -484,11 +512,17 @@ def _initialize_meta_tensors(
config: TrainingConfig,
) -> None:
"""Initialize any remaining meta tensors after loading."""
meta_params = [name for name, p in model.named_parameters() if p.device.type == "meta"]
meta_buffers = [name for name, b in model.named_buffers() if b.device.type == "meta"]
meta_params = [
name for name, p in model.named_parameters() if p.device.type == "meta"
]
meta_buffers = [
name for name, b in model.named_buffers() if b.device.type == "meta"
]
if config.debug_loading:
print(f"\n[DIAGNOSTIC] Meta params: {len(meta_params)}, Meta buffers: {len(meta_buffers)}")
print(
f"\n[DIAGNOSTIC] Meta params: {len(meta_params)}, Meta buffers: {len(meta_buffers)}"
)
def get_parent_and_name(model, full_name):
parts = full_name.split(".")
@ -526,11 +560,15 @@ def _initialize_meta_tensors(
dim = buffer.shape[0] * 2
# Get rope_theta from model config (default 10000.0 for LLaMA, but Qwen3 uses 5000000!)
rope_theta = getattr(model.config, "rope_theta", 10000.0)
inv_freq = 1.0 / (rope_theta ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))
inv_freq = 1.0 / (
rope_theta ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)
)
new_buffer = inv_freq.to(dtype=buffer.dtype, device=device)
print(f"[Setup] Initialized {name} with rope_theta={rope_theta}")
else:
new_buffer = torch.zeros(buffer.shape, dtype=buffer.dtype, device=device)
new_buffer = torch.zeros(
buffer.shape, dtype=buffer.dtype, device=device
)
parent, attr_name = get_parent_and_name(model, name)
parent.register_buffer(attr_name, new_buffer)
@ -544,8 +582,12 @@ def _initialize_meta_tensors(
def _validate_no_meta_tensors(model: torch.nn.Module) -> None:
"""Ensure no parameters or buffers are still on meta device."""
final_meta_params = [name for name, p in model.named_parameters() if p.device.type == "meta"]
final_meta_buffers = [name for name, b in model.named_buffers() if b.device.type == "meta"]
final_meta_params = [
name for name, p in model.named_parameters() if p.device.type == "meta"
]
final_meta_buffers = [
name for name, b in model.named_buffers() if b.device.type == "meta"
]
if final_meta_params or final_meta_buffers:
error_msg = "[Setup] CRITICAL ERROR: Some tensors are still on meta device!\n"
@ -588,7 +630,9 @@ def _create_vllm_to_hf_mapping(
model_config = model.config
hidden_size = getattr(model_config, "hidden_size", 4096)
num_attention_heads = getattr(model_config, "num_attention_heads", 32)
num_key_value_heads = getattr(model_config, "num_key_value_heads", num_attention_heads)
num_key_value_heads = getattr(
model_config, "num_key_value_heads", num_attention_heads
)
intermediate_size = getattr(model_config, "intermediate_size", hidden_size * 4)
# Try to get head_dim from config (some models like Qwen3 have this)
@ -639,8 +683,10 @@ def _create_vllm_to_hf_mapping(
up_size = intermediate_size
# Always print sizes for debugging weight sharing issues
print(f"[Mapping] Model config: hidden={hidden_size}, heads={num_attention_heads}, "
f"kv_heads={num_key_value_heads}, head_dim={head_dim}")
print(
f"[Mapping] Model config: hidden={hidden_size}, heads={num_attention_heads}, "
f"kv_heads={num_key_value_heads}, head_dim={head_dim}"
)
print(f"[Mapping] QKV sizes from HF model: q={q_size}, k={k_size}, v={v_size}")
print(f"[Mapping] Gate/Up sizes from HF model: gate={gate_size}, up={up_size}")
@ -714,7 +760,8 @@ def _create_vllm_to_hf_mapping(
if debug:
direct = sum(1 for v in mapping.values() if isinstance(v, str))
fused = sum(1 for v in mapping.values() if isinstance(v, dict))
print(f"[Mapping] Total: {len(mapping)} mapped ({direct} direct, {fused} fused)")
print(
f"[Mapping] Total: {len(mapping)} mapped ({direct} direct, {fused} fused)"
)
return mapping

View file

@ -60,10 +60,13 @@ def wait_for_bridge_config(config_path: str, timeout: int = 60) -> bool:
if os.path.exists(config_path):
try:
import json
with open(config_path, 'r') as f:
with open(config_path, "r") as f:
config = json.load(f)
if config.get('ipc_handles') and len(config['ipc_handles']) > 0:
print(f"[Run] ✓ Bridge config ready with {len(config['ipc_handles'])} IPC handles")
if config.get("ipc_handles") and len(config["ipc_handles"]) > 0:
print(
f"[Run] ✓ Bridge config ready with {len(config['ipc_handles'])} IPC handles"
)
return True
except Exception:
pass
@ -79,7 +82,7 @@ def main():
args = parser.parse_args()
# Create log directory
log_dir = getattr(args, 'log_dir', './logs')
log_dir = getattr(args, "log_dir", "./logs")
os.makedirs(log_dir, exist_ok=True)
# Bridge config path
@ -91,16 +94,16 @@ def main():
print("[Run] Removed old bridge config")
# === Print Configuration ===
print("\n" + "="*60)
print("\n" + "=" * 60)
print("STARTING UNIFIED GRPO TRAINER (shared_vllm mode)")
print("="*60)
print("=" * 60)
print(f"Model: {args.model_name}")
print(f"vLLM port: {args.vllm_port}")
print(f"GPU memory utilization: {args.gpu_memory_utilization}")
print(f"Training steps: {args.training_steps}")
print(f"Optimizer: {args.optimizer}")
print(f"GRPO: kl_coef={args.kl_coef}, clip_eps={args.clip_eps}")
print("="*60 + "\n")
print("=" * 60 + "\n")
# Get the path to vllm_api_server.py
script_dir = Path(__file__).parent
@ -126,12 +129,19 @@ def main():
# Build vLLM command
vllm_cmd = [
sys.executable, "-u", str(vllm_server_script),
"--model", args.model_name,
"--port", str(args.vllm_port),
"--dtype", args.dtype,
"--gpu-memory-utilization", str(args.gpu_memory_utilization),
"--max-model-len", str(args.max_model_len),
sys.executable,
"-u",
str(vllm_server_script),
"--model",
args.model_name,
"--port",
str(args.vllm_port),
"--dtype",
args.dtype,
"--gpu-memory-utilization",
str(args.gpu_memory_utilization),
"--max-model-len",
str(args.max_model_len),
"--enforce-eager", # Required for shared weights
]
@ -212,6 +222,7 @@ def main():
except Exception as e:
print(f"\n[Run] ✗ Training failed: {e}")
import traceback
traceback.print_exc()
sys.exit(1)

View file

@ -138,4 +138,3 @@ if [ -d "$LOG_DIR/checkpoints" ]; then
}' | jq '.text[0]' | tee "${LOG_DIR}/trained_response.txt"
fi
fi

View file

@ -143,4 +143,3 @@ curl -s -X POST "http://localhost:${VLLM_PORT}/generate" \
"max_tokens": 100,
"temperature": 0.1
}' | jq '.text[0]' | tee "${LOG_DIR}/trained_response.txt"

View file

@ -25,7 +25,10 @@ class CPUOffloadAdamW(torch.optim.Optimizer):
Full precision (no quantization), but states stay on CPU RAM instead of GPU.
Trade-off: Slower (~2x) but uses ~0GB GPU memory for optimizer states.
"""
def __init__(self, params, lr=1e-5, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.01):
def __init__(
self, params, lr=1e-5, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.01
):
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
super().__init__(params, defaults)
@ -33,10 +36,10 @@ class CPUOffloadAdamW(torch.optim.Optimizer):
"""Lazily initialize state on CPU."""
state = self.state[p]
if len(state) == 0:
state['step'] = 0
state["step"] = 0
# Store on CPU in FP32
state['exp_avg'] = torch.zeros_like(p, device='cpu', dtype=torch.float32)
state['exp_avg_sq'] = torch.zeros_like(p, device='cpu', dtype=torch.float32)
state["exp_avg"] = torch.zeros_like(p, device="cpu", dtype=torch.float32)
state["exp_avg_sq"] = torch.zeros_like(p, device="cpu", dtype=torch.float32)
return state
@torch.no_grad()
@ -47,41 +50,41 @@ class CPUOffloadAdamW(torch.optim.Optimizer):
loss = closure()
for group in self.param_groups:
beta1, beta2 = group['betas']
beta1, beta2 = group["betas"]
for p in group['params']:
for p in group["params"]:
if p.grad is None:
continue
grad = p.grad
state = self._init_state(p)
state['step'] += 1
state["step"] += 1
# Move states to GPU for computation
exp_avg = state['exp_avg'].to(p.device)
exp_avg_sq = state['exp_avg_sq'].to(p.device)
exp_avg = state["exp_avg"].to(p.device)
exp_avg_sq = state["exp_avg_sq"].to(p.device)
# AdamW update
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
# Bias correction
bias_correction1 = 1 - beta1 ** state['step']
bias_correction2 = 1 - beta2 ** state['step']
step_size = group['lr'] / bias_correction1
bias_correction1 = 1 - beta1 ** state["step"]
bias_correction2 = 1 - beta2 ** state["step"]
step_size = group["lr"] / bias_correction1
# Update weights
denom = (exp_avg_sq.sqrt() / (bias_correction2 ** 0.5)).add_(group['eps'])
denom = (exp_avg_sq.sqrt() / (bias_correction2**0.5)).add_(group["eps"])
p.addcdiv_(exp_avg, denom, value=-step_size)
# Weight decay
if group['weight_decay'] != 0:
p.add_(p, alpha=-group['lr'] * group['weight_decay'])
if group["weight_decay"] != 0:
p.add_(p, alpha=-group["lr"] * group["weight_decay"])
# Move states back to CPU (non-blocking for better perf)
state['exp_avg'].copy_(exp_avg.cpu())
state['exp_avg_sq'].copy_(exp_avg_sq.cpu())
state["exp_avg"].copy_(exp_avg.cpu())
state["exp_avg_sq"].copy_(exp_avg_sq.cpu())
return loss
@ -99,6 +102,7 @@ def create_optimizer(model: torch.nn.Module, config) -> torch.optim.Optimizer:
if config.optimizer == "adamw_8bit":
try:
import bitsandbytes as bnb
optimizer = bnb.optim.AdamW8bit(model.parameters(), lr=config.lr)
print("[Setup] Using 8-bit AdamW (saves ~24GB optimizer memory)")
return optimizer
@ -108,13 +112,18 @@ def create_optimizer(model: torch.nn.Module, config) -> torch.optim.Optimizer:
if config.optimizer == "adamw_cpu":
optimizer = CPUOffloadAdamW(model.parameters(), lr=config.lr)
print("[Setup] Using AdamW with CPU offload (full precision, ~0GB GPU for states)")
print("[Setup] NOTE: ~2x slower due to CPU<->GPU transfers, but no quantization")
print(
"[Setup] Using AdamW with CPU offload (full precision, ~0GB GPU for states)"
)
print(
"[Setup] NOTE: ~2x slower due to CPU<->GPU transfers, but no quantization"
)
return optimizer
if config.optimizer == "adafactor":
try:
from transformers.optimization import Adafactor
optimizer = Adafactor(
model.parameters(),
lr=config.lr,
@ -135,7 +144,7 @@ def create_optimizer(model: torch.nn.Module, config) -> torch.optim.Optimizer:
from .checkpointing import save_checkpoint, save_lora_checkpoint # noqa: E402
from .config import TrainingConfig # noqa: E402
from .data import get_data # noqa: E402
from .model import load_model_and_tokenizer, PEFT_AVAILABLE # noqa: E402
from .model import PEFT_AVAILABLE, load_model_and_tokenizer # noqa: E402
from .training import ( # noqa: E402
finalize_training,
log_metrics,
@ -146,8 +155,8 @@ from .vllm_manager import ( # noqa: E402
check_vllm_health,
check_vllm_process_health,
launch_vllm_server,
terminate_vllm_process,
set_vllm_process,
terminate_vllm_process,
)
@ -171,13 +180,13 @@ def train_legacy(config: TrainingConfig):
model, tokenizer = load_model_and_tokenizer(config)
optimizer = create_optimizer(model, config)
print("\n" + "="*60)
print("\n" + "=" * 60)
print("LEGACY MODE (checkpoint + vLLM restart)")
print("="*60)
print("=" * 60)
print(f"Training for {config.training_steps} steps on {config.device}")
print(f"vLLM restart interval: every {config.vllm_restart_interval} steps")
print(f"Save path: {config.save_path}")
print("="*60 + "\n")
print("=" * 60 + "\n")
os.makedirs(config.save_path, exist_ok=True)
@ -206,24 +215,36 @@ def train_legacy(config: TrainingConfig):
# Fetch data (with inference logprobs for proper GRPO)
data_fetch_start = time.time()
if len(batches) == 0:
batches, _ = get_data(config.batch_size, config.seq_len, config.atropos_url,
extract_inference_logprobs=True)
batches, _ = get_data(
config.batch_size,
config.seq_len,
config.atropos_url,
extract_inference_logprobs=True,
)
batch_data = batches.pop(0)
token_batches, label_batches, advantage_batches, temperature_batches = batch_data[:4]
token_batches, label_batches, advantage_batches, temperature_batches = (
batch_data[:4]
)
inference_logprob_batches = batch_data[4] if len(batch_data) > 4 else None
data_fetch_time = time.time() - data_fetch_start
benchmark_stats["data_fetch_times"].append(data_fetch_time)
# Check if we should sync (save checkpoint + restart vLLM)
should_sync = (step + 1) % config.vllm_restart_interval == 0 or step == config.training_steps - 1
should_sync = (
step + 1
) % config.vllm_restart_interval == 0 or step == config.training_steps - 1
if should_sync:
terminate_vllm_process()
# Training step (with proper GRPO using inference logprobs)
step_start = time.time()
metrics = run_training_step(
model, optimizer,
token_batches, label_batches, advantage_batches, temperature_batches,
model,
optimizer,
token_batches,
label_batches,
advantage_batches,
temperature_batches,
config,
inference_logprob_batches=inference_logprob_batches,
)
@ -231,15 +252,21 @@ def train_legacy(config: TrainingConfig):
benchmark_stats["step_times"].append(step_time)
# GPU memory tracking
gpu_mem_gb = torch.cuda.memory_allocated() / 1e9 if torch.cuda.is_available() else 0
gpu_mem_reserved_gb = torch.cuda.memory_reserved() / 1e9 if torch.cuda.is_available() else 0
gpu_mem_gb = (
torch.cuda.memory_allocated() / 1e9 if torch.cuda.is_available() else 0
)
gpu_mem_reserved_gb = (
torch.cuda.memory_reserved() / 1e9 if torch.cuda.is_available() else 0
)
benchmark_stats["gpu_memories"].append(gpu_mem_gb)
# Sync (checkpoint + restart)
sync_time = 0
if should_sync:
sync_start = time.time()
checkpoint_path = save_checkpoint(model, tokenizer, config.save_path, step + 1)
checkpoint_path = save_checkpoint(
model, tokenizer, config.save_path, step + 1
)
torch.cuda.empty_cache()
vllm_proc = launch_vllm_server(config, checkpoint_path)
set_vllm_process(vllm_proc)
@ -247,20 +274,31 @@ def train_legacy(config: TrainingConfig):
benchmark_stats["sync_times"].append(sync_time)
# Update metrics
metrics.update({
"step_time": step_time,
"sync_time": sync_time,
"data_fetch_time": data_fetch_time,
"gpu_memory_gb": gpu_mem_gb,
"gpu_memory_reserved_gb": gpu_mem_reserved_gb,
})
metrics.update(
{
"step_time": step_time,
"sync_time": sync_time,
"data_fetch_time": data_fetch_time,
"gpu_memory_gb": gpu_mem_gb,
"gpu_memory_reserved_gb": gpu_mem_reserved_gb,
}
)
log_metrics(metrics, step + 1, use_wandb, benchmark=config.benchmark)
check_vllm_process_health()
# === Cleanup ===
save_checkpoint(model, tokenizer, config.save_path, config.training_steps, is_final=True)
finalize_training(use_wandb, training_start_time, "legacy", config.training_steps, benchmark_stats, config.benchmark)
save_checkpoint(
model, tokenizer, config.save_path, config.training_steps, is_final=True
)
finalize_training(
use_wandb,
training_start_time,
"legacy",
config.training_steps,
benchmark_stats,
config.benchmark,
)
def train_shared_vllm(config: TrainingConfig):
@ -281,13 +319,13 @@ def train_shared_vllm(config: TrainingConfig):
# === Setup ===
use_wandb = setup_wandb(config)
print("\n" + "="*60)
print("\n" + "=" * 60)
print("SINGLE-COPY MODE (CUDA IPC)")
print(">>> Trainer uses vLLM's tensors directly!")
print("="*60)
print("=" * 60)
print(f"Model: {config.model_name}")
print(f"Save path: {config.save_path}")
print("="*60 + "\n")
print("=" * 60 + "\n")
# Attach to vLLM's shared tensors
print("[1/2] Attaching to vLLM's shared tensors...")
@ -331,11 +369,15 @@ def train_shared_vllm(config: TrainingConfig):
data_fetch_start = time.time()
if len(batches) == 0:
batches, _ = get_data(
config.batch_size, config.seq_len, config.atropos_url,
config.batch_size,
config.seq_len,
config.atropos_url,
extract_inference_logprobs=True, # Enable proper GRPO with reference logprobs
)
batch_data = batches.pop(0)
token_batches, label_batches, advantage_batches, temperature_batches = batch_data[:4]
token_batches, label_batches, advantage_batches, temperature_batches = (
batch_data[:4]
)
inference_logprob_batches = batch_data[4] if len(batch_data) > 4 else None
data_fetch_time = time.time() - data_fetch_start
benchmark_stats["data_fetch_times"].append(data_fetch_time)
@ -343,8 +385,12 @@ def train_shared_vllm(config: TrainingConfig):
# Training step with proper GRPO (importance sampling + KL penalty)
step_start = time.time()
metrics = run_training_step(
model, optimizer,
token_batches, label_batches, advantage_batches, temperature_batches,
model,
optimizer,
token_batches,
label_batches,
advantage_batches,
temperature_batches,
config,
inference_logprob_batches=inference_logprob_batches, # Pass for GRPO ratio computation
)
@ -352,8 +398,12 @@ def train_shared_vllm(config: TrainingConfig):
benchmark_stats["step_times"].append(step_time)
# GPU memory tracking
gpu_mem_gb = torch.cuda.memory_allocated() / 1e9 if torch.cuda.is_available() else 0
gpu_mem_reserved_gb = torch.cuda.memory_reserved() / 1e9 if torch.cuda.is_available() else 0
gpu_mem_gb = (
torch.cuda.memory_allocated() / 1e9 if torch.cuda.is_available() else 0
)
gpu_mem_reserved_gb = (
torch.cuda.memory_reserved() / 1e9 if torch.cuda.is_available() else 0
)
benchmark_stats["gpu_memories"].append(gpu_mem_gb)
# In single-copy mode, weights are updated in-place (no sync needed!)
@ -362,23 +412,37 @@ def train_shared_vllm(config: TrainingConfig):
benchmark_stats["sync_times"].append(sync_time)
# Update metrics
metrics.update({
"step_time": step_time,
"sync_time": sync_time,
"data_fetch_time": data_fetch_time,
"gpu_memory_gb": gpu_mem_gb,
"gpu_memory_reserved_gb": gpu_mem_reserved_gb,
})
metrics.update(
{
"step_time": step_time,
"sync_time": sync_time,
"data_fetch_time": data_fetch_time,
"gpu_memory_gb": gpu_mem_gb,
"gpu_memory_reserved_gb": gpu_mem_reserved_gb,
}
)
log_metrics(metrics, step + 1, use_wandb, benchmark=config.benchmark)
# Periodic checkpoint (for recovery, not for vLLM sync)
if config.checkpoint_interval > 0 and (step + 1) % config.checkpoint_interval == 0:
if (
config.checkpoint_interval > 0
and (step + 1) % config.checkpoint_interval == 0
):
save_checkpoint(model, tokenizer, config.save_path, step + 1)
# === Cleanup ===
save_checkpoint(model, tokenizer, config.save_path, config.training_steps, is_final=True)
finalize_training(use_wandb, training_start_time, "shared_vllm", config.training_steps, benchmark_stats, config.benchmark)
save_checkpoint(
model, tokenizer, config.save_path, config.training_steps, is_final=True
)
finalize_training(
use_wandb,
training_start_time,
"shared_vllm",
config.training_steps,
benchmark_stats,
config.benchmark,
)
def train_lora(config: TrainingConfig):
@ -399,29 +463,33 @@ def train_lora(config: TrainingConfig):
- External vLLM server running with --enable-lora
"""
if not PEFT_AVAILABLE:
raise RuntimeError("PEFT library required for LoRA mode. Install with: pip install peft")
raise RuntimeError(
"PEFT library required for LoRA mode. Install with: pip install peft"
)
training_start_time = time.time()
# === Setup ===
use_wandb = setup_wandb(config)
print("\n" + "="*60)
print("\n" + "=" * 60)
print("LORA MODE (adapter-only training)")
print("="*60)
print("=" * 60)
print(f"Base model: {config.model_name}")
print(f"LoRA config: r={config.lora_r}, alpha={config.lora_alpha}")
print(f"Save path: {config.save_path}")
print(f"vLLM port: {config.vllm_port}")
print("="*60 + "\n")
print("=" * 60 + "\n")
# Check external vLLM server
print("[1/3] Checking external vLLM server...")
if not check_vllm_health(config.vllm_port):
print(f"\nERROR: vLLM server not running on port {config.vllm_port}")
print("\nLoRA mode requires an external vLLM server. Start it first:")
print(f" python example_trainer/vllm_api_server.py --model {config.model_name} "
f"--port {config.vllm_port} --enable-lora --enforce-eager")
print(
f" python example_trainer/vllm_api_server.py --model {config.model_name} "
f"--port {config.vllm_port} --enable-lora --enforce-eager"
)
raise RuntimeError(f"External vLLM server required on port {config.vllm_port}")
print(f"vLLM server healthy on port {config.vllm_port}")
@ -459,10 +527,16 @@ def train_lora(config: TrainingConfig):
# Fetch data (with inference logprobs for proper GRPO)
data_fetch_start = time.time()
if len(batches) == 0:
batches, _ = get_data(config.batch_size, config.seq_len, config.atropos_url,
extract_inference_logprobs=True)
batches, _ = get_data(
config.batch_size,
config.seq_len,
config.atropos_url,
extract_inference_logprobs=True,
)
batch_data = batches.pop(0)
token_batches, label_batches, advantage_batches, temperature_batches = batch_data[:4]
token_batches, label_batches, advantage_batches, temperature_batches = (
batch_data[:4]
)
inference_logprob_batches = batch_data[4] if len(batch_data) > 4 else None
data_fetch_time = time.time() - data_fetch_start
benchmark_stats["data_fetch_times"].append(data_fetch_time)
@ -470,8 +544,12 @@ def train_lora(config: TrainingConfig):
# Training step with proper GRPO
step_start = time.time()
metrics = run_training_step(
model, optimizer,
token_batches, label_batches, advantage_batches, temperature_batches,
model,
optimizer,
token_batches,
label_batches,
advantage_batches,
temperature_batches,
config,
inference_logprob_batches=inference_logprob_batches,
)
@ -479,8 +557,12 @@ def train_lora(config: TrainingConfig):
benchmark_stats["step_times"].append(step_time)
# GPU memory tracking
gpu_mem_gb = torch.cuda.memory_allocated() / 1e9 if torch.cuda.is_available() else 0
gpu_mem_reserved_gb = torch.cuda.memory_reserved() / 1e9 if torch.cuda.is_available() else 0
gpu_mem_gb = (
torch.cuda.memory_allocated() / 1e9 if torch.cuda.is_available() else 0
)
gpu_mem_reserved_gb = (
torch.cuda.memory_reserved() / 1e9 if torch.cuda.is_available() else 0
)
benchmark_stats["gpu_memories"].append(gpu_mem_gb)
# Periodic adapter save + hot-swap
@ -494,24 +576,35 @@ def train_lora(config: TrainingConfig):
benchmark_stats["sync_times"].append(sync_time)
# Update metrics
metrics.update({
"step_time": step_time,
"sync_time": sync_time,
"data_fetch_time": data_fetch_time,
"gpu_memory_gb": gpu_mem_gb,
"gpu_memory_reserved_gb": gpu_mem_reserved_gb,
})
metrics.update(
{
"step_time": step_time,
"sync_time": sync_time,
"data_fetch_time": data_fetch_time,
"gpu_memory_gb": gpu_mem_gb,
"gpu_memory_reserved_gb": gpu_mem_reserved_gb,
}
)
log_metrics(metrics, step + 1, use_wandb, benchmark=config.benchmark)
# === Cleanup ===
final_sync_start = time.time()
final_adapter_path = save_lora_checkpoint(model, config.save_path, config.training_steps, is_final=True)
final_adapter_path = save_lora_checkpoint(
model, config.save_path, config.training_steps, is_final=True
)
_hotswap_lora_adapter(config.vllm_port, final_adapter_path, "final")
final_sync_time = time.time() - final_sync_start
benchmark_stats["sync_times"].append(final_sync_time)
finalize_training(use_wandb, training_start_time, "lora_only", config.training_steps, benchmark_stats, config.benchmark)
finalize_training(
use_wandb,
training_start_time,
"lora_only",
config.training_steps,
benchmark_stats,
config.benchmark,
)
# Save tokenizer
tokenizer_path = os.path.join(config.save_path, "tokenizer")
@ -563,4 +656,3 @@ def _hotswap_lora_adapter(
except Exception as e:
print(f" [LORA] ✗ Hot-swap request failed: {e}")
return False

View file

@ -18,10 +18,10 @@ import wandb
from .config import TrainingConfig
# Global storage for logprob alignment stats
_logprob_alignment_stats: Dict[str, float] = {}
def setup_wandb(config: TrainingConfig) -> bool:
"""
Initialize Weights & Biases logging if enabled.
@ -134,7 +134,9 @@ def compute_grpo_loss(
# === GRPO/PPO Loss Computation ===
if use_reference_logprobs and inference_logprobs is not None:
# Move inference logprobs to correct device/dtype
ref_logprobs = inference_logprobs.to(logp_per_token.device, logp_per_token.dtype)
ref_logprobs = inference_logprobs.to(
logp_per_token.device, logp_per_token.dtype
)
# NOTE: inference_logprobs uses 1.0 for masked (prompt) positions, actual negative values for generated
with torch.no_grad():
@ -156,10 +158,16 @@ def compute_grpo_loss(
# Check if ref logprobs are negative (as they should be for generated tokens)
# If ref_at_generated is close to 1.0, that means the 1.0 placeholder is being used
if ref_at_generated > 0.5:
print(f" [WARNING] ref_logprobs avg {ref_at_generated:.3f} (should be negative!)")
print(" [WARNING] This suggests inference_logprobs alignment is wrong")
print(
f" [WARNING] ref_logprobs avg {ref_at_generated:.3f} (should be negative!)"
)
print(
" [WARNING] This suggests inference_logprobs alignment is wrong"
)
elif abs(ref_at_generated - train_at_generated) > 2.0:
print(f" [DEBUG] Logprob gap: ref={ref_at_generated:.3f}, train={train_at_generated:.3f}")
print(
f" [DEBUG] Logprob gap: ref={ref_at_generated:.3f}, train={train_at_generated:.3f}"
)
# Compute importance sampling ratio: policy(a|s) / policy_old(a|s) = exp(log policy - log policy_old)
log_ratio = logp_per_token - ref_logprobs
@ -192,7 +200,9 @@ def compute_grpo_loss(
# = exp(-log_ratio) + log_ratio - 1
kl_per_token = torch.exp(-log_ratio) + log_ratio - 1.0
kl_penalty = ((kl_per_token * mask).sum(dim=-1) / mask_sum).mean()
total_loss = (policy_loss + kl_coef * kl_penalty) / gradient_accumulation_steps
total_loss = (
policy_loss + kl_coef * kl_penalty
) / gradient_accumulation_steps
else:
kl_penalty = torch.tensor(0.0, device=logp_per_token.device)
total_loss = policy_loss / gradient_accumulation_steps
@ -200,7 +210,9 @@ def compute_grpo_loss(
# Compute metrics for logging
with torch.no_grad():
# Fraction of tokens where ratio was clipped
clipped_fraction = ((ratio < 1.0 - clip_eps) | (ratio > 1.0 + clip_eps)).float()
clipped_fraction = (
(ratio < 1.0 - clip_eps) | (ratio > 1.0 + clip_eps)
).float()
clipped_fraction = (clipped_fraction * mask).sum() / mask.sum()
# Mean ratio and KL for monitoring (using Schulman's estimator)
@ -256,7 +268,11 @@ def compute_grpo_loss(
"kl_penalty": kl_penalty.item() if torch.is_tensor(kl_penalty) else kl_penalty,
"mean_ratio": mean_ratio.item() if torch.is_tensor(mean_ratio) else mean_ratio,
"mean_kl": mean_kl.item() if torch.is_tensor(mean_kl) else mean_kl,
"clipped_fraction": clipped_fraction.item() if torch.is_tensor(clipped_fraction) else clipped_fraction,
"clipped_fraction": (
clipped_fraction.item()
if torch.is_tensor(clipped_fraction)
else clipped_fraction
),
# Token-level alignment metrics (key for verifying weight sharing)
"logprob_diff_mean": logprob_diff_mean,
"logprob_diff_abs_mean": logprob_diff_abs_mean,
@ -315,23 +331,25 @@ def run_training_step(
all_inference_logprobs: List[torch.Tensor] = []
# Get GRPO hyperparameters from config
kl_coef = getattr(config, 'kl_coef', 0.1)
clip_eps = getattr(config, 'clip_eps', 0.2)
use_reference_logprobs = getattr(config, 'use_reference_logprobs', True)
kl_coef = getattr(config, "kl_coef", 0.1)
clip_eps = getattr(config, "clip_eps", 0.2)
use_reference_logprobs = getattr(config, "use_reference_logprobs", True)
# Accumulate gradients over micro-batches
num_batches = len(token_batches) if token_batches else 1
for batch_idx, (tokens, labels, advantages, temperatures) in enumerate(zip(
token_batches, label_batches, advantage_batches, temperature_batches
)):
for batch_idx, (tokens, labels, advantages, temperatures) in enumerate(
zip(token_batches, label_batches, advantage_batches, temperature_batches)
):
tokens = tokens.to(config.device)
labels = labels.to(config.device)
advantages = advantages.to(config.device)
# Get corresponding inference logprobs batch if available
inf_logprobs = None
if inference_logprob_batches is not None and batch_idx < len(inference_logprob_batches):
if inference_logprob_batches is not None and batch_idx < len(
inference_logprob_batches
):
inf_logprobs = inference_logprob_batches[batch_idx]
loss, metrics = compute_grpo_loss(
@ -363,12 +381,17 @@ def run_training_step(
# Accumulate token-level alignment metrics
total_logprob_diff_mean += metrics.get("logprob_diff_mean", 0.0)
total_logprob_diff_abs_mean += metrics.get("logprob_diff_abs_mean", 0.0)
total_logprob_diff_max = max(total_logprob_diff_max, metrics.get("logprob_diff_max", 0.0))
total_logprob_diff_max = max(
total_logprob_diff_max, metrics.get("logprob_diff_max", 0.0)
)
# Collect logprobs for alignment monitoring
if "training_logprobs" in metrics and metrics["training_logprobs"] is not None:
all_training_logprobs.append(metrics["training_logprobs"])
if "inference_logprobs" in metrics and metrics["inference_logprobs"] is not None:
if (
"inference_logprobs" in metrics
and metrics["inference_logprobs"] is not None
):
all_inference_logprobs.append(metrics["inference_logprobs"])
# Gradient clipping and optimizer step
@ -387,7 +410,7 @@ def run_training_step(
result = {
"loss": total_loss,
"grad_norm": grad_norm.item() if hasattr(grad_norm, 'item') else grad_norm,
"grad_norm": grad_norm.item() if hasattr(grad_norm, "item") else grad_norm,
"pos_logp": total_pos_logp,
"neg_logp": total_neg_logp,
"pos_count": total_pos,
@ -404,7 +427,9 @@ def run_training_step(
if all_training_logprobs:
train_flat = torch.cat(all_training_logprobs)
if train_flat.numel() > 0:
_logprob_alignment_stats["logprobs/training_mean"] = train_flat.mean().item()
_logprob_alignment_stats["logprobs/training_mean"] = (
train_flat.mean().item()
)
_logprob_alignment_stats["logprobs/training_std"] = train_flat.std().item()
if all_inference_logprobs:
@ -415,8 +440,12 @@ def run_training_step(
# Token-level alignment metrics - THE key metric for verifying weight sharing
# diff_abs_mean close to 0 = weights are truly shared
_logprob_alignment_stats["alignment/diff_mean"] = total_logprob_diff_mean / num_batches
_logprob_alignment_stats["alignment/diff_abs_mean"] = total_logprob_diff_abs_mean / num_batches
_logprob_alignment_stats["alignment/diff_mean"] = (
total_logprob_diff_mean / num_batches
)
_logprob_alignment_stats["alignment/diff_abs_mean"] = (
total_logprob_diff_abs_mean / num_batches
)
_logprob_alignment_stats["alignment/diff_max"] = total_logprob_diff_max
return result
@ -495,8 +524,13 @@ def log_metrics(
"grpo/clipped_fraction": clipped_frac,
}
# Add timing metrics if present
for key in ["step_time", "sync_time", "data_fetch_time",
"gpu_memory_gb", "gpu_memory_reserved_gb"]:
for key in [
"step_time",
"sync_time",
"data_fetch_time",
"gpu_memory_gb",
"gpu_memory_reserved_gb",
]:
if key in metrics:
log_dict[f"train/{key}"] = metrics[key]
@ -549,7 +583,9 @@ def finalize_training(
total_step_time = sum(step_times)
avg_sync_time = sum(sync_times) / len(sync_times) if sync_times else 0
total_sync_time = sum(sync_times)
avg_data_fetch = sum(data_fetch_times) / len(data_fetch_times) if data_fetch_times else 0
avg_data_fetch = (
sum(data_fetch_times) / len(data_fetch_times) if data_fetch_times else 0
)
total_data_fetch = sum(data_fetch_times)
avg_gpu_mem = sum(gpu_memories) / len(gpu_memories) if gpu_memories else 0
@ -557,13 +593,17 @@ def finalize_training(
print(f"\n{'='*70}")
print(f"BENCHMARK SUMMARY ({mode})")
print(f"{'='*70}")
print(f" Total training time: {total_time:.2f}s ({total_time/60:.2f} min)")
print(
f" Total training time: {total_time:.2f}s ({total_time/60:.2f} min)"
)
print(f" Total steps: {total_steps}")
print(" ")
print(" TIMING BREAKDOWN:")
print(f" Avg step time: {avg_step_time:.2f}s")
print(f" Total step time: {total_step_time:.2f}s")
print(f" Avg sync time: {avg_sync_time:.2f}s (x{len(sync_times)} syncs)")
print(
f" Avg sync time: {avg_sync_time:.2f}s (x{len(sync_times)} syncs)"
)
print(f" Total sync time: {total_sync_time:.2f}s")
print(f" Avg data fetch time: {avg_data_fetch:.2f}s")
print(f" Total data fetch time: {total_data_fetch:.2f}s")
@ -584,4 +624,3 @@ def finalize_training(
wandb.finish()
elif use_wandb:
wandb.finish()

View file

@ -53,7 +53,7 @@ os.environ.setdefault("VLLM_USE_V1", "0")
# Set spawn method for multiprocessing (required for CUDA)
os.environ.setdefault("VLLM_WORKER_MULTIPROC_METHOD", "spawn")
try:
multiprocessing.set_start_method('spawn', force=True)
multiprocessing.set_start_method("spawn", force=True)
except RuntimeError:
pass # Already set
@ -86,6 +86,7 @@ def _apply_patches_early() -> bool:
try:
import sys
from pathlib import Path
# Add parent directory to path so we can import vllm_patching
script_dir = Path(__file__).parent
if str(script_dir) not in sys.path:
@ -106,6 +107,7 @@ def _apply_patches_early() -> bool:
except Exception as e:
print(f"[vLLM Server] Error applying patches: {e}")
import traceback
traceback.print_exc()
return False
@ -145,17 +147,20 @@ except ImportError:
def add_argument(self, *args, **kwargs):
# Remove 'deprecated' kwarg if present (not supported before Python 3.13)
kwargs.pop('deprecated', None)
kwargs.pop("deprecated", None)
return super().add_argument(*args, **kwargs)
# set_ulimit might not exist in all vLLM versions
try:
from vllm.utils import set_ulimit
except ImportError:
def set_ulimit() -> None:
"""No-op fallback for set_ulimit."""
pass
from vllm.outputs import RequestOutput # noqa: F401, E402
from vllm.version import __version__ as VLLM_VERSION # noqa: E402
@ -602,7 +607,9 @@ async def lora_load(request: LoraLoadRequest) -> JSONResponse:
) # vLLM needs unique int ID
bridge_state.lora_load_count += 1
logger.info(f"LoRA adapter loaded: {request.adapter_path} (id={bridge_state.active_lora_id})")
logger.info(
f"LoRA adapter loaded: {request.adapter_path} (id={bridge_state.active_lora_id})"
)
return JSONResponse(
{

View file

@ -17,7 +17,6 @@ import requests
from .config import TrainingConfig
# Global variable to keep track of the vLLM process
_vllm_process: Optional[subprocess.Popen] = None
@ -25,7 +24,7 @@ _vllm_process: Optional[subprocess.Popen] = None
def is_port_in_use(port: int) -> bool:
"""Check if a port is already in use."""
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
return s.connect_ex(('localhost', port)) == 0
return s.connect_ex(("localhost", port)) == 0
def kill_process_on_port(port: int, timeout: float = 5.0) -> bool:
@ -42,13 +41,10 @@ def kill_process_on_port(port: int, timeout: float = 5.0) -> bool:
try:
# Try to find and kill the process using lsof (Linux/Mac)
result = subprocess.run(
["lsof", "-t", "-i", f":{port}"],
capture_output=True,
text=True,
timeout=5
["lsof", "-t", "-i", f":{port}"], capture_output=True, text=True, timeout=5
)
if result.stdout.strip():
pids = result.stdout.strip().split('\n')
pids = result.stdout.strip().split("\n")
for pid in pids:
try:
os.kill(int(pid), signal.SIGTERM)
@ -135,7 +131,9 @@ def launch_vllm_server(
if is_port_in_use(config.vllm_port):
print(f" WARNING: Port {config.vllm_port} is already in use!")
if not kill_process_on_port(config.vllm_port):
print(f" ERROR: Could not free port {config.vllm_port}. Please manually kill the process.")
print(
f" ERROR: Could not free port {config.vllm_port}. Please manually kill the process."
)
print(f" Try: lsof -i :{config.vllm_port} | grep LISTEN")
print(f" Or: pkill -f 'vllm.*{config.vllm_port}'")
return None
@ -209,7 +207,9 @@ def check_vllm_process_health() -> None:
global _vllm_process
if _vllm_process is not None and _vllm_process.poll() is not None:
print(f" WARNING: vLLM terminated unexpectedly (code: {_vllm_process.returncode})")
print(
f" WARNING: vLLM terminated unexpectedly (code: {_vllm_process.returncode})"
)
_vllm_process = None
@ -299,7 +299,9 @@ def hotswap_lora_adapter(
print(f" [LORA] ✓ Hot-swapped adapter: {adapter_name} ({adapter_path})")
return True
else:
print(f" [LORA] ✗ Hot-swap failed: {response.status_code} - {response.text}")
print(
f" [LORA] ✗ Hot-swap failed: {response.status_code} - {response.text}"
)
return False
except requests.exceptions.ConnectionError:
@ -308,4 +310,3 @@ def hotswap_lora_adapter(
except Exception as e:
print(f" [LORA] ✗ Error during hot-swap: {e}")
return False

View file

@ -37,6 +37,7 @@ def _patch_lora_triton_for_blackwell() -> bool:
"""
try:
import vllm
vllm_path = vllm.__path__[0]
kernel_utils_path = f"{vllm_path}/lora/ops/triton_ops/kernel_utils.py"
@ -45,42 +46,42 @@ def _patch_lora_triton_for_blackwell() -> bool:
print("[vLLM Patch] LoRA kernel_utils.py not found, skipping GDC patch")
return False
with open(kernel_utils_path, 'r') as f:
with open(kernel_utils_path, "r") as f:
content = f.read()
# Check if already patched
if 'PATCHED FOR B200' in content:
if "PATCHED FOR B200" in content:
print("[vLLM Patch] LoRA GDC already patched for B200")
return True
modified = False
# Patch USE_GDC = True -> False
if 'USE_GDC = True' in content:
if "USE_GDC = True" in content:
content = content.replace(
'USE_GDC = True',
'USE_GDC = False # PATCHED FOR B200 - GDC causes Triton compilation failure'
"USE_GDC = True",
"USE_GDC = False # PATCHED FOR B200 - GDC causes Triton compilation failure",
)
modified = True
# Patch USE_GDC: tl.constexpr = True -> False
if 'USE_GDC: tl.constexpr = True' in content:
if "USE_GDC: tl.constexpr = True" in content:
content = content.replace(
'USE_GDC: tl.constexpr = True',
'USE_GDC: tl.constexpr = False # PATCHED FOR B200'
"USE_GDC: tl.constexpr = True",
"USE_GDC: tl.constexpr = False # PATCHED FOR B200",
)
modified = True
# Patch the gdc_wait call itself
if 'tl.extra.cuda.gdc_wait()' in content:
if "tl.extra.cuda.gdc_wait()" in content:
content = content.replace(
'tl.extra.cuda.gdc_wait()',
'pass # tl.extra.cuda.gdc_wait() PATCHED FOR B200 - disabled'
"tl.extra.cuda.gdc_wait()",
"pass # tl.extra.cuda.gdc_wait() PATCHED FOR B200 - disabled",
)
modified = True
if modified:
with open(kernel_utils_path, 'w') as f:
with open(kernel_utils_path, "w") as f:
f.write(content)
print(f"[vLLM Patch] ✓ Patched LoRA Triton GDC in {kernel_utils_path}")