mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-27 17:23:08 +00:00
model layer stuff
This commit is contained in:
parent
16ac332880
commit
fa22bf58d1
4 changed files with 131 additions and 1 deletions
|
|
@ -6,6 +6,7 @@ This is the SINGLE SOURCE OF TRUTH for all CLI arguments.
|
|||
"""
|
||||
|
||||
import argparse
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
|
|
@ -16,6 +17,53 @@ from .config import TrainingConfig
|
|||
# =============================================================================
|
||||
|
||||
|
||||
def _parse_lora_layer_indices(value: str) -> Optional[List[int]]:
|
||||
"""
|
||||
Parse LoRA layer indices from comma/range syntax.
|
||||
|
||||
Supported formats:
|
||||
- "20-31"
|
||||
- "0,1,2,28,29,30,31"
|
||||
- "0-3,28-31"
|
||||
"""
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
raw = value.strip()
|
||||
if not raw:
|
||||
return None
|
||||
|
||||
indices: List[int] = []
|
||||
parts = [part.strip() for part in raw.split(",") if part.strip()]
|
||||
|
||||
try:
|
||||
for part in parts:
|
||||
if "-" in part:
|
||||
start_s, end_s = part.split("-", 1)
|
||||
start = int(start_s.strip())
|
||||
end = int(end_s.strip())
|
||||
if start > end:
|
||||
raise argparse.ArgumentTypeError(
|
||||
f"Invalid range '{part}': start must be <= end"
|
||||
)
|
||||
indices.extend(range(start, end + 1))
|
||||
else:
|
||||
indices.append(int(part))
|
||||
except ValueError as e:
|
||||
raise argparse.ArgumentTypeError(
|
||||
f"Invalid --lora-layer-indices value '{value}': {e}"
|
||||
) from e
|
||||
|
||||
if not indices:
|
||||
return None
|
||||
if any(idx < 0 for idx in indices):
|
||||
raise argparse.ArgumentTypeError(
|
||||
f"Invalid --lora-layer-indices '{value}': indices must be >= 0"
|
||||
)
|
||||
|
||||
return sorted(set(indices))
|
||||
|
||||
|
||||
def add_model_args(parser: argparse.ArgumentParser) -> None:
|
||||
"""Add model-related arguments."""
|
||||
group = parser.add_argument_group("Model")
|
||||
|
|
@ -225,6 +273,15 @@ def add_lora_args(parser: argparse.ArgumentParser) -> None:
|
|||
default=None,
|
||||
help="Module names to apply LoRA to (default: q_proj v_proj)",
|
||||
)
|
||||
group.add_argument(
|
||||
"--lora-layer-indices",
|
||||
type=_parse_lora_layer_indices,
|
||||
default=None,
|
||||
help=(
|
||||
"Optional layer indices to apply LoRA to, e.g. '20-31' or "
|
||||
"'0-3,28-31'. If omitted, applies to all matching layers."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def add_distributed_args(parser: argparse.ArgumentParser) -> None:
|
||||
|
|
@ -373,6 +430,7 @@ def config_from_args(args: argparse.Namespace) -> TrainingConfig:
|
|||
lora_alpha=getattr(args, "lora_alpha", 32),
|
||||
lora_dropout=getattr(args, "lora_dropout", 0.05),
|
||||
lora_target_modules=getattr(args, "lora_target_modules", None),
|
||||
lora_layer_indices=getattr(args, "lora_layer_indices", None),
|
||||
vllm_config_path=getattr(args, "vllm_config_path", None),
|
||||
debug_loading=getattr(args, "debug_loading", False),
|
||||
benchmark=getattr(args, "benchmark", False),
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue