mirror of
https://github.com/NousResearch/atropos.git
synced 2026-05-03 17:53:17 +00:00
KL
This commit is contained in:
parent
13def6bdab
commit
3e7705c17d
2 changed files with 39 additions and 26 deletions
|
|
@ -203,12 +203,14 @@ def train_legacy(config: TrainingConfig):
|
|||
for step in range(config.training_steps):
|
||||
print(f"\nStep {step+1}/{config.training_steps}")
|
||||
|
||||
# Fetch data
|
||||
# Fetch data (with inference logprobs for proper GRPO)
|
||||
data_fetch_start = time.time()
|
||||
if len(batches) == 0:
|
||||
batches, _ = get_data(config.batch_size, config.seq_len, config.atropos_url,
|
||||
extract_inference_logprobs=False)
|
||||
token_batches, label_batches, advantage_batches, temperature_batches = batches.pop(0)
|
||||
extract_inference_logprobs=True)
|
||||
batch_data = batches.pop(0)
|
||||
token_batches, label_batches, advantage_batches, temperature_batches = batch_data[:4]
|
||||
inference_logprob_batches = batch_data[4] if len(batch_data) > 4 else None
|
||||
data_fetch_time = time.time() - data_fetch_start
|
||||
benchmark_stats["data_fetch_times"].append(data_fetch_time)
|
||||
|
||||
|
|
@ -217,12 +219,13 @@ def train_legacy(config: TrainingConfig):
|
|||
if should_sync:
|
||||
terminate_vllm_process()
|
||||
|
||||
# Training step
|
||||
# Training step (with proper GRPO using inference logprobs)
|
||||
step_start = time.time()
|
||||
metrics = run_training_step(
|
||||
model, optimizer,
|
||||
token_batches, label_batches, advantage_batches, temperature_batches,
|
||||
config,
|
||||
inference_logprob_batches=inference_logprob_batches,
|
||||
)
|
||||
step_time = time.time() - step_start
|
||||
benchmark_stats["step_times"].append(step_time)
|
||||
|
|
@ -518,34 +521,32 @@ def train_shared_vllm(config: TrainingConfig):
|
|||
|
||||
# === Training Loop ===
|
||||
batches = []
|
||||
inference_logprobs = None
|
||||
for step in range(config.training_steps):
|
||||
print(f"\nStep {step+1}/{config.training_steps}")
|
||||
|
||||
# Fetch data (with inference logprobs for alignment check)
|
||||
# Fetch data (with inference logprobs for proper GRPO loss)
|
||||
data_fetch_start = time.time()
|
||||
if len(batches) == 0:
|
||||
batches, inference_logprobs = get_data(
|
||||
batches, _ = get_data(
|
||||
config.batch_size, config.seq_len, config.atropos_url,
|
||||
extract_inference_logprobs=True, # Enable logprob alignment check
|
||||
extract_inference_logprobs=True, # Enable proper GRPO with reference logprobs
|
||||
)
|
||||
token_batches, label_batches, advantage_batches, temperature_batches = batches.pop(0)
|
||||
batch_data = batches.pop(0)
|
||||
token_batches, label_batches, advantage_batches, temperature_batches = batch_data[:4]
|
||||
inference_logprob_batches = batch_data[4] if len(batch_data) > 4 else None
|
||||
data_fetch_time = time.time() - data_fetch_start
|
||||
benchmark_stats["data_fetch_times"].append(data_fetch_time)
|
||||
|
||||
# Training step (with logprob alignment check)
|
||||
# Training step with proper GRPO (importance sampling + KL penalty)
|
||||
step_start = time.time()
|
||||
metrics = run_training_step(
|
||||
model, optimizer,
|
||||
token_batches, label_batches, advantage_batches, temperature_batches,
|
||||
config,
|
||||
inference_logprobs=inference_logprobs, # Pass for alignment validation
|
||||
inference_logprob_batches=inference_logprob_batches, # Pass for GRPO ratio computation
|
||||
)
|
||||
step_time = time.time() - step_start
|
||||
benchmark_stats["step_times"].append(step_time)
|
||||
|
||||
# Clear inference logprobs after use (will be refreshed with new data)
|
||||
inference_logprobs = None
|
||||
|
||||
# GPU memory tracking
|
||||
gpu_mem_gb = torch.cuda.memory_allocated() / 1e9 if torch.cuda.is_available() else 0
|
||||
|
|
@ -652,21 +653,24 @@ def train_lora(config: TrainingConfig):
|
|||
for step in range(config.training_steps):
|
||||
print(f"\nStep {step+1}/{config.training_steps}")
|
||||
|
||||
# Fetch data
|
||||
# Fetch data (with inference logprobs for proper GRPO)
|
||||
data_fetch_start = time.time()
|
||||
if len(batches) == 0:
|
||||
batches, _ = get_data(config.batch_size, config.seq_len, config.atropos_url,
|
||||
extract_inference_logprobs=False)
|
||||
token_batches, label_batches, advantage_batches, temperature_batches = batches.pop(0)
|
||||
extract_inference_logprobs=True)
|
||||
batch_data = batches.pop(0)
|
||||
token_batches, label_batches, advantage_batches, temperature_batches = batch_data[:4]
|
||||
inference_logprob_batches = batch_data[4] if len(batch_data) > 4 else None
|
||||
data_fetch_time = time.time() - data_fetch_start
|
||||
benchmark_stats["data_fetch_times"].append(data_fetch_time)
|
||||
|
||||
# Training step
|
||||
# Training step with proper GRPO
|
||||
step_start = time.time()
|
||||
metrics = run_training_step(
|
||||
model, optimizer,
|
||||
token_batches, label_batches, advantage_batches, temperature_batches,
|
||||
config,
|
||||
inference_logprob_batches=inference_logprob_batches,
|
||||
)
|
||||
step_time = time.time() - step_start
|
||||
benchmark_stats["step_times"].append(step_time)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue