model layer stuff

This commit is contained in:
Jai Suphavadeeprasit 2026-02-18 10:52:20 -05:00
parent 16ac332880
commit fa22bf58d1
4 changed files with 131 additions and 1 deletions

View file

@ -6,6 +6,7 @@ This is the SINGLE SOURCE OF TRUTH for all CLI arguments.
"""
import argparse
from typing import List, Optional
import torch
@ -16,6 +17,53 @@ from .config import TrainingConfig
# =============================================================================
def _parse_lora_layer_indices(value: str) -> Optional[List[int]]:
"""
Parse LoRA layer indices from comma/range syntax.
Supported formats:
- "20-31"
- "0,1,2,28,29,30,31"
- "0-3,28-31"
"""
if value is None:
return None
raw = value.strip()
if not raw:
return None
indices: List[int] = []
parts = [part.strip() for part in raw.split(",") if part.strip()]
try:
for part in parts:
if "-" in part:
start_s, end_s = part.split("-", 1)
start = int(start_s.strip())
end = int(end_s.strip())
if start > end:
raise argparse.ArgumentTypeError(
f"Invalid range '{part}': start must be <= end"
)
indices.extend(range(start, end + 1))
else:
indices.append(int(part))
except ValueError as e:
raise argparse.ArgumentTypeError(
f"Invalid --lora-layer-indices value '{value}': {e}"
) from e
if not indices:
return None
if any(idx < 0 for idx in indices):
raise argparse.ArgumentTypeError(
f"Invalid --lora-layer-indices '{value}': indices must be >= 0"
)
return sorted(set(indices))
def add_model_args(parser: argparse.ArgumentParser) -> None:
"""Add model-related arguments."""
group = parser.add_argument_group("Model")
@ -225,6 +273,15 @@ def add_lora_args(parser: argparse.ArgumentParser) -> None:
default=None,
help="Module names to apply LoRA to (default: q_proj v_proj)",
)
group.add_argument(
"--lora-layer-indices",
type=_parse_lora_layer_indices,
default=None,
help=(
"Optional layer indices to apply LoRA to, e.g. '20-31' or "
"'0-3,28-31'. If omitted, applies to all matching layers."
),
)
def add_distributed_args(parser: argparse.ArgumentParser) -> None:
@ -373,6 +430,7 @@ def config_from_args(args: argparse.Namespace) -> TrainingConfig:
lora_alpha=getattr(args, "lora_alpha", 32),
lora_dropout=getattr(args, "lora_dropout", 0.05),
lora_target_modules=getattr(args, "lora_target_modules", None),
lora_layer_indices=getattr(args, "lora_layer_indices", None),
vllm_config_path=getattr(args, "vllm_config_path", None),
debug_loading=getattr(args, "debug_loading", False),
benchmark=getattr(args, "benchmark", False),