mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
testing set up
This commit is contained in:
parent
f44eb810bf
commit
530fed2877
8 changed files with 599 additions and 2 deletions
|
|
@ -163,6 +163,23 @@ def add_grpo_args(parser: argparse.ArgumentParser) -> None:
|
|||
default=0.2,
|
||||
help="PPO-style clipping epsilon. Clips ratio to [1-eps, 1+eps].",
|
||||
)
|
||||
group.add_argument(
|
||||
"--distill-enabled",
|
||||
action="store_true",
|
||||
help="Enable teacher distillation loss (requires distill payload in Atropos batch).",
|
||||
)
|
||||
group.add_argument(
|
||||
"--distill-coef",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="Coefficient for distillation loss term.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--distill-temperature",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="Temperature for teacher top-k distribution in distillation loss.",
|
||||
)
|
||||
|
||||
|
||||
def add_vllm_args(parser: argparse.ArgumentParser) -> None:
|
||||
|
|
@ -424,6 +441,9 @@ def config_from_args(args: argparse.Namespace) -> TrainingConfig:
|
|||
checkpoint_interval=getattr(args, "checkpoint_interval", 3),
|
||||
# GRPO/PPO hyperparameters
|
||||
clip_eps=getattr(args, "clip_eps", 0.2),
|
||||
distill_enabled=getattr(args, "distill_enabled", False),
|
||||
distill_coef=getattr(args, "distill_coef", 0.0),
|
||||
distill_temperature=getattr(args, "distill_temperature", 1.0),
|
||||
adafactor_scale_parameter=getattr(args, "adafactor_scale_parameter", False),
|
||||
adafactor_relative_step=getattr(args, "adafactor_relative_step", False),
|
||||
# vLLM settings
|
||||
|
|
|
|||
|
|
@ -69,6 +69,18 @@ class TrainingConfig(BaseModel):
|
|||
"Prevents large policy updates that could destabilize training."
|
||||
),
|
||||
)
|
||||
distill_enabled: bool = Field(
|
||||
False,
|
||||
description="Enable teacher distillation loss when distill tensors are present.",
|
||||
)
|
||||
distill_coef: float = Field(
|
||||
0.0,
|
||||
description="Weight for distillation loss in total loss.",
|
||||
)
|
||||
distill_temperature: float = Field(
|
||||
1.0,
|
||||
description="Temperature applied when converting teacher top-k logprobs.",
|
||||
)
|
||||
# === Device & Storage ===
|
||||
device: str = Field(
|
||||
"cuda" if torch.cuda.is_available() else "cpu", description="Device to train on"
|
||||
|
|
|
|||
|
|
@ -29,6 +29,8 @@ def pad_data_to_good_offset(
|
|||
List[torch.Tensor], # advantage_batches
|
||||
List[torch.Tensor], # temperature_batches
|
||||
Optional[List[torch.Tensor]], # inference_logprob_batches (aligned with labels)
|
||||
Optional[List[torch.Tensor]], # distill_token_id_batches [batch, seq, k]
|
||||
Optional[List[torch.Tensor]], # distill_logprob_batches [batch, seq, k]
|
||||
]:
|
||||
"""
|
||||
Pad and batch data from the Atropos API.
|
||||
|
|
@ -45,7 +47,8 @@ def pad_data_to_good_offset(
|
|||
extract_inference_logprobs: Whether to extract inference logprobs
|
||||
|
||||
Returns:
|
||||
Tuple of (token_batches, label_batches, advantage_batches, temperature_batches, inference_logprob_batches)
|
||||
Tuple of (token_batches, label_batches, advantage_batches, temperature_batches,
|
||||
inference_logprob_batches, distill_token_id_batches, distill_logprob_batches)
|
||||
inference_logprob_batches is None if extract_inference_logprobs=False or no logprobs in data
|
||||
|
||||
Note:
|
||||
|
|
@ -73,6 +76,10 @@ def pad_data_to_good_offset(
|
|||
temperatures = []
|
||||
inference_logprobs_padded: List[np.ndarray] = [] # Padded to match labels shape
|
||||
has_any_logprobs = False
|
||||
distill_token_ids_padded: List[np.ndarray] = []
|
||||
distill_logprobs_padded: List[np.ndarray] = []
|
||||
has_any_distill = False
|
||||
max_distill_k = 1
|
||||
|
||||
for item in data["batch"]:
|
||||
# Normalize advantage scores
|
||||
|
|
@ -153,6 +160,77 @@ def pad_data_to_good_offset(
|
|||
np.full(token_setup_len - 1, 1.0, dtype=np.float32)
|
||||
)
|
||||
|
||||
# Extract teacher distillation top-k arrays if available.
|
||||
# Expected shape in incoming payload: [sequence][position][k].
|
||||
if "distill_token_ids" in item and "distill_logprobs" in item:
|
||||
seq_token_ids = item["distill_token_ids"]
|
||||
seq_logprobs = item["distill_logprobs"]
|
||||
if (
|
||||
isinstance(seq_token_ids, list)
|
||||
and isinstance(seq_logprobs, list)
|
||||
and i < len(seq_token_ids)
|
||||
and i < len(seq_logprobs)
|
||||
and seq_token_ids[i] is not None
|
||||
and seq_logprobs[i] is not None
|
||||
):
|
||||
per_pos_token_ids = seq_token_ids[i]
|
||||
per_pos_logprobs = seq_logprobs[i]
|
||||
if (
|
||||
isinstance(per_pos_token_ids, list)
|
||||
and isinstance(per_pos_logprobs, list)
|
||||
and len(per_pos_token_ids) == len(per_pos_logprobs)
|
||||
):
|
||||
local_k = 1
|
||||
for row_ids in per_pos_token_ids:
|
||||
if isinstance(row_ids, list):
|
||||
local_k = max(local_k, len(row_ids))
|
||||
max_distill_k = max(max_distill_k, local_k)
|
||||
has_any_distill = True
|
||||
|
||||
rows = max(0, token_setup_len - 1)
|
||||
token_mat = np.full((rows, local_k), -1, dtype=np.int64)
|
||||
logprob_mat = np.full(
|
||||
(rows, local_k), -1e9, dtype=np.float32
|
||||
)
|
||||
|
||||
# Shift by one to align with causal labels like inference_logprobs.
|
||||
copy_positions = min(
|
||||
len(per_pos_token_ids), len(per_pos_logprobs), token_setup_len
|
||||
)
|
||||
for pos in range(1, copy_positions):
|
||||
src_ids = per_pos_token_ids[pos]
|
||||
src_lps = per_pos_logprobs[pos]
|
||||
if not isinstance(src_ids, list) or not isinstance(src_lps, list):
|
||||
continue
|
||||
topk = min(local_k, len(src_ids), len(src_lps))
|
||||
if topk <= 0:
|
||||
continue
|
||||
token_mat[pos - 1, :topk] = np.array(src_ids[:topk], dtype=np.int64)
|
||||
logprob_mat[pos - 1, :topk] = np.array(
|
||||
src_lps[:topk], dtype=np.float32
|
||||
)
|
||||
|
||||
distill_token_ids_padded.append(token_mat)
|
||||
distill_logprobs_padded.append(logprob_mat)
|
||||
else:
|
||||
rows = max(0, token_setup_len - 1)
|
||||
distill_token_ids_padded.append(
|
||||
np.full((rows, 1), -1, dtype=np.int64)
|
||||
)
|
||||
distill_logprobs_padded.append(
|
||||
np.full((rows, 1), -1e9, dtype=np.float32)
|
||||
)
|
||||
else:
|
||||
rows = max(0, token_setup_len - 1)
|
||||
distill_token_ids_padded.append(np.full((rows, 1), -1, dtype=np.int64))
|
||||
distill_logprobs_padded.append(
|
||||
np.full((rows, 1), -1e9, dtype=np.float32)
|
||||
)
|
||||
else:
|
||||
rows = max(0, token_setup_len - 1)
|
||||
distill_token_ids_padded.append(np.full((rows, 1), -1, dtype=np.int64))
|
||||
distill_logprobs_padded.append(np.full((rows, 1), -1e9, dtype=np.float32))
|
||||
|
||||
# Extract temperature (priority: override > generation_params > group_overrides > 1.0)
|
||||
t = 1.0
|
||||
if (
|
||||
|
|
@ -178,6 +256,8 @@ def pad_data_to_good_offset(
|
|||
advantage_batches = []
|
||||
temperature_batches = []
|
||||
inference_logprob_batches = []
|
||||
distill_token_id_batches = []
|
||||
distill_logprob_batches = []
|
||||
|
||||
for start in range(0, len(input_ids), batch_size):
|
||||
end = min(start + batch_size, len(input_ids))
|
||||
|
|
@ -199,12 +279,42 @@ def pad_data_to_good_offset(
|
|||
torch.tensor(np.stack(inference_logprobs_padded[start:end], axis=0))
|
||||
)
|
||||
|
||||
if distill_token_ids_padded and distill_logprobs_padded:
|
||||
seq_slice_ids = distill_token_ids_padded[start:end]
|
||||
seq_slice_lps = distill_logprobs_padded[start:end]
|
||||
normalized_ids = []
|
||||
normalized_lps = []
|
||||
for ids_mat, lps_mat in zip(seq_slice_ids, seq_slice_lps):
|
||||
if ids_mat.shape[1] < max_distill_k:
|
||||
pad_cols = max_distill_k - ids_mat.shape[1]
|
||||
ids_mat = np.pad(
|
||||
ids_mat, ((0, 0), (0, pad_cols)), constant_values=-1
|
||||
)
|
||||
lps_mat = np.pad(
|
||||
lps_mat, ((0, 0), (0, pad_cols)), constant_values=-1e9
|
||||
)
|
||||
normalized_ids.append(ids_mat)
|
||||
normalized_lps.append(lps_mat)
|
||||
|
||||
distill_token_id_batches.append(
|
||||
torch.tensor(np.stack(normalized_ids, axis=0), dtype=torch.long)
|
||||
)
|
||||
distill_logprob_batches.append(
|
||||
torch.tensor(np.stack(normalized_lps, axis=0), dtype=torch.float32)
|
||||
)
|
||||
|
||||
# Return inference logprob batches if we have any real logprobs
|
||||
final_logprob_batches = (
|
||||
inference_logprob_batches
|
||||
if (has_any_logprobs and inference_logprob_batches)
|
||||
else None
|
||||
)
|
||||
final_distill_token_id_batches = (
|
||||
distill_token_id_batches if (has_any_distill and distill_token_id_batches) else None
|
||||
)
|
||||
final_distill_logprob_batches = (
|
||||
distill_logprob_batches if (has_any_distill and distill_logprob_batches) else None
|
||||
)
|
||||
|
||||
return (
|
||||
token_batches,
|
||||
|
|
@ -212,6 +322,8 @@ def pad_data_to_good_offset(
|
|||
advantage_batches,
|
||||
temperature_batches,
|
||||
final_logprob_batches,
|
||||
final_distill_token_id_batches,
|
||||
final_distill_logprob_batches,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -228,6 +340,8 @@ def get_data(
|
|||
List[torch.Tensor], # advantage_batches
|
||||
List[torch.Tensor], # temperature_batches
|
||||
Optional[List[torch.Tensor]], # inference_logprob_batches
|
||||
Optional[List[torch.Tensor]], # distill_token_id_batches
|
||||
Optional[List[torch.Tensor]], # distill_logprob_batches
|
||||
]
|
||||
],
|
||||
None, # Legacy return (no longer used)
|
||||
|
|
@ -299,6 +413,8 @@ def get_data(
|
|||
adv_batches,
|
||||
temp_batches,
|
||||
inf_logprob_batches,
|
||||
distill_token_id_batches,
|
||||
distill_logprob_batches,
|
||||
) = pad_data_to_good_offset(data, batch_size, extract_inference_logprobs)
|
||||
|
||||
# Include inference logprob batches in the tuple
|
||||
|
|
@ -309,6 +425,8 @@ def get_data(
|
|||
adv_batches,
|
||||
temp_batches,
|
||||
inf_logprob_batches,
|
||||
distill_token_id_batches,
|
||||
distill_logprob_batches,
|
||||
)
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -201,6 +201,9 @@ def main():
|
|||
checkpoint_interval=args.checkpoint_interval,
|
||||
# GRPO hyperparameters
|
||||
clip_eps=args.clip_eps,
|
||||
distill_enabled=getattr(args, "distill_enabled", False),
|
||||
distill_coef=getattr(args, "distill_coef", 0.0),
|
||||
distill_temperature=getattr(args, "distill_temperature", 1.0),
|
||||
# vLLM settings
|
||||
vllm_port=args.vllm_port,
|
||||
vllm_gpu_memory_utilization=args.gpu_memory_utilization,
|
||||
|
|
|
|||
267
example_trainer/run_gsm8k_teacher_distill_single_terminal.sh
Executable file
267
example_trainer/run_gsm8k_teacher_distill_single_terminal.sh
Executable file
|
|
@ -0,0 +1,267 @@
|
|||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
|
||||
# Single-terminal teacher-distillation runner.
|
||||
# Starts everything in the background from ONE shell that has GPU access:
|
||||
# 1) Atropos API
|
||||
# 2) Student vLLM server
|
||||
# 3) Teacher vLLM server
|
||||
# 4) GSM8K teacher-distill environment
|
||||
# 5) Example trainer (foreground)
|
||||
#
|
||||
# Usage:
|
||||
# chmod +x example_trainer/run_gsm8k_teacher_distill_single_terminal.sh
|
||||
# ./example_trainer/run_gsm8k_teacher_distill_single_terminal.sh
|
||||
#
|
||||
# Optional overrides:
|
||||
# STUDENT_MODEL="Qwen/Qwen3-4B-Instruct-2507-FP8"
|
||||
# TEACHER_MODEL="Qwen/Qwen3-30B-A3B-Instruct-2507"
|
||||
# STUDENT_GPUS="0,1"
|
||||
# TEACHER_GPUS="4,5,6,7"
|
||||
# TRAINER_GPU="2"
|
||||
# STUDENT_TP=2
|
||||
# TEACHER_TP=4
|
||||
# API_PORT=8002
|
||||
# STUDENT_PORT=9001
|
||||
# TEACHER_PORT=9003
|
||||
# TRAINING_STEPS=100
|
||||
# DISTILL_COEF=0.2
|
||||
# DISTILL_TEMPERATURE=1.0
|
||||
# TEACHER_TOP_K=8
|
||||
# DRY_RUN=1
|
||||
|
||||
ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
|
||||
LAUNCH_DIR="$PWD"
|
||||
cd "$ROOT_DIR"
|
||||
|
||||
STUDENT_MODEL="${STUDENT_MODEL:-Qwen/Qwen3-4B-Instruct-2507-FP8}"
|
||||
TEACHER_MODEL="${TEACHER_MODEL:-Qwen/Qwen3-30B-A3B-Instruct-2507}"
|
||||
|
||||
STUDENT_GPUS="${STUDENT_GPUS:-0,1}"
|
||||
TEACHER_GPUS="${TEACHER_GPUS:-4,5,6,7}"
|
||||
TRAINER_GPU="${TRAINER_GPU:-2}"
|
||||
|
||||
STUDENT_TP="${STUDENT_TP:-2}"
|
||||
TEACHER_TP="${TEACHER_TP:-4}"
|
||||
|
||||
API_PORT="${API_PORT:-8002}"
|
||||
STUDENT_PORT="${STUDENT_PORT:-9001}"
|
||||
TEACHER_PORT="${TEACHER_PORT:-9003}"
|
||||
|
||||
TRAINING_STEPS="${TRAINING_STEPS:-100}"
|
||||
BATCH_SIZE="${BATCH_SIZE:-2}"
|
||||
GRAD_ACCUM="${GRAD_ACCUM:-8}"
|
||||
LR="${LR:-1e-5}"
|
||||
WARMUP_STEPS="${WARMUP_STEPS:-0}"
|
||||
CLIP_EPS="${CLIP_EPS:-0.2}"
|
||||
MAX_MODEL_LEN="${MAX_MODEL_LEN:-16384}"
|
||||
DISTILL_COEF="${DISTILL_COEF:-0.2}"
|
||||
DISTILL_TEMPERATURE="${DISTILL_TEMPERATURE:-1.0}"
|
||||
TEACHER_TOP_K="${TEACHER_TOP_K:-8}"
|
||||
|
||||
STUDENT_GPU_MEMORY_UTILIZATION="${STUDENT_GPU_MEMORY_UTILIZATION:-0.90}"
|
||||
TEACHER_GPU_MEMORY_UTILIZATION="${TEACHER_GPU_MEMORY_UTILIZATION:-0.92}"
|
||||
DTYPE="${DTYPE:-bfloat16}"
|
||||
SAVE_DIR="${SAVE_DIR:-${LAUNCH_DIR}/saves/gsm8k_teacher_distill}"
|
||||
LOG_DIR="${LOG_DIR:-${LAUNCH_DIR}/logs/gsm8k_teacher_distill}"
|
||||
DRY_RUN="${DRY_RUN:-0}"
|
||||
|
||||
ENV_GROUP_SIZE="${ENV_GROUP_SIZE:-4}"
|
||||
ENV_BATCH_SIZE="${ENV_BATCH_SIZE:-16}"
|
||||
ENV_TOTAL_STEPS="${ENV_TOTAL_STEPS:-200}"
|
||||
ENV_STEPS_PER_EVAL="${ENV_STEPS_PER_EVAL:-50}"
|
||||
ENV_MAX_WORKERS_PER_NODE="${ENV_MAX_WORKERS_PER_NODE:-8}"
|
||||
|
||||
RUN_PIDS=()
|
||||
RUN_PORTS=()
|
||||
|
||||
log() {
|
||||
echo "[$(date '+%H:%M:%S')] $*"
|
||||
}
|
||||
|
||||
kill_port() {
|
||||
local port="$1"
|
||||
if [[ "$DRY_RUN" == "1" ]]; then
|
||||
log "[DRY RUN] skip port cleanup for :${port}"
|
||||
return 0
|
||||
fi
|
||||
if lsof -i ":${port}" -sTCP:LISTEN >/dev/null 2>&1; then
|
||||
lsof -ti ":${port}" | xargs -r kill -9 || true
|
||||
fi
|
||||
}
|
||||
|
||||
wait_for_http() {
|
||||
local url="$1"
|
||||
local timeout="${2:-240}"
|
||||
local name="${3:-endpoint}"
|
||||
local start
|
||||
start="$(date +%s)"
|
||||
while true; do
|
||||
if curl -fsS "$url" >/dev/null 2>&1; then
|
||||
log "Ready: ${name} (${url})"
|
||||
return 0
|
||||
fi
|
||||
if (( "$(date +%s)" - start > timeout )); then
|
||||
log "Timeout waiting for ${name}: ${url}"
|
||||
return 1
|
||||
fi
|
||||
sleep 2
|
||||
done
|
||||
}
|
||||
|
||||
start_process() {
|
||||
local name="$1"
|
||||
local logfile="$2"
|
||||
shift 2
|
||||
if [[ "$DRY_RUN" == "1" ]]; then
|
||||
log "[DRY RUN] start ${name} (log: ${logfile})"
|
||||
printf ' '
|
||||
printf '%q ' "$@"
|
||||
printf '\n'
|
||||
return 0
|
||||
fi
|
||||
log "Starting ${name} (log: ${logfile})"
|
||||
"$@" >"$logfile" 2>&1 &
|
||||
local pid=$!
|
||||
RUN_PIDS+=("$pid")
|
||||
log "${name} PID=${pid}"
|
||||
}
|
||||
|
||||
cleanup_all() {
|
||||
log "Cleaning up processes..."
|
||||
for pid in "${RUN_PIDS[@]:-}"; do
|
||||
kill "$pid" >/dev/null 2>&1 || true
|
||||
done
|
||||
sleep 1
|
||||
for pid in "${RUN_PIDS[@]:-}"; do
|
||||
kill -9 "$pid" >/dev/null 2>&1 || true
|
||||
done
|
||||
for port in "${RUN_PORTS[@]:-}"; do
|
||||
kill_port "$port"
|
||||
done
|
||||
}
|
||||
|
||||
trap cleanup_all EXIT INT TERM
|
||||
|
||||
mkdir -p "$LOG_DIR" "$SAVE_DIR"
|
||||
RUN_PORTS+=("$API_PORT" "$STUDENT_PORT" "$TEACHER_PORT")
|
||||
kill_port "$API_PORT"
|
||||
kill_port "$STUDENT_PORT"
|
||||
kill_port "$TEACHER_PORT"
|
||||
|
||||
log "Config:"
|
||||
log " student=${STUDENT_MODEL}"
|
||||
log " teacher=${TEACHER_MODEL}"
|
||||
log " gpus student=${STUDENT_GPUS}, teacher=${TEACHER_GPUS}, trainer=${TRAINER_GPU}"
|
||||
log " ports api=${API_PORT}, student=${STUDENT_PORT}, teacher=${TEACHER_PORT}"
|
||||
log " logs=${LOG_DIR}"
|
||||
log " saves=${SAVE_DIR}"
|
||||
|
||||
# 1) Atropos API
|
||||
start_process "run_api" "${LOG_DIR}/run_api.log" \
|
||||
uv run python -m atroposlib.cli.run_api --port "$API_PORT"
|
||||
if [[ "$DRY_RUN" == "0" ]]; then
|
||||
wait_for_http "http://localhost:${API_PORT}/info" 60 "run-api"
|
||||
fi
|
||||
|
||||
# 2) Student vLLM server
|
||||
start_process "student_vllm" "${LOG_DIR}/student_vllm.log" \
|
||||
env CUDA_VISIBLE_DEVICES="$STUDENT_GPUS" \
|
||||
uv run python -m example_trainer.vllm_api_server \
|
||||
--model "$STUDENT_MODEL" \
|
||||
--port "$STUDENT_PORT" \
|
||||
--tensor-parallel-size "$STUDENT_TP" \
|
||||
--gpu-memory-utilization "$STUDENT_GPU_MEMORY_UTILIZATION" \
|
||||
--max-model-len "$MAX_MODEL_LEN" \
|
||||
--dtype "$DTYPE"
|
||||
if [[ "$DRY_RUN" == "0" ]]; then
|
||||
wait_for_http "http://localhost:${STUDENT_PORT}/health" 420 "student vLLM"
|
||||
fi
|
||||
|
||||
# 3) Teacher vLLM server
|
||||
start_process "teacher_vllm" "${LOG_DIR}/teacher_vllm.log" \
|
||||
env CUDA_VISIBLE_DEVICES="$TEACHER_GPUS" \
|
||||
uv run python -m example_trainer.vllm_api_server \
|
||||
--model "$TEACHER_MODEL" \
|
||||
--port "$TEACHER_PORT" \
|
||||
--tensor-parallel-size "$TEACHER_TP" \
|
||||
--gpu-memory-utilization "$TEACHER_GPU_MEMORY_UTILIZATION" \
|
||||
--max-model-len "$MAX_MODEL_LEN" \
|
||||
--dtype "$DTYPE"
|
||||
if [[ "$DRY_RUN" == "0" ]]; then
|
||||
wait_for_http "http://localhost:${TEACHER_PORT}/health" 600 "teacher vLLM"
|
||||
fi
|
||||
|
||||
# 4) Teacher-distill GSM8K env
|
||||
start_process "gsm8k_teacher_env" "${LOG_DIR}/env.log" \
|
||||
uv run python environments/gsm8k_server_teacher_distill.py serve \
|
||||
--env.group_size "$ENV_GROUP_SIZE" \
|
||||
--env.batch_size "$ENV_BATCH_SIZE" \
|
||||
--env.total_steps "$ENV_TOTAL_STEPS" \
|
||||
--env.steps_per_eval "$ENV_STEPS_PER_EVAL" \
|
||||
--env.max_num_workers_per_node "$ENV_MAX_WORKERS_PER_NODE" \
|
||||
--env.max_token_length "$MAX_MODEL_LEN" \
|
||||
--env.rollout_server_url "http://localhost:${API_PORT}" \
|
||||
--env.use_wandb true \
|
||||
--env.wandb_name "gsm8k-teacher-distill" \
|
||||
--env.distillation_enabled true \
|
||||
--env.teacher_enabled true \
|
||||
--env.teacher_base_url "http://localhost:${TEACHER_PORT}/v1" \
|
||||
--env.teacher_model_name "$TEACHER_MODEL" \
|
||||
--env.teacher_top_k "$TEACHER_TOP_K" \
|
||||
--openai.api_key "dummy" \
|
||||
--openai.base_url "http://localhost:${STUDENT_PORT}/v1" \
|
||||
--openai.model_name "$STUDENT_MODEL" \
|
||||
--openai.tokenizer_name "$STUDENT_MODEL" \
|
||||
--openai.server_type vllm
|
||||
|
||||
log "All services launched."
|
||||
log "Run logs:"
|
||||
log " ${LOG_DIR}/run_api.log"
|
||||
log " ${LOG_DIR}/student_vllm.log"
|
||||
log " ${LOG_DIR}/teacher_vllm.log"
|
||||
log " ${LOG_DIR}/env.log"
|
||||
|
||||
# 5) Trainer (foreground, primary output)
|
||||
if [[ "$DRY_RUN" == "1" ]]; then
|
||||
log "[DRY RUN] trainer command:"
|
||||
printf ' '
|
||||
printf '%q ' env CUDA_VISIBLE_DEVICES="$TRAINER_GPU" \
|
||||
uv run python -m example_trainer.grpo \
|
||||
--model-name "$STUDENT_MODEL" \
|
||||
--weight-bridge-mode none \
|
||||
--device cuda:0 \
|
||||
--save-path "$SAVE_DIR" \
|
||||
--atropos-url "http://localhost:${API_PORT}" \
|
||||
--training-steps "$TRAINING_STEPS" \
|
||||
--batch-size "$BATCH_SIZE" \
|
||||
--gradient-accumulation-steps "$GRAD_ACCUM" \
|
||||
--warmup-steps "$WARMUP_STEPS" \
|
||||
--lr "$LR" \
|
||||
--clip-eps "$CLIP_EPS" \
|
||||
--distill-enabled \
|
||||
--distill-coef "$DISTILL_COEF" \
|
||||
--distill-temperature "$DISTILL_TEMPERATURE"
|
||||
printf '\n'
|
||||
exit 0
|
||||
fi
|
||||
|
||||
log "Starting trainer in foreground..."
|
||||
env CUDA_VISIBLE_DEVICES="$TRAINER_GPU" \
|
||||
uv run python -m example_trainer.grpo \
|
||||
--model-name "$STUDENT_MODEL" \
|
||||
--weight-bridge-mode none \
|
||||
--device cuda:0 \
|
||||
--save-path "$SAVE_DIR" \
|
||||
--atropos-url "http://localhost:${API_PORT}" \
|
||||
--training-steps "$TRAINING_STEPS" \
|
||||
--batch-size "$BATCH_SIZE" \
|
||||
--gradient-accumulation-steps "$GRAD_ACCUM" \
|
||||
--warmup-steps "$WARMUP_STEPS" \
|
||||
--lr "$LR" \
|
||||
--clip-eps "$CLIP_EPS" \
|
||||
--distill-enabled \
|
||||
--distill-coef "$DISTILL_COEF" \
|
||||
--distill-temperature "$DISTILL_TEMPERATURE" | tee "${LOG_DIR}/trainer.log"
|
||||
|
||||
log "Training finished."
|
||||
|
|
@ -170,6 +170,8 @@ def train_legacy(config: TrainingConfig):
|
|||
batch_data[:4]
|
||||
)
|
||||
inference_logprob_batches = batch_data[4] if len(batch_data) > 4 else None
|
||||
distill_token_id_batches = batch_data[5] if len(batch_data) > 5 else None
|
||||
distill_logprob_batches = batch_data[6] if len(batch_data) > 6 else None
|
||||
data_fetch_time = time.time() - data_fetch_start
|
||||
benchmark_stats["data_fetch_times"].append(data_fetch_time)
|
||||
|
||||
|
|
@ -192,6 +194,8 @@ def train_legacy(config: TrainingConfig):
|
|||
config,
|
||||
step_idx=step,
|
||||
inference_logprob_batches=inference_logprob_batches,
|
||||
distill_token_id_batches=distill_token_id_batches,
|
||||
distill_logprob_batches=distill_logprob_batches,
|
||||
)
|
||||
step_time = time.time() - step_start
|
||||
benchmark_stats["step_times"].append(step_time)
|
||||
|
|
@ -324,6 +328,8 @@ def train_shared_vllm(config: TrainingConfig):
|
|||
batch_data[:4]
|
||||
)
|
||||
inference_logprob_batches = batch_data[4] if len(batch_data) > 4 else None
|
||||
distill_token_id_batches = batch_data[5] if len(batch_data) > 5 else None
|
||||
distill_logprob_batches = batch_data[6] if len(batch_data) > 6 else None
|
||||
data_fetch_time = time.time() - data_fetch_start
|
||||
benchmark_stats["data_fetch_times"].append(data_fetch_time)
|
||||
|
||||
|
|
@ -339,6 +345,8 @@ def train_shared_vllm(config: TrainingConfig):
|
|||
config,
|
||||
step_idx=step,
|
||||
inference_logprob_batches=inference_logprob_batches, # Pass for GRPO ratio computation
|
||||
distill_token_id_batches=distill_token_id_batches,
|
||||
distill_logprob_batches=distill_logprob_batches,
|
||||
)
|
||||
step_time = time.time() - step_start
|
||||
benchmark_stats["step_times"].append(step_time)
|
||||
|
|
@ -484,6 +492,8 @@ def train_lora(config: TrainingConfig):
|
|||
batch_data[:4]
|
||||
)
|
||||
inference_logprob_batches = batch_data[4] if len(batch_data) > 4 else None
|
||||
distill_token_id_batches = batch_data[5] if len(batch_data) > 5 else None
|
||||
distill_logprob_batches = batch_data[6] if len(batch_data) > 6 else None
|
||||
data_fetch_time = time.time() - data_fetch_start
|
||||
benchmark_stats["data_fetch_times"].append(data_fetch_time)
|
||||
|
||||
|
|
@ -499,6 +509,8 @@ def train_lora(config: TrainingConfig):
|
|||
config,
|
||||
step_idx=step,
|
||||
inference_logprob_batches=inference_logprob_batches,
|
||||
distill_token_id_batches=distill_token_id_batches,
|
||||
distill_logprob_batches=distill_logprob_batches,
|
||||
)
|
||||
step_time = time.time() - step_start
|
||||
benchmark_stats["step_times"].append(step_time)
|
||||
|
|
@ -706,6 +718,8 @@ def train_lora_restart(config: TrainingConfig):
|
|||
batch_data[:4]
|
||||
)
|
||||
inference_logprob_batches = batch_data[4] if len(batch_data) > 4 else None
|
||||
distill_token_id_batches = batch_data[5] if len(batch_data) > 5 else None
|
||||
distill_logprob_batches = batch_data[6] if len(batch_data) > 6 else None
|
||||
data_fetch_time = time.time() - data_fetch_start
|
||||
benchmark_stats["data_fetch_times"].append(data_fetch_time)
|
||||
|
||||
|
|
@ -721,6 +735,8 @@ def train_lora_restart(config: TrainingConfig):
|
|||
config,
|
||||
step_idx=step,
|
||||
inference_logprob_batches=inference_logprob_batches,
|
||||
distill_token_id_batches=distill_token_id_batches,
|
||||
distill_logprob_batches=distill_logprob_batches,
|
||||
)
|
||||
step_time = time.time() - step_start
|
||||
benchmark_stats["step_times"].append(step_time)
|
||||
|
|
|
|||
|
|
@ -70,6 +70,11 @@ def compute_grpo_loss(
|
|||
gradient_accumulation_steps: int,
|
||||
inference_logprobs: Optional[torch.Tensor] = None,
|
||||
clip_eps: float = 0.2,
|
||||
distill_token_ids: Optional[torch.Tensor] = None,
|
||||
distill_logprobs: Optional[torch.Tensor] = None,
|
||||
distill_enabled: bool = False,
|
||||
distill_coef: float = 0.0,
|
||||
distill_temperature: float = 1.0,
|
||||
) -> Tuple[torch.Tensor, dict]:
|
||||
"""
|
||||
Compute GRPO (Group Relative Policy Optimization) loss for a single micro-batch.
|
||||
|
|
@ -125,6 +130,9 @@ def compute_grpo_loss(
|
|||
logprob_diff_abs_mean = 0.0
|
||||
logprob_diff_max = 0.0
|
||||
|
||||
distill_loss_value = torch.tensor(0.0, device=logp_per_token.device)
|
||||
distill_token_count = 0.0
|
||||
|
||||
# === GRPO/PPO Loss Computation ===
|
||||
if inference_logprobs is not None:
|
||||
# Move inference logprobs to correct device/dtype
|
||||
|
|
@ -187,7 +195,23 @@ def compute_grpo_loss(
|
|||
# Average over tokens, then over batch
|
||||
policy_loss = ((policy_loss_per_token * mask).sum(dim=-1) / mask_sum).mean()
|
||||
|
||||
total_loss = policy_loss / gradient_accumulation_steps
|
||||
if (
|
||||
distill_enabled
|
||||
and distill_coef > 0
|
||||
and distill_token_ids is not None
|
||||
and distill_logprobs is not None
|
||||
):
|
||||
distill_loss_value, distill_token_count = compute_distillation_loss(
|
||||
logits=scaled_logits,
|
||||
labels=labels,
|
||||
distill_token_ids=distill_token_ids.to(logits.device),
|
||||
distill_logprobs=distill_logprobs.to(logits.device, logits.dtype),
|
||||
temperature=max(1e-6, float(distill_temperature)),
|
||||
)
|
||||
|
||||
total_loss = (policy_loss + distill_coef * distill_loss_value) / (
|
||||
gradient_accumulation_steps
|
||||
)
|
||||
|
||||
# Compute metrics for logging
|
||||
with torch.no_grad():
|
||||
|
|
@ -253,11 +277,66 @@ def compute_grpo_loss(
|
|||
"logprob_diff_mean": logprob_diff_mean,
|
||||
"logprob_diff_abs_mean": logprob_diff_abs_mean,
|
||||
"logprob_diff_max": logprob_diff_max,
|
||||
"distill_loss": (
|
||||
distill_loss_value.item()
|
||||
if torch.is_tensor(distill_loss_value)
|
||||
else float(distill_loss_value)
|
||||
),
|
||||
"distill_token_count": distill_token_count,
|
||||
}
|
||||
|
||||
return total_loss, metrics
|
||||
|
||||
|
||||
def compute_distillation_loss(
|
||||
logits: torch.Tensor,
|
||||
labels: torch.Tensor,
|
||||
distill_token_ids: torch.Tensor,
|
||||
distill_logprobs: torch.Tensor,
|
||||
temperature: float = 1.0,
|
||||
) -> Tuple[torch.Tensor, float]:
|
||||
"""
|
||||
Compute token-level distillation loss from teacher top-k prompt logprobs.
|
||||
|
||||
Args:
|
||||
logits: Student logits [batch, seq_len, vocab]
|
||||
labels: Labels [batch, seq_len], -100 for masked positions
|
||||
distill_token_ids: Teacher top-k token IDs [batch, seq_len, k], -1 padded
|
||||
distill_logprobs: Teacher top-k logprobs [batch, seq_len, k], very negative padded
|
||||
temperature: Distillation temperature
|
||||
|
||||
Returns:
|
||||
Tuple of (distillation loss scalar, valid token count)
|
||||
"""
|
||||
if distill_token_ids.dim() != 3 or distill_logprobs.dim() != 3:
|
||||
return torch.tensor(0.0, device=logits.device, dtype=logits.dtype), 0.0
|
||||
|
||||
if distill_token_ids.shape[:2] != labels.shape or distill_logprobs.shape != distill_token_ids.shape:
|
||||
return torch.tensor(0.0, device=logits.device, dtype=logits.dtype), 0.0
|
||||
|
||||
temp = max(1e-6, float(temperature))
|
||||
student_log_probs = F.log_softmax(logits / temp, dim=-1)
|
||||
|
||||
valid_ids = distill_token_ids >= 0
|
||||
label_mask = labels != -100
|
||||
valid_pos = label_mask & valid_ids.any(dim=-1)
|
||||
if not valid_pos.any():
|
||||
return torch.tensor(0.0, device=logits.device, dtype=logits.dtype), 0.0
|
||||
|
||||
gather_ids = distill_token_ids.clamp_min(0).long()
|
||||
student_logp_topk = torch.gather(student_log_probs, dim=-1, index=gather_ids)
|
||||
|
||||
masked_teacher_logprobs = distill_logprobs.masked_fill(~valid_ids, -1e9)
|
||||
teacher_probs = F.softmax(masked_teacher_logprobs / temp, dim=-1)
|
||||
|
||||
per_token_loss = -(teacher_probs * student_logp_topk).sum(dim=-1)
|
||||
per_token_loss = per_token_loss * valid_pos.to(per_token_loss.dtype)
|
||||
|
||||
token_count = valid_pos.sum().item()
|
||||
loss = per_token_loss.sum() / valid_pos.sum().clamp_min(1).to(per_token_loss.dtype)
|
||||
return loss, float(token_count)
|
||||
|
||||
|
||||
def run_training_step(
|
||||
model: torch.nn.Module,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
|
|
@ -268,6 +347,8 @@ def run_training_step(
|
|||
config: TrainingConfig,
|
||||
step_idx: int,
|
||||
inference_logprob_batches: Optional[List[torch.Tensor]] = None,
|
||||
distill_token_id_batches: Optional[List[torch.Tensor]] = None,
|
||||
distill_logprob_batches: Optional[List[torch.Tensor]] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Run a single training step with gradient accumulation.
|
||||
|
|
@ -302,6 +383,8 @@ def run_training_step(
|
|||
total_logprob_diff_mean = 0.0
|
||||
total_logprob_diff_abs_mean = 0.0
|
||||
total_logprob_diff_max = 0.0
|
||||
total_distill_loss = 0.0
|
||||
total_distill_tokens = 0.0
|
||||
grad_norm = 0.0
|
||||
all_training_logprobs: List[torch.Tensor] = []
|
||||
all_inference_logprobs: List[torch.Tensor] = []
|
||||
|
|
@ -335,6 +418,16 @@ def run_training_step(
|
|||
inference_logprob_batches
|
||||
):
|
||||
inf_logprobs = inference_logprob_batches[batch_idx]
|
||||
distill_ids = None
|
||||
if distill_token_id_batches is not None and batch_idx < len(
|
||||
distill_token_id_batches
|
||||
):
|
||||
distill_ids = distill_token_id_batches[batch_idx]
|
||||
distill_lps = None
|
||||
if distill_logprob_batches is not None and batch_idx < len(
|
||||
distill_logprob_batches
|
||||
):
|
||||
distill_lps = distill_logprob_batches[batch_idx]
|
||||
|
||||
loss, metrics = compute_grpo_loss(
|
||||
model,
|
||||
|
|
@ -345,6 +438,11 @@ def run_training_step(
|
|||
config.gradient_accumulation_steps,
|
||||
inference_logprobs=inf_logprobs,
|
||||
clip_eps=clip_eps,
|
||||
distill_token_ids=distill_ids,
|
||||
distill_logprobs=distill_lps,
|
||||
distill_enabled=bool(getattr(config, "distill_enabled", False)),
|
||||
distill_coef=float(getattr(config, "distill_coef", 0.0)),
|
||||
distill_temperature=float(getattr(config, "distill_temperature", 1.0)),
|
||||
)
|
||||
|
||||
loss.backward()
|
||||
|
|
@ -364,6 +462,8 @@ def run_training_step(
|
|||
total_logprob_diff_max = max(
|
||||
total_logprob_diff_max, metrics.get("logprob_diff_max", 0.0)
|
||||
)
|
||||
total_distill_loss += metrics.get("distill_loss", 0.0)
|
||||
total_distill_tokens += metrics.get("distill_token_count", 0.0)
|
||||
|
||||
# Collect logprobs for alignment monitoring
|
||||
if "training_logprobs" in metrics and metrics["training_logprobs"] is not None:
|
||||
|
|
@ -399,6 +499,8 @@ def run_training_step(
|
|||
# GRPO-specific metrics (averaged over batches)
|
||||
"mean_ratio": total_mean_ratio / num_batches,
|
||||
"clipped_fraction": total_clipped_fraction / num_batches,
|
||||
"distill_loss": total_distill_loss / num_batches,
|
||||
"distill_token_count": total_distill_tokens,
|
||||
}
|
||||
|
||||
# Compute logprob alignment stats for monitoring
|
||||
|
|
@ -472,6 +574,12 @@ def log_metrics(
|
|||
clipped_frac = metrics.get("clipped_fraction", 0)
|
||||
|
||||
print(f" GRPO: ratio={mean_ratio:.3f}, clipped={clipped_frac*100:.1f}%")
|
||||
if metrics.get("distill_token_count", 0) > 0:
|
||||
print(
|
||||
" Distill: "
|
||||
f"loss={metrics.get('distill_loss', 0.0):.4f}, "
|
||||
f"tokens={int(metrics.get('distill_token_count', 0))}"
|
||||
)
|
||||
|
||||
# Advantage distribution
|
||||
if "pos_count" in metrics or "neg_count" in metrics:
|
||||
|
|
@ -494,6 +602,8 @@ def log_metrics(
|
|||
# GRPO-specific metrics
|
||||
"grpo/mean_ratio": mean_ratio,
|
||||
"grpo/clipped_fraction": clipped_frac,
|
||||
"distill/loss": metrics.get("distill_loss", 0.0),
|
||||
"distill/token_count": metrics.get("distill_token_count", 0.0),
|
||||
}
|
||||
# Add timing metrics if present
|
||||
for key in [
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue