[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
This commit is contained in:
pre-commit-ci[bot] 2026-02-06 06:46:14 +00:00 committed by Jai Suphavadeeprasit
parent d07ab3e3ce
commit 5cfd1929f1
19 changed files with 708 additions and 452 deletions

View file

@ -14,7 +14,7 @@ from pydantic import BaseModel, Field
class TrainingConfig(BaseModel):
"""
Training configuration for GRPO trainer.
Supports three training modes:
- 'none' (legacy): Periodic checkpoint saves + vLLM restarts
- 'shared_vllm': Attach to vLLM's shared memory tensors, update in-place
@ -23,7 +23,7 @@ class TrainingConfig(BaseModel):
# === Model Configuration ===
model_name: str = Field(..., description="Name of the base model to train")
# === Training Hyperparameters ===
lr: float = Field(1e-5, description="Learning rate for the optimizer")
training_steps: int = Field(10, description="Number of training steps")
@ -35,11 +35,11 @@ class TrainingConfig(BaseModel):
optimizer: Literal["adamw", "adamw_8bit", "adamw_cpu", "adafactor"] = Field(
"adamw_8bit",
description="Optimizer to use: 'adamw' (full precision, ~32GB GPU), "
"'adamw_8bit' (8-bit states, ~8GB GPU, requires bitsandbytes), "
"'adamw_cpu' (CPU offload, ~0GB GPU, slower), "
"'adafactor' (no momentum, ~8GB GPU)"
"'adamw_8bit' (8-bit states, ~8GB GPU, requires bitsandbytes), "
"'adamw_cpu' (CPU offload, ~0GB GPU, slower), "
"'adafactor' (no momentum, ~8GB GPU)",
)
# === GRPO/PPO Hyperparameters ===
kl_coef: float = Field(
0.1,
@ -66,15 +66,13 @@ class TrainingConfig(BaseModel):
"When False, falls back to REINFORCE-style updates (not recommended)."
),
)
# === Device & Storage ===
device: str = Field(
"cuda" if torch.cuda.is_available() else "cpu",
description="Device to train on"
"cuda" if torch.cuda.is_available() else "cpu", description="Device to train on"
)
save_path: str = Field(
"trained_model_checkpoints",
description="Base path to save model checkpoints"
"trained_model_checkpoints", description="Base path to save model checkpoints"
)
checkpoint_interval: int = Field(
3,
@ -83,7 +81,7 @@ class TrainingConfig(BaseModel):
"Set to 0 to only save final checkpoint."
),
)
# === vLLM Server Configuration ===
vllm_restart_interval: int = Field(
3, description="Restart vLLM every N training steps (legacy mode)"
@ -116,14 +114,12 @@ class TrainingConfig(BaseModel):
"'none': legacy mode, restart vLLM with new checkpoint files."
),
)
# === Distributed Training Configuration ===
trainer_rank: int = Field(
0, description="Rank of this trainer in the distributed group"
)
world_size: int = Field(
1, description="Total processes in the distributed group"
)
world_size: int = Field(1, description="Total processes in the distributed group")
init_method: str = Field(
"env://",
description=(
@ -189,7 +185,7 @@ class TrainingConfig(BaseModel):
"data fetch time, and GPU memory usage per step."
),
)
# === Atropos API Configuration ===
atropos_url: str = Field(
"http://localhost:8000",
@ -198,4 +194,3 @@ class TrainingConfig(BaseModel):
"Default is http://localhost:8000. Change for concurrent tests."
),
)