qwen math training code (#435)

* qwen math training code

* pre-commit
This commit is contained in:
Zafir Stojanovski 2025-05-16 13:19:19 +02:00 committed by GitHub
parent 47303211b3
commit 0cda6b1205
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
51 changed files with 155089 additions and 0 deletions

View file

@ -0,0 +1,54 @@
#!/bin/bash
#SBATCH --job-name=grpo_multinode
#SBATCH -D .
#SBATCH --partition=TODO
#SBATCH --account=TODO
#SBATCH --output=output-%x.%j
#SBATCH --error=error-%x.%j
#SBATCH --nodes=2 # number of nodes
#SBATCH --ntasks-per-node=1 # number of MP tasks
#SBATCH --gres=gpu:2 # number of GPUs per node
#SBATCH --cpus-per-task=8 # number of cores per tasks
#SBATCH --mem=128G
#SBATCH --time=48:00:00 # maximum execution time (HH:MM:SS)
#SBATCH --comment "Key=Monitoring,Value=ON"
#SBATCH --exclusive
######################
### Set environment ##
######################
ulimit -s unlimited
MAMBA_ENV="tina"
eval "$(mamba shell hook --shell bash)" && mamba activate "${MAMBA_ENV}"
echo "START TIME: $(date)"
echo "PYTHON ENV: $(which python)"
source "./scripts/set/set_vars.sh"
export GPUS_PER_NODE=2
######################
######################
#### Set network #####
######################
head_node_ip=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1)
######################
export LAUNCHER="accelerate launch \
--num_processes $((SLURM_NNODES * GPUS_PER_NODE)) \
--num_machines $SLURM_NNODES \
--machine_rank $SLURM_NODEID \
--rdzv_backend c10d \
--main_process_ip $head_node_ip \
--main_process_port 29500 \
"
PY_SCRIPT="./tina/post_train_hf/grpo.py"
PY_CONFIG="./recipes/DeepSeek-R1-Distill-Qwen-1.5B/grpo/model_curated_deepscaler.yaml"
# This step is necessary because accelerate launch does not handle multiline arguments properly
export CMD="$LAUNCHER $PY_SCRIPT --config $PY_CONFIG"
srun $CMD

View file

@ -0,0 +1,37 @@
#!/bin/bash
# use by running `bash sbatch_launch.sh <script.slurm>`
cleanup() {
echo "Script interrupted. Cleaning up..."
scancel "$job_id" 2>/dev/null
echo "Job $job_id has been canceled."
exit 1
}
trap cleanup SIGINT
# launch the slurm script
SLURM_FILE=$1
echo "Launching $SLURM_FILE ..."
job_id=$(sbatch $SLURM_FILE | awk '{print $4}')
echo "Submitted job with ID: $job_id"
# Wait until the job is running
while true; do
job_status=$(squeue -j "$job_id" -h -o "%T")
if [ "$job_status" == "RUNNING" ]; then
echo "Job $job_id is now running."
sleep 5
break
elif [ -z "$job_status" ]; then
echo "Job $job_id has finished or failed before reaching running state."
exit 1
else
echo "Job $job_id is still in $job_status state. Checking again in 10 seconds..."
sleep 10
fi
done
# Plot the real-time output
output_file=$(scontrol show job "$job_id" | awk -F= '/StdOut/ {print $2}' | sed "s/%A/${job_id}/g" | sed "s/%a/1/g")
echo "Tailing output file: $output_file"
tail -f "$output_file"