mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-28 17:29:30 +00:00
remove training code
This commit is contained in:
parent
862cd3667d
commit
148a4fd5eb
6 changed files with 38 additions and 329 deletions
|
|
@ -170,8 +170,6 @@ def train_legacy(config: TrainingConfig):
|
|||
batch_data[:4]
|
||||
)
|
||||
inference_logprob_batches = batch_data[4] if len(batch_data) > 4 else None
|
||||
distill_token_id_batches = batch_data[5] if len(batch_data) > 5 else None
|
||||
distill_logprob_batches = batch_data[6] if len(batch_data) > 6 else None
|
||||
data_fetch_time = time.time() - data_fetch_start
|
||||
benchmark_stats["data_fetch_times"].append(data_fetch_time)
|
||||
|
||||
|
|
@ -194,8 +192,6 @@ def train_legacy(config: TrainingConfig):
|
|||
config,
|
||||
step_idx=step,
|
||||
inference_logprob_batches=inference_logprob_batches,
|
||||
distill_token_id_batches=distill_token_id_batches,
|
||||
distill_logprob_batches=distill_logprob_batches,
|
||||
)
|
||||
step_time = time.time() - step_start
|
||||
benchmark_stats["step_times"].append(step_time)
|
||||
|
|
@ -317,30 +313,17 @@ def train_shared_vllm(config: TrainingConfig):
|
|||
# Fetch data (with inference logprobs for proper GRPO loss)
|
||||
data_fetch_start = time.time()
|
||||
if len(batches) == 0:
|
||||
print(" [Trainer] requesting data from Atropos API...", flush=True)
|
||||
batches, _ = get_data(
|
||||
config.batch_size,
|
||||
config.seq_len,
|
||||
config.atropos_url,
|
||||
extract_inference_logprobs=True, # Enable proper GRPO with reference logprobs
|
||||
)
|
||||
print(
|
||||
f" [Trainer] get_data returned {len(batches)} trainer batch tuple(s)",
|
||||
flush=True,
|
||||
)
|
||||
batch_data = batches.pop(0)
|
||||
token_batches, label_batches, advantage_batches, temperature_batches = (
|
||||
batch_data[:4]
|
||||
)
|
||||
inference_logprob_batches = batch_data[4] if len(batch_data) > 4 else None
|
||||
distill_token_id_batches = batch_data[5] if len(batch_data) > 5 else None
|
||||
distill_logprob_batches = batch_data[6] if len(batch_data) > 6 else None
|
||||
token_shapes = [tuple(tb.shape) for tb in token_batches]
|
||||
print(
|
||||
" [Trainer] selected trainer batch: "
|
||||
f"micro_batches={len(token_batches)} token_batch_shapes={token_shapes}",
|
||||
flush=True,
|
||||
)
|
||||
data_fetch_time = time.time() - data_fetch_start
|
||||
benchmark_stats["data_fetch_times"].append(data_fetch_time)
|
||||
|
||||
|
|
@ -356,8 +339,6 @@ def train_shared_vllm(config: TrainingConfig):
|
|||
config,
|
||||
step_idx=step,
|
||||
inference_logprob_batches=inference_logprob_batches, # Pass for GRPO ratio computation
|
||||
distill_token_id_batches=distill_token_id_batches,
|
||||
distill_logprob_batches=distill_logprob_batches,
|
||||
)
|
||||
step_time = time.time() - step_start
|
||||
benchmark_stats["step_times"].append(step_time)
|
||||
|
|
@ -503,8 +484,6 @@ def train_lora(config: TrainingConfig):
|
|||
batch_data[:4]
|
||||
)
|
||||
inference_logprob_batches = batch_data[4] if len(batch_data) > 4 else None
|
||||
distill_token_id_batches = batch_data[5] if len(batch_data) > 5 else None
|
||||
distill_logprob_batches = batch_data[6] if len(batch_data) > 6 else None
|
||||
data_fetch_time = time.time() - data_fetch_start
|
||||
benchmark_stats["data_fetch_times"].append(data_fetch_time)
|
||||
|
||||
|
|
@ -520,8 +499,6 @@ def train_lora(config: TrainingConfig):
|
|||
config,
|
||||
step_idx=step,
|
||||
inference_logprob_batches=inference_logprob_batches,
|
||||
distill_token_id_batches=distill_token_id_batches,
|
||||
distill_logprob_batches=distill_logprob_batches,
|
||||
)
|
||||
step_time = time.time() - step_start
|
||||
benchmark_stats["step_times"].append(step_time)
|
||||
|
|
@ -729,8 +706,6 @@ def train_lora_restart(config: TrainingConfig):
|
|||
batch_data[:4]
|
||||
)
|
||||
inference_logprob_batches = batch_data[4] if len(batch_data) > 4 else None
|
||||
distill_token_id_batches = batch_data[5] if len(batch_data) > 5 else None
|
||||
distill_logprob_batches = batch_data[6] if len(batch_data) > 6 else None
|
||||
data_fetch_time = time.time() - data_fetch_start
|
||||
benchmark_stats["data_fetch_times"].append(data_fetch_time)
|
||||
|
||||
|
|
@ -746,8 +721,6 @@ def train_lora_restart(config: TrainingConfig):
|
|||
config,
|
||||
step_idx=step,
|
||||
inference_logprob_batches=inference_logprob_batches,
|
||||
distill_token_id_batches=distill_token_id_batches,
|
||||
distill_logprob_batches=distill_logprob_batches,
|
||||
)
|
||||
step_time = time.time() - step_start
|
||||
benchmark_stats["step_times"].append(step_time)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue