mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
parent
d07ab3e3ce
commit
5cfd1929f1
19 changed files with 708 additions and 452 deletions
|
|
@ -14,7 +14,7 @@ from pydantic import BaseModel, Field
|
|||
class TrainingConfig(BaseModel):
|
||||
"""
|
||||
Training configuration for GRPO trainer.
|
||||
|
||||
|
||||
Supports three training modes:
|
||||
- 'none' (legacy): Periodic checkpoint saves + vLLM restarts
|
||||
- 'shared_vllm': Attach to vLLM's shared memory tensors, update in-place
|
||||
|
|
@ -23,7 +23,7 @@ class TrainingConfig(BaseModel):
|
|||
|
||||
# === Model Configuration ===
|
||||
model_name: str = Field(..., description="Name of the base model to train")
|
||||
|
||||
|
||||
# === Training Hyperparameters ===
|
||||
lr: float = Field(1e-5, description="Learning rate for the optimizer")
|
||||
training_steps: int = Field(10, description="Number of training steps")
|
||||
|
|
@ -35,11 +35,11 @@ class TrainingConfig(BaseModel):
|
|||
optimizer: Literal["adamw", "adamw_8bit", "adamw_cpu", "adafactor"] = Field(
|
||||
"adamw_8bit",
|
||||
description="Optimizer to use: 'adamw' (full precision, ~32GB GPU), "
|
||||
"'adamw_8bit' (8-bit states, ~8GB GPU, requires bitsandbytes), "
|
||||
"'adamw_cpu' (CPU offload, ~0GB GPU, slower), "
|
||||
"'adafactor' (no momentum, ~8GB GPU)"
|
||||
"'adamw_8bit' (8-bit states, ~8GB GPU, requires bitsandbytes), "
|
||||
"'adamw_cpu' (CPU offload, ~0GB GPU, slower), "
|
||||
"'adafactor' (no momentum, ~8GB GPU)",
|
||||
)
|
||||
|
||||
|
||||
# === GRPO/PPO Hyperparameters ===
|
||||
kl_coef: float = Field(
|
||||
0.1,
|
||||
|
|
@ -66,15 +66,13 @@ class TrainingConfig(BaseModel):
|
|||
"When False, falls back to REINFORCE-style updates (not recommended)."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# === Device & Storage ===
|
||||
device: str = Field(
|
||||
"cuda" if torch.cuda.is_available() else "cpu",
|
||||
description="Device to train on"
|
||||
"cuda" if torch.cuda.is_available() else "cpu", description="Device to train on"
|
||||
)
|
||||
save_path: str = Field(
|
||||
"trained_model_checkpoints",
|
||||
description="Base path to save model checkpoints"
|
||||
"trained_model_checkpoints", description="Base path to save model checkpoints"
|
||||
)
|
||||
checkpoint_interval: int = Field(
|
||||
3,
|
||||
|
|
@ -83,7 +81,7 @@ class TrainingConfig(BaseModel):
|
|||
"Set to 0 to only save final checkpoint."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# === vLLM Server Configuration ===
|
||||
vllm_restart_interval: int = Field(
|
||||
3, description="Restart vLLM every N training steps (legacy mode)"
|
||||
|
|
@ -116,14 +114,12 @@ class TrainingConfig(BaseModel):
|
|||
"'none': legacy mode, restart vLLM with new checkpoint files."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# === Distributed Training Configuration ===
|
||||
trainer_rank: int = Field(
|
||||
0, description="Rank of this trainer in the distributed group"
|
||||
)
|
||||
world_size: int = Field(
|
||||
1, description="Total processes in the distributed group"
|
||||
)
|
||||
world_size: int = Field(1, description="Total processes in the distributed group")
|
||||
init_method: str = Field(
|
||||
"env://",
|
||||
description=(
|
||||
|
|
@ -189,7 +185,7 @@ class TrainingConfig(BaseModel):
|
|||
"data fetch time, and GPU memory usage per step."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# === Atropos API Configuration ===
|
||||
atropos_url: str = Field(
|
||||
"http://localhost:8000",
|
||||
|
|
@ -198,4 +194,3 @@ class TrainingConfig(BaseModel):
|
|||
"Default is http://localhost:8000. Change for concurrent tests."
|
||||
),
|
||||
)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue