logprob wandb

This commit is contained in:
Jai Suphavadeeprasit 2026-01-27 13:25:37 -05:00
parent 210726c3d9
commit 947ab19a8e
2 changed files with 74 additions and 6 deletions

View file

@ -90,7 +90,8 @@ def train_legacy(config: TrainingConfig):
# Fetch data
data_fetch_start = time.time()
if len(batches) == 0:
batches = get_data(config.batch_size, config.seq_len, config.atropos_url)
batches, _ = get_data(config.batch_size, config.seq_len, config.atropos_url,
extract_inference_logprobs=False)
token_batches, label_batches, advantage_batches, temperature_batches = batches.pop(0)
data_fetch_time = time.time() - data_fetch_start
benchmark_stats["data_fetch_times"].append(data_fetch_time)
@ -206,26 +207,34 @@ def train_shared_vllm(config: TrainingConfig):
# === Training Loop ===
batches = []
inference_logprobs = None
for step in range(config.training_steps):
print(f"\nStep {step+1}/{config.training_steps}")
# Fetch data
# Fetch data (with inference logprobs for alignment check)
data_fetch_start = time.time()
if len(batches) == 0:
batches = get_data(config.batch_size, config.seq_len, config.atropos_url)
batches, inference_logprobs = get_data(
config.batch_size, config.seq_len, config.atropos_url,
extract_inference_logprobs=True, # Enable logprob alignment check
)
token_batches, label_batches, advantage_batches, temperature_batches = batches.pop(0)
data_fetch_time = time.time() - data_fetch_start
benchmark_stats["data_fetch_times"].append(data_fetch_time)
# Training step
# Training step (with logprob alignment check)
step_start = time.time()
metrics = run_training_step(
model, optimizer,
token_batches, label_batches, advantage_batches, temperature_batches,
config,
inference_logprobs=inference_logprobs, # Pass for alignment validation
)
step_time = time.time() - step_start
benchmark_stats["step_times"].append(step_time)
# Clear inference logprobs after use (will be refreshed with new data)
inference_logprobs = None
# GPU memory tracking
gpu_mem_gb = torch.cuda.memory_allocated() / 1e9 if torch.cuda.is_available() else 0
@ -335,7 +344,8 @@ def train_lora(config: TrainingConfig):
# Fetch data
data_fetch_start = time.time()
if len(batches) == 0:
batches = get_data(config.batch_size, config.seq_len, config.atropos_url)
batches, _ = get_data(config.batch_size, config.seq_len, config.atropos_url,
extract_inference_logprobs=False)
token_batches, label_batches, advantage_batches, temperature_batches = batches.pop(0)
data_fetch_time = time.time() - data_fetch_start
benchmark_stats["data_fetch_times"].append(data_fetch_time)