remove training code

This commit is contained in:
Jai Suphavadeeprasit 2026-03-13 12:52:52 -04:00
parent 862cd3667d
commit 148a4fd5eb
6 changed files with 38 additions and 329 deletions

View file

@ -170,8 +170,6 @@ def train_legacy(config: TrainingConfig):
batch_data[:4]
)
inference_logprob_batches = batch_data[4] if len(batch_data) > 4 else None
distill_token_id_batches = batch_data[5] if len(batch_data) > 5 else None
distill_logprob_batches = batch_data[6] if len(batch_data) > 6 else None
data_fetch_time = time.time() - data_fetch_start
benchmark_stats["data_fetch_times"].append(data_fetch_time)
@ -194,8 +192,6 @@ def train_legacy(config: TrainingConfig):
config,
step_idx=step,
inference_logprob_batches=inference_logprob_batches,
distill_token_id_batches=distill_token_id_batches,
distill_logprob_batches=distill_logprob_batches,
)
step_time = time.time() - step_start
benchmark_stats["step_times"].append(step_time)
@ -317,30 +313,17 @@ def train_shared_vllm(config: TrainingConfig):
# Fetch data (with inference logprobs for proper GRPO loss)
data_fetch_start = time.time()
if len(batches) == 0:
print(" [Trainer] requesting data from Atropos API...", flush=True)
batches, _ = get_data(
config.batch_size,
config.seq_len,
config.atropos_url,
extract_inference_logprobs=True, # Enable proper GRPO with reference logprobs
)
print(
f" [Trainer] get_data returned {len(batches)} trainer batch tuple(s)",
flush=True,
)
batch_data = batches.pop(0)
token_batches, label_batches, advantage_batches, temperature_batches = (
batch_data[:4]
)
inference_logprob_batches = batch_data[4] if len(batch_data) > 4 else None
distill_token_id_batches = batch_data[5] if len(batch_data) > 5 else None
distill_logprob_batches = batch_data[6] if len(batch_data) > 6 else None
token_shapes = [tuple(tb.shape) for tb in token_batches]
print(
" [Trainer] selected trainer batch: "
f"micro_batches={len(token_batches)} token_batch_shapes={token_shapes}",
flush=True,
)
data_fetch_time = time.time() - data_fetch_start
benchmark_stats["data_fetch_times"].append(data_fetch_time)
@ -356,8 +339,6 @@ def train_shared_vllm(config: TrainingConfig):
config,
step_idx=step,
inference_logprob_batches=inference_logprob_batches, # Pass for GRPO ratio computation
distill_token_id_batches=distill_token_id_batches,
distill_logprob_batches=distill_logprob_batches,
)
step_time = time.time() - step_start
benchmark_stats["step_times"].append(step_time)
@ -503,8 +484,6 @@ def train_lora(config: TrainingConfig):
batch_data[:4]
)
inference_logprob_batches = batch_data[4] if len(batch_data) > 4 else None
distill_token_id_batches = batch_data[5] if len(batch_data) > 5 else None
distill_logprob_batches = batch_data[6] if len(batch_data) > 6 else None
data_fetch_time = time.time() - data_fetch_start
benchmark_stats["data_fetch_times"].append(data_fetch_time)
@ -520,8 +499,6 @@ def train_lora(config: TrainingConfig):
config,
step_idx=step,
inference_logprob_batches=inference_logprob_batches,
distill_token_id_batches=distill_token_id_batches,
distill_logprob_batches=distill_logprob_batches,
)
step_time = time.time() - step_start
benchmark_stats["step_times"].append(step_time)
@ -729,8 +706,6 @@ def train_lora_restart(config: TrainingConfig):
batch_data[:4]
)
inference_logprob_batches = batch_data[4] if len(batch_data) > 4 else None
distill_token_id_batches = batch_data[5] if len(batch_data) > 5 else None
distill_logprob_batches = batch_data[6] if len(batch_data) > 6 else None
data_fetch_time = time.time() - data_fetch_start
benchmark_stats["data_fetch_times"].append(data_fetch_time)
@ -746,8 +721,6 @@ def train_lora_restart(config: TrainingConfig):
config,
step_idx=step,
inference_logprob_batches=inference_logprob_batches,
distill_token_id_batches=distill_token_id_batches,
distill_logprob_batches=distill_logprob_batches,
)
step_time = time.time() - step_start
benchmark_stats["step_times"].append(step_time)