reasoning-gym/training/qwen-math/tina/config.py
Zafir Stojanovski 0cda6b1205
qwen math training code (#435)
* qwen math training code

* pre-commit
2025-05-16 13:19:19 +02:00

21 lines
994 B
Python

from dataclasses import dataclass, field
from typing import Literal
# check ./recipes/MODEL_NAME/PT_METHOD/model_DATASET.yaml
@dataclass
class ModelPTConfig:
# //*******Model post-training configs*******//
model_post_train_type: Literal["grpo", "sft"] = field(default="grpo")
model_post_train_dataset_name: str = field(default="curated_deepscaler")
model_post_train_dataset_config: str | None = field(default=None)
rl_post_train_reward_funcs: list[str] = field(default_factory=lambda: ["accuracy", "format"])
rl_post_train_reward_weights: list[str] = field(default_factory=lambda: [2.0, 1.0])
cosine_min_value_wrong: float = field(default=0.0)
cosine_max_value_wrong: float = field(default=-0.5)
cosine_min_value_correct: float = field(default=0.5)
cosine_max_value_correct: float = field(default=1.0)
cosine_max_len: int = field(default=1000)
repetition_n_grams: int = field(default=3)
repetition_max_penalty: float = field(default=-1.0)