qwen math training code (#435)

* qwen math training code

* pre-commit
This commit is contained in:
Zafir Stojanovski 2025-05-16 13:19:19 +02:00 committed by GitHub
parent 47303211b3
commit 0cda6b1205
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
51 changed files with 155089 additions and 0 deletions

View file

@ -0,0 +1,20 @@
#!/bin/bash
MAMBA_ENV="tina"
eval "$(mamba shell hook --shell bash)" && mamba activate "${MAMBA_ENV}"
echo "START TIME: $(date)"
echo "PYTHON ENV: $(which python)"
source "./scripts/set/set_vars.sh"
PY_SCRIPT="./scripts/set/run_download_model.py"
echo ""
echo "Running script: ${PY_SCRIPT}"
echo ""
python "${PY_SCRIPT}"
echo "END TIME: $(date)"
echo "DONE"

View file

@ -0,0 +1,9 @@
import os
from huggingface_hub import snapshot_download
if __name__ == "__main__":
CKPT_DIR = os.environ["CKPT_DIR"]
print("Downloading model ...")
snapshot_download(repo_id="Qwen/Qwen2.5-3B-Instruct", local_dir=f"{CKPT_DIR}/models/Qwen2.5-3B-Instruct/base")

View file

@ -0,0 +1,39 @@
#!/bin/bash
# python 3.10 + cuda 11.8.0
# the execution order the following commands matter
export MKL_NUM_THREADS=1
export NUMEXPR_NUM_THREADS=1
export OPENBLAS_NUM_THREADS=1
export OMP_NUM_THREADS=1
conda clean -a -y
mamba clean -a -y
pip install --upgrade pip
pip cache purge
# cuda, gcc/g++, torch
# mamba install cuda -c nvidia/label/cuda-11.8.0 -y
# mamba install gcc gxx -c conda-forge -y
pip install torch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 --index-url https://download.pytorch.org/whl/cu118
# xformers
pip install xformers==0.0.28.post3 --index-url https://download.pytorch.org/whl/cu118
# vLLM pre-compiled with CUDA 11.8
pip install https://github.com/vllm-project/vllm/releases/download/v0.7.2/vllm-0.7.2+cu118-cp38-abi3-manylinux1_x86_64.whl
pip install deepspeed
pip install flash-attn==2.7.3 --no-build-isolation
pip install peft
pip install "trl==0.15.2"
pip install latex2sympy2_extended
pip install "math_verify==0.5.2"
pip install word2number
pip install scipy
pip install wandb
pip install plotly
pip install matplotlib
pip install seaborn

View file

@ -0,0 +1,46 @@
#!/bin/bash
# python 3.11 & cuda 11.8
export MKL_NUM_THREADS=1
export NUMEXPR_NUM_THREADS=1
export OPENBLAS_NUM_THREADS=1
export OMP_NUM_THREADS=1
conda clean -a -y
mamba clean -a -y
pip install --upgrade pip
pip cache purge
# mamba install cuda -c nvidia/label/cuda-11.8.0 -y
# mamba install gcc gxx -c conda-forge -y
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
pip install xformers --index-url https://download.pytorch.org/whl/cu118
pip install vllm
pip install flash-attn --no-build-isolation
pip install accelerate
pip install datasets
pip install deepspeed
pip install distilabel[vllm,ray,openai]
pip install e2b-code-interpreter
pip install einops
pip install flake8
pip install huggingface_hub
pip install hf_transfer
pip install isort
pip install langdetect
pip install latex2sympy2_extended
pip install liger_kernel
pip install "math_verify==0.5.2"
pip install packaging
pip install parameterized
pip install peft
pip install pytest
pip install python-dotenv
pip install ruff
pip install safetensors
pip install sentencepiece
pip install transformers
pip install trl@git+https://github.com/huggingface/trl.git
pip install wandb

View file

@ -0,0 +1,53 @@
#!/bin/bash
export CUDA_LAUNCH_BLOCKING=1
export DS_LOG_LEVEL=error
export TOKENIZERS_PARALLELISM=false
export NCCL_P2P_DISABLE=1
export NCCL_SHM_DISABLE=1
export NCCL_IB_DISABLE=1
export MKL_THREADING_LAYER=GNU
export MKL_NUM_THREADS=1
export NUMEXPR_NUM_THREADS=1
export OPENBLAS_NUM_THREADS=1
export OMP_NUM_THREADS=1
## basic setup for the env
export PROJECT_PREFIX="/root/projects" # e.g. /home/username/projects
export SCRATCH_PREFIX="/root/scratch" # e.g. /home/username/scratch
mkdir -p "${PROJECT_PREFIX}" "${SCRATCH_PREFIX}"
export PROJECT_NAME="rg-math"
export CORE_POSTFIX="tina"
export PROJECT_DIR="${PROJECT_PREFIX}/${PROJECT_NAME}"
export PYTHONPATH="${PROJECT_DIR}":$PYTHONPATH
export PYTHONPATH="${PROJECT_DIR}/${CORE_POSTFIX}":$PYTHONPATH
mkdir -p "${PROJECT_PREFIX}/${PROJECT_NAME}"
export CKPT_DIR="${PROJECT_DIR}/ckpts"
export DATA_DIR="${PROJECT_DIR}/datasets"
export OUTPUT_DIR="${PROJECT_DIR}/outputs"
export LOGGING_DIR="${PROJECT_DIR}/logs"
mkdir -p "${CKPT_DIR}" "${DATA_DIR}" "${OUTPUT_DIR}" "${LOGGING_DIR}"
## wandb setup
# export WANDB_API_KEY="TODO"
export WANDB_PROJECT="${PROJECT_NAME}"
export WANDB_DIR="${OUTPUT_DIR}"
wandb login $WANDB_API_KEY
export CACHE_DIR="${PROJECT_DIR}/.cache"
export WANDB_CACHE_DIR="${CACHE_DIR}"
export TRITON_CACHE_DIR="${CACHE_DIR}/triton_cache"
## huggingface setup
# export HF_TOKEN="TODO"
git config --global credential.helper store
huggingface-cli login --token $HF_TOKEN --add-to-git-credential
export HF_HOME="${CACHE_DIR}/huggingface"
export HUGGINGFACE_HUB_CACHE="${HF_HOME}/hub"
export HF_DATASETS_CACHE="${HF_HOME}/datasets"

View file

@ -0,0 +1,54 @@
#!/bin/bash
#SBATCH --job-name=grpo_multinode
#SBATCH -D .
#SBATCH --partition=TODO
#SBATCH --account=TODO
#SBATCH --output=output-%x.%j
#SBATCH --error=error-%x.%j
#SBATCH --nodes=2 # number of nodes
#SBATCH --ntasks-per-node=1 # number of MP tasks
#SBATCH --gres=gpu:2 # number of GPUs per node
#SBATCH --cpus-per-task=8 # number of cores per tasks
#SBATCH --mem=128G
#SBATCH --time=48:00:00 # maximum execution time (HH:MM:SS)
#SBATCH --comment "Key=Monitoring,Value=ON"
#SBATCH --exclusive
######################
### Set environment ##
######################
ulimit -s unlimited
MAMBA_ENV="tina"
eval "$(mamba shell hook --shell bash)" && mamba activate "${MAMBA_ENV}"
echo "START TIME: $(date)"
echo "PYTHON ENV: $(which python)"
source "./scripts/set/set_vars.sh"
export GPUS_PER_NODE=2
######################
######################
#### Set network #####
######################
head_node_ip=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1)
######################
export LAUNCHER="accelerate launch \
--num_processes $((SLURM_NNODES * GPUS_PER_NODE)) \
--num_machines $SLURM_NNODES \
--machine_rank $SLURM_NODEID \
--rdzv_backend c10d \
--main_process_ip $head_node_ip \
--main_process_port 29500 \
"
PY_SCRIPT="./tina/post_train_hf/grpo.py"
PY_CONFIG="./recipes/DeepSeek-R1-Distill-Qwen-1.5B/grpo/model_curated_deepscaler.yaml"
# This step is necessary because accelerate launch does not handle multiline arguments properly
export CMD="$LAUNCHER $PY_SCRIPT --config $PY_CONFIG"
srun $CMD

View file

@ -0,0 +1,37 @@
#!/bin/bash
# use by running `bash sbatch_launch.sh <script.slurm>`
cleanup() {
echo "Script interrupted. Cleaning up..."
scancel "$job_id" 2>/dev/null
echo "Job $job_id has been canceled."
exit 1
}
trap cleanup SIGINT
# launch the slurm script
SLURM_FILE=$1
echo "Launching $SLURM_FILE ..."
job_id=$(sbatch $SLURM_FILE | awk '{print $4}')
echo "Submitted job with ID: $job_id"
# Wait until the job is running
while true; do
job_status=$(squeue -j "$job_id" -h -o "%T")
if [ "$job_status" == "RUNNING" ]; then
echo "Job $job_id is now running."
sleep 5
break
elif [ -z "$job_status" ]; then
echo "Job $job_id has finished or failed before reaching running state."
exit 1
else
echo "Job $job_id is still in $job_status state. Checking again in 10 seconds..."
sleep 10
fi
done
# Plot the real-time output
output_file=$(scontrol show job "$job_id" | awk -F= '/StdOut/ {print $2}' | sed "s/%A/${job_id}/g" | sed "s/%a/1/g")
echo "Tailing output file: $output_file"
tail -f "$output_file"

View file

@ -0,0 +1,32 @@
#!/bin/bash
MAMBA_ENV="tina_eval"
eval "$(mamba shell hook --shell bash)" && mamba activate "${MAMBA_ENV}"
echo "START TIME: $(date)"
echo "PYTHON ENV: $(which python)"
source "./scripts/set/set_vars.sh"
export CUDA_VISIBLE_DEVICES="0" # NOTE: update this if you have more than 1 GPU
GPU_COUNT=$(python -c "import torch; print(torch.cuda.device_count())")
echo ""
echo "GPU_COUNT: $GPU_COUNT"
echo ""
# MODEL_LIST=("deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B" "starzmustdie/DeepSeek-R1-Distill-Qwen-1.5B-rg-math" "agentica-org/DeepScaleR-1.5B-Preview" "knoveleng/Open-RS3" "RUC-AIBOX/STILL-3-1.5B-preview")
MODEL_LIST=("Qwen/Qwen2.5-3B-Instruct" "starzmustdie/Qwen2.5-3B-Instruct")
TASKS=("aime24" "aime25" "amc23" "minerva" "math_500")
for MODEL_NAME in "${MODEL_LIST[@]}"; do
for TASK in "${TASKS[@]}"; do
MODEL_ARGS="model_name=$MODEL_NAME,dtype=bfloat16,data_parallel_size=$GPU_COUNT,max_model_length=32768,gpu_memory_utilization=0.7,generation_parameters={max_new_tokens:32768,temperature:0.6,top_p:0.95}"
lighteval vllm $MODEL_ARGS "custom|$TASK|0|0" \
--custom-tasks ./scripts/training/run_post_train_eval.py \
--use-chat-template \
--output-dir "${OUTPUT_DIR}/${TASK}"
done
done
echo "END TIME: $(date)"
echo "DONE"

View file

@ -0,0 +1,73 @@
#!/bin/bash
MAMBA_ENV="tina_eval"
eval "$(mamba shell hook --shell bash)" && mamba activate "${MAMBA_ENV}"
echo "START TIME: $(date)"
echo "PYTHON ENV: $(which python)"
source "./scripts/set/set_vars.sh"
export CUDA_VISIBLE_DEVICES=0,1 # make sure all evaluation run on 2 GPUs
GPU_COUNT=$(python -c "import torch; print(torch.cuda.device_count())")
echo ""
echo "GPU_COUNT: $GPU_COUNT, make sure using 2 GPUs."
echo ""
MODEL_NAME="DeepSeek-R1-Distill-Qwen-1.5B"
PT_TYPE="grpo"
## Main datasets
DATASET_NAME="curated_deepscaler"
#DATASET_NAME="curated_still"
#DATASET_NAME="curated_open_rs3"
#DATASET_NAME="curated_open_rs2"
#DATASET_NAME="curated_open_rs1"
## Extra datasets
#DATASET_NAME="curated_limr"
#DATASET_NAME="curated_open_r1"
#DATASET_NAME="curated_thoughts"
## Ablation
#DATASET_NAME="curated_limr_large_lr_ablation"
#DATASET_NAME="curated_limr_small_lr_ablation"
#DATASET_NAME="curated_limr_large_rank_ablation"
#DATASET_NAME="curated_limr_medium_rank_ablation"
#DATASET_NAME="curated_limr_small_rank_ablation"
#DATASET_NAME="curated_limr_tiny_rank_ablation"
#DATASET_NAME="curated_open_rs3_drgrpo_ablation"
CKPT_LIST=$(ls "${CKPT_DIR}/models/${MODEL_NAME}/${PT_TYPE}_${DATASET_NAME}" | grep -E "^checkpoint-[0-9]+$")
#CKPT_LIST=("checkpoint-XXX")
# loop over all the checkpoints in the list
for CKPT in "${CKPT_LIST[@]}"; do
echo "Running model post train merging base and adapter for checkpoint: ${CKPT}"
python ./scripts/training/run_post_train_merge.py \
--model_name "${MODEL_NAME}" \
--adapter_type "${PT_TYPE}_${DATASET_NAME}" \
--ckpt "${CKPT}"
MODEL_PATH="${CKPT_DIR}/models/${MODEL_NAME}/${PT_TYPE}_${DATASET_NAME}/${CKPT}-merged"
# Set model arguments (ensure that MODEL_PATH, GPU_COUNT, OUTPUT_DIR, and MODEL are defined)
MODEL_ARGS="pretrained=$MODEL_PATH,dtype=bfloat16,data_parallel_size=$GPU_COUNT,max_model_length=32768,gpu_memory_utilization=0.5,generation_parameters={max_new_tokens:32768,temperature:0.6,top_p:0.95}"
# Define an array of tasks to evaluate
tasks=("aime24" "math_500" "gpqa:diamond" "aime25" "amc23" "minerva")
# Loop over each task and evaluate
for TASK in "${tasks[@]}"; do
echo "Evaluating task: $TASK"
lighteval vllm $MODEL_ARGS "custom|$TASK|0|0" \
--custom-tasks ./scripts/training/run_post_train_eval.py \
--use-chat-template \
--output-dir "${OUTPUT_DIR}/${MODEL}/${TASK}"
done
done
echo "END TIME: $(date)"
echo "DONE"

View file

@ -0,0 +1,68 @@
#!/bin/bash
MAMBA_ENV="tina"
eval "$(mamba shell hook --shell bash)" && mamba activate "${MAMBA_ENV}"
echo "START TIME: $(date)"
echo "PYTHON ENV: $(which python)"
source "./scripts/set/set_vars.sh"
export CUDA_VISIBLE_DEVICES=0,1,2 # Set the GPUs you want to use
GPU_COUNT=$(python -c "import torch; print(torch.cuda.device_count())")
NUM_PROCESSES_TRAINING=$((GPU_COUNT - 1))
echo ""
echo "Number of GPUs: ${GPU_COUNT}"
echo "Number of processes for training: ${NUM_PROCESSES_TRAINING}"
echo ""
MODEL_NAME="Qwen2.5-3B-Instruct"
## Main datasets
#DATASET_NAME="curated_deepscaler"
#DATASET_NAME="curated_still"
#DATASET_NAME="curated_open_rs3"
#DATASET_NAME="curated_open_rs2"
#DATASET_NAME="curated_open_rs1"
## Extra datasets
#DATASET_NAME="curated_limr"
#DATASET_NAME="curated_open_r1"
#DATASET_NAME="curated_thoughts"
## Ablation
#DATASET_NAME="curated_limr_large_lr_ablation"
#DATASET_NAME="curated_limr_small_lr_ablation"
#DATASET_NAME="curated_limr_large_rank_ablation"
#DATASET_NAME="curated_limr_medium_rank_ablation"
#DATASET_NAME="curated_limr_small_rank_ablation"
#DATASET_NAME="curated_limr_tiny_rank_ablation"
#DATASET_NAME="curated_open_rs3_drgrpo_ablation"
## Reasoning Gym
DATASET_NAME="curated_rg_math"
PY_SCRIPT="./tina/post_train_hf/grpo.py"
PY_CONFIG="./recipes/${MODEL_NAME}/grpo/model_${DATASET_NAME}.yaml"
ACCELERATE_DS_CONFIG="./recipes/accelerate_ds_cfgs/ds_zero2.yaml"
echo ""
echo "Running ${PY_SCRIPT} on model ${MODEL_NAME} with dataset ${DATASET_NAME}"
echo ""
if [[ "${DATASET_NAME}" == "curated_thoughts" || "${DATASET_NAME}" == "curated_open_r1" || "${DATASET_NAME}" == "curated_open_rs3" || "${DATASET_NAME}" == "curated_open_rs3_drgrpo_ablation" ]]; then
ACCELERATE_LOG_LEVEL=info accelerate launch \
--config_file "${ACCELERATE_DS_CONFIG}" \
--main_process_port=29500 \
--num_processes="${NUM_PROCESSES_TRAINING}" "${PY_SCRIPT}" --config "${PY_CONFIG}" --cosine_max_len 3584
else
ACCELERATE_LOG_LEVEL=info accelerate launch \
--config_file "${ACCELERATE_DS_CONFIG}" \
--main_process_port=29500 \
--num_processes="${NUM_PROCESSES_TRAINING}" "${PY_SCRIPT}" --config "${PY_CONFIG}"
fi
echo "END TIME: $(date)"
echo "DONE"

View file

@ -0,0 +1,229 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Custom evaluation tasks for LightEval."""
import numpy as np
from lighteval.metrics.dynamic_metrics import (
ExprExtractionConfig,
LatexExtractionConfig,
compare_gold_target,
extract_target_from_pred,
get_extraction_regexes,
)
from lighteval.metrics.metrics import Metrics
from lighteval.metrics.metrics_sample import PassAtK
from lighteval.metrics.utils.metric_utils import MetricCategory, MetricUseCase, SampleLevelMetric
from lighteval.tasks.lighteval_task import LightevalTaskConfig
from lighteval.tasks.requests import Doc
from lighteval.utils.language import Language
# Prompt template adapted from
# - simple-evals: https://github.com/openai/simple-evals/blob/6e84f4e2aed6b60f6a0c7b8f06bbbf4bfde72e58/math_eval.py#L17
# - Llama 3: https://huggingface.co/datasets/meta-llama/Llama-3.2-1B-Instruct-evals/viewer/Llama-3.2-1B-Instruct-evals__math__details?views%5B%5D=llama_32_1b_instruct_evals__math__details
# Note that it is important to have the final answer in a box for math-verify to work correctly
MATH_QUERY_TEMPLATE = """
Solve the following math problem efficiently and clearly. The last line of your response should be of the following format: 'Therefore, the final answer is: $\\boxed{{ANSWER}}$. I hope it is correct' (without quotes) where ANSWER is just the final number or expression that solves the problem. Think step by step before answering.
{Question}
""".strip()
math_pass_at_1_2n = SampleLevelMetric(
metric_name="math_pass@1:2_samples",
sample_level_fn=PassAtK(
k=1,
n=2,
strip_strings=True,
# Extracting mathematical expressions and latex expressions
normalize_gold=lambda k: extract_target_from_pred(
k,
get_extraction_regexes(
formatted_doc=None,
target_types=[ExprExtractionConfig(), LatexExtractionConfig()],
language=Language.ENGLISH,
),
),
# Extracting mathematical expressions and latex expressions
normalize_pred=lambda k: extract_target_from_pred(
k,
get_extraction_regexes(
formatted_doc=None,
target_types=[ExprExtractionConfig(), LatexExtractionConfig()],
language=Language.ENGLISH,
),
),
# Uses sympy for comparision
sample_scoring_function=compare_gold_target,
).compute,
category=MetricCategory.GENERATIVE_SAMPLING,
use_case=MetricUseCase.REASONING,
corpus_level_fn=np.mean,
higher_is_better=True,
)
def math_prompt_fn(line, task_name: str = None):
return Doc(
task_name=task_name,
query=MATH_QUERY_TEMPLATE.format(Question=line["problem"]),
choices=[line["solution"]],
gold_index=0,
)
def aime_prompt_fn(line, task_name: str = None):
return Doc(
task_name=task_name,
query=MATH_QUERY_TEMPLATE.format(Question=line["problem"]),
choices=[line["answer"]],
gold_index=0,
)
def amc_prompt_fn(line, task_name: str = None):
return Doc(
task_name=task_name,
query=MATH_QUERY_TEMPLATE.format(Question=line["problem"]),
choices=[line["answer"]],
gold_index=0,
)
def minerva_prompt_fn(line, task_name: str = None):
return Doc(
task_name=task_name,
query=MATH_QUERY_TEMPLATE.format(Question=line["problem"]),
choices=[line["solution"]],
gold_index=0,
)
# Define tasks
aime24 = LightevalTaskConfig(
name="aime24",
suite=["custom"],
prompt_function=aime_prompt_fn,
hf_repo="HuggingFaceH4/aime_2024",
hf_subset="default",
hf_avail_splits=["train"],
evaluation_splits=["train"],
few_shots_split=None,
few_shots_select=None,
generation_size=32768,
metric=[
# Metrics.math_pass_at_1_1n,
# math_pass_at_1_2n,
# Metrics.math_pass_at_1_4n,
# Metrics.math_pass_at_1_16n,
Metrics.math_pass_at_1_32n,
# Metrics.math_pass_at_1_64n,
],
version=1,
)
aime25 = LightevalTaskConfig(
name="aime25",
suite=["custom"],
prompt_function=aime_prompt_fn,
hf_repo="yentinglin/aime_2025",
hf_subset="default",
hf_avail_splits=["train"],
evaluation_splits=["train"],
few_shots_split=None,
few_shots_select=None,
generation_size=32768,
metric=[
# Metrics.math_pass_at_1_1n,
# math_pass_at_1_2n,
# Metrics.math_pass_at_1_4n,
# Metrics.math_pass_at_1_16n,
Metrics.math_pass_at_1_32n,
# Metrics.math_pass_at_1_64n,
],
version=1,
)
amc23 = LightevalTaskConfig(
name="amc23",
suite=["custom"],
prompt_function=amc_prompt_fn,
hf_repo="knoveleng/AMC-23",
hf_subset="default",
hf_avail_splits=["train"],
evaluation_splits=["train"],
few_shots_split=None,
few_shots_select=None,
generation_size=32768,
metric=[
# Metrics.math_pass_at_1_1n,
# math_pass_at_1_2n,
# Metrics.math_pass_at_1_4n,
# Metrics.math_pass_at_1_16n,
Metrics.math_pass_at_1_32n,
# Metrics.math_pass_at_1_64n,
],
version=1,
)
math_500 = LightevalTaskConfig(
name="math_500",
suite=["custom"],
prompt_function=math_prompt_fn,
hf_repo="HuggingFaceH4/MATH-500",
hf_subset="default",
hf_avail_splits=["test"],
evaluation_splits=["test"],
few_shots_split=None,
few_shots_select=None,
generation_size=32768,
metric=[
# Metrics.math_pass_at_1_1n,
math_pass_at_1_2n,
],
version=1,
)
minerva = LightevalTaskConfig(
name="minerva",
suite=["custom"],
prompt_function=minerva_prompt_fn,
hf_repo="knoveleng/Minerva-Math",
hf_subset="default",
hf_avail_splits=["train"],
evaluation_splits=["train"],
few_shots_split=None,
few_shots_select=None,
generation_size=32768,
metric=[
# Metrics.math_pass_at_1_1n,
# math_pass_at_1_2n,
Metrics.math_pass_at_1_4n,
],
version=1,
)
# Add tasks to the table
TASKS_TABLE = []
TASKS_TABLE.append(aime24)
TASKS_TABLE.append(aime25)
TASKS_TABLE.append(amc23)
TASKS_TABLE.append(math_500)
TASKS_TABLE.append(minerva)
# MODULE LOGIC
if __name__ == "__main__":
print([t["name"] for t in TASKS_TABLE])
print(len(TASKS_TABLE))

View file

@ -0,0 +1,40 @@
import argparse
import os
import torch
from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer
def argparser():
parser = argparse.ArgumentParser()
parser.add_argument("--model_name", type=str, default="DeepSeek-R1-Distill-Qwen-1.5B")
parser.add_argument("--adapter_type", type=str, default="grpo_curated_open_r1")
parser.add_argument("--ckpt", type=str, default="checkpoint-500")
return parser.parse_args()
if __name__ == "__main__":
args = argparser()
ckpt_dir = os.environ["CKPT_DIR"]
ckpt = args.ckpt
adapter_type = args.adapter_type
model_name = args.model_name
base_model_name_or_path = f"{ckpt_dir}/models/{model_name}/base"
adapter_model_name_or_path = f"{ckpt_dir}/models/{model_name}/{adapter_type}/{ckpt}"
merged_model_name_or_path = f"{ckpt_dir}/models/{model_name}/{adapter_type}/{ckpt}-merged"
print("Merged model will be saved to: ", merged_model_name_or_path)
base_model = AutoModelForCausalLM.from_pretrained(
base_model_name_or_path, torch_dtype=torch.bfloat16, device_map="auto"
) # Automatically distributes across available GPUs
model = PeftModel.from_pretrained(base_model, adapter_model_name_or_path)
model = model.merge_and_unload()
model.save_pretrained(merged_model_name_or_path)
tokenizer = AutoTokenizer.from_pretrained(base_model_name_or_path)
tokenizer.save_pretrained(merged_model_name_or_path)