reasoning-gym/examples/trl/train.sh
Zafir Stojanovski 56ce2e79a7
tutorial(training): Add a minimal example with trl (#473)
* v0

* 2 gpu setup

* improve parsing from yaml

* update yaml dataset example

* remove restriction on flash attn

* more comments

* first version of the readme

* pin torch

* simplify requirements

* just flash attn

* use set env instead

* simpler set env

* readme

* add wandb project to setup

* update template

* update model id

* post init to capture the config and weight

* extract metadata

* update config

* update dataset config

* move env for wandb project

* pre-commit

* remove qwen-math from training

* more instructions

* unused import

* remove trl old

* warmup ratio

* warmup ratio

* change model id

* change model_id

* add info about CUDA_VISIBLE_DEVICES
2025-06-21 00:01:31 +02:00

26 lines
673 B
Bash
Executable file

#!/bin/bash
export CUDA_VISIBLE_DEVICES=0,1
GPU_COUNT=$(python -c "import torch; print(torch.cuda.device_count())")
NUM_PROCESSES_TRAINING=$((GPU_COUNT - 1))
echo ""
echo "Number of GPUs: ${GPU_COUNT}"
echo "Number of processes for training: ${NUM_PROCESSES_TRAINING}"
echo ""
PY_SCRIPT="./grpo.py"
PY_CONFIG="./config/grpo.yaml"
ACCELERATE_DS_CONFIG="./config/ds_zero2.yaml"
echo "START TIME: $(date)"
export WANDB_PROJECT="reasoning-gym-trl"
accelerate launch \
--config_file "${ACCELERATE_DS_CONFIG}" \
--main_process_port=29500 \
--num_processes="${NUM_PROCESSES_TRAINING}" "${PY_SCRIPT}" --config "${PY_CONFIG}"
echo "END TIME: $(date)"
echo "DONE"