tutorial(training): Add a minimal example with trl (#473)

* v0

* 2 gpu setup

* improve parsing from yaml

* update yaml dataset example

* remove restriction on flash attn

* more comments

* first version of the readme

* pin torch

* simplify requirements

* just flash attn

* use set env instead

* simpler set env

* readme

* add wandb project to setup

* update template

* update model id

* post init to capture the config and weight

* extract metadata

* update config

* update dataset config

* move env for wandb project

* pre-commit

* remove qwen-math from training

* more instructions

* unused import

* remove trl old

* warmup ratio

* warmup ratio

* change model id

* change model_id

* add info about CUDA_VISIBLE_DEVICES
This commit is contained in:
Zafir Stojanovski 2025-06-21 00:01:31 +02:00 committed by GitHub
parent 49f3821098
commit 56ce2e79a7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
59 changed files with 382 additions and 155340 deletions

View file

@ -1,37 +1,52 @@
#Model arguments
model_name_or_path: deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B
# Reasoning Gym configs
dataset_size: 20000
developer_prompt: DeepSeekZero
developer_role: system
datasets:
simple_equations:
weight: 1
complex_arithmetic:
weight: 1
config:
min_real: -20
max_real: 20
#script arguments
dataset_name: chain_sum
# Model configs from trl
model_name_or_path: Qwen/Qwen2.5-1.5B-Instruct
attn_implementation: flash_attention_2
#training arguments
# GRPO trainer configs from trl
bf16: true
gradient_accumulation_steps: 16
use_vllm: true
vllm_device: cuda:1
vllm_gpu_memory_utilization: 0.9
log_level: info
gradient_accumulation_steps: 1
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
hub_model_id:
seed: 42
eval_seed: 101
log_level: info
logging_steps: 10
use_reentrant: false
logging_first_step: true
logging_steps: 1
logging_strategy: steps
lr_scheduler_type: cosine
learning_rate: 2.0e-05
learning_rate: 1e-06
lr_scheduler_type: constant_with_warmup
lr_scheduler_kwargs:
num_warmup_steps: 10
max_prompt_length: 512
max_completion_length: 1024
max_completion_length: 2048
max_steps: 100
num_generations: 8
per_device_train_batch_size: 1
per_device_eval_batch_size: 1
overwrite_output_dir: true
output_dir: data/Qwen-1.5B-GRPO
train_size: 1000
eval_size: 100
num_train_epochs: 1
max_steps: -1
push_to_hub: true
report_to: ['wandb']
#do_eval: true
#eval_strategy: steps
#eval_steps: 100
overwrite_output_dir: true
per_device_train_batch_size: 8
report_to:
- wandb
save_strategy: steps
save_steps: 50
save_total_limit: 5
seed: 42
temperature: 0.6
warmup_ratio: 0.1