mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-19 12:58:07 +00:00
* v0 * 2 gpu setup * improve parsing from yaml * update yaml dataset example * remove restriction on flash attn * more comments * first version of the readme * pin torch * simplify requirements * just flash attn * use set env instead * simpler set env * readme * add wandb project to setup * update template * update model id * post init to capture the config and weight * extract metadata * update config * update dataset config * move env for wandb project * pre-commit * remove qwen-math from training * more instructions * unused import * remove trl old * warmup ratio * warmup ratio * change model id * change model_id * add info about CUDA_VISIBLE_DEVICES
26 lines
673 B
Bash
Executable file
26 lines
673 B
Bash
Executable file
#!/bin/bash
|
|
|
|
export CUDA_VISIBLE_DEVICES=0,1
|
|
GPU_COUNT=$(python -c "import torch; print(torch.cuda.device_count())")
|
|
NUM_PROCESSES_TRAINING=$((GPU_COUNT - 1))
|
|
|
|
echo ""
|
|
echo "Number of GPUs: ${GPU_COUNT}"
|
|
echo "Number of processes for training: ${NUM_PROCESSES_TRAINING}"
|
|
echo ""
|
|
|
|
PY_SCRIPT="./grpo.py"
|
|
PY_CONFIG="./config/grpo.yaml"
|
|
ACCELERATE_DS_CONFIG="./config/ds_zero2.yaml"
|
|
|
|
echo "START TIME: $(date)"
|
|
|
|
export WANDB_PROJECT="reasoning-gym-trl"
|
|
|
|
accelerate launch \
|
|
--config_file "${ACCELERATE_DS_CONFIG}" \
|
|
--main_process_port=29500 \
|
|
--num_processes="${NUM_PROCESSES_TRAINING}" "${PY_SCRIPT}" --config "${PY_CONFIG}"
|
|
|
|
echo "END TIME: $(date)"
|
|
echo "DONE"
|