mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-05-02 17:45:58 +00:00
tutorial(training): Add a minimal example with trl (#473)
* v0 * 2 gpu setup * improve parsing from yaml * update yaml dataset example * remove restriction on flash attn * more comments * first version of the readme * pin torch * simplify requirements * just flash attn * use set env instead * simpler set env * readme * add wandb project to setup * update template * update model id * post init to capture the config and weight * extract metadata * update config * update dataset config * move env for wandb project * pre-commit * remove qwen-math from training * more instructions * unused import * remove trl old * warmup ratio * warmup ratio * change model id * change model_id * add info about CUDA_VISIBLE_DEVICES
This commit is contained in:
parent
49f3821098
commit
56ce2e79a7
59 changed files with 382 additions and 155340 deletions
26
examples/trl/train.sh
Executable file
26
examples/trl/train.sh
Executable file
|
|
@ -0,0 +1,26 @@
|
|||
#!/bin/bash
|
||||
|
||||
export CUDA_VISIBLE_DEVICES=0,1
|
||||
GPU_COUNT=$(python -c "import torch; print(torch.cuda.device_count())")
|
||||
NUM_PROCESSES_TRAINING=$((GPU_COUNT - 1))
|
||||
|
||||
echo ""
|
||||
echo "Number of GPUs: ${GPU_COUNT}"
|
||||
echo "Number of processes for training: ${NUM_PROCESSES_TRAINING}"
|
||||
echo ""
|
||||
|
||||
PY_SCRIPT="./grpo.py"
|
||||
PY_CONFIG="./config/grpo.yaml"
|
||||
ACCELERATE_DS_CONFIG="./config/ds_zero2.yaml"
|
||||
|
||||
echo "START TIME: $(date)"
|
||||
|
||||
export WANDB_PROJECT="reasoning-gym-trl"
|
||||
|
||||
accelerate launch \
|
||||
--config_file "${ACCELERATE_DS_CONFIG}" \
|
||||
--main_process_port=29500 \
|
||||
--num_processes="${NUM_PROCESSES_TRAINING}" "${PY_SCRIPT}" --config "${PY_CONFIG}"
|
||||
|
||||
echo "END TIME: $(date)"
|
||||
echo "DONE"
|
||||
Loading…
Add table
Add a link
Reference in a new issue