BLEUBERI/training/create_training_data.py
2025-06-04 20:36:43 +00:00

727 lines
34 KiB
Python

import argparse
import os
import random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import json
import logging
import csv
from tqdm import tqdm
from typing import List, Optional, Union, Tuple, Dict, Any
from datasets import load_from_disk, Dataset, DatasetDict, load_dataset
from evaluate import load
from transformers import AutoModel, AutoModelForSequenceClassification, AutoTokenizer
from chat_templates import QWEN_CHAT_TEMPLATE, LLAMA_CHAT_TEMPLATE, OLMO_CHAT_TEMPLATE
from utils import *
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
METRIC_FUNCTIONS = {
"bleu": lambda pred, refs, prompt: bleu_reward(pred, refs),
"rouge": lambda pred, refs, prompt: rouge_reward(pred, refs),
"bertscore": lambda pred, refs, prompt: bertscore_reward(pred, refs),
"bleu_rouge_f1": lambda pred, refs, prompt: bleu_rouge_f1_reward(pred, refs),
}
def score_dataset(data, model_outputs_dict, metric):
score_field = f"{metric}_score"
scored_data = []
scores = {}
rm_model, rm_tokenizer = None, None
device = "cuda" if torch.cuda.is_available() else "cpu"
if metric == "rm":
rm_model, rm_tokenizer = get_reward_model(device)
valid_examples = []
valid_responses = []
valid_prompts_for_rm = [] # Only populated if metric is 'rm'
for example in data:
example_id_str = str(example["id"])
if example_id_str not in model_outputs_dict:
print(f"Skipping example {example_id_str}: No model output found")
continue
response = model_outputs_dict[example_id_str]
valid_examples.append(example)
valid_responses.append(response)
if metric == "rm":
prompt = example.get("prompt")
if prompt is None:
print(f"Warning: Skipping example {example_id_str} for RM scoring due to missing 'prompt'.")
valid_prompts_for_rm.append(prompt)
if not valid_examples:
print("No valid examples found for scoring")
return scored_data, scores, score_field
batch_scores = []
if metric == "rm":
temp_valid_examples = []
temp_valid_responses = []
temp_valid_prompts_for_rm = []
for ex, resp, prmpt in zip(valid_examples, valid_responses, valid_prompts_for_rm):
if prmpt is not None:
temp_valid_examples.append(ex)
temp_valid_responses.append(resp)
temp_valid_prompts_for_rm.append(prmpt)
else:
print(f"Skipping example {ex['id']} for RM scoring as prompt is missing.")
valid_examples = temp_valid_examples
valid_responses = temp_valid_responses
valid_prompts_for_rm = temp_valid_prompts_for_rm
if valid_examples: # Only run RM if there are examples left
batch_scores = rm_reward(valid_responses, valid_prompts_for_rm, rm_model=rm_model, rm_tokenizer=rm_tokenizer, device=device)
else:
print("No valid examples with prompts for RM scoring.")
elif metric in METRIC_FUNCTIONS:
metric_func = METRIC_FUNCTIONS[metric]
desc = f"Processing {metric.upper()} scores"
for response, example in tqdm(zip(valid_responses, valid_examples), desc=desc, total=len(valid_responses)):
references_for_metric = example.get("references")
if not isinstance(references_for_metric, list) or not references_for_metric or not all(isinstance(r, str) and r for r in references_for_metric):
print(f"Warning: Invalid or empty references (expected a non-empty list of non-empty strings) found for example {example['id']} for metric {metric}. Skipping score calculation.")
batch_scores.append(0.0)
continue
score_val = metric_func(response, references_for_metric, None) # Pass the list
batch_scores.append(score_val)
else:
print(f"Unsupported metric: {metric}")
batch_scores = [0.0] * len(valid_examples)
# Ensure batch_scores has the same length as valid_examples if any specific metric path failed
if len(batch_scores) != len(valid_examples):
print(f"Warning: Mismatch in number of scores ({len(batch_scores)}) and examples ({len(valid_examples)}) for metric {metric}. Padding with 0.0.")
num_missing = len(valid_examples) - len(batch_scores)
batch_scores.extend([0.0] * num_missing)
for example, score in zip(valid_examples, batch_scores):
example[score_field] = score
scores[str(example["id"])] = { # Ensure ID is string for consistency
"score": float(score),
"metric": metric
}
scored_data.append(example)
print(f"Scored {len(scored_data)} examples using {metric} metric")
return scored_data, scores, score_field
def _get_user_prompt_from_messages(messages: List[Dict[str, str]], example_id: Optional[Any] = None) -> Optional[str]:
"""Extracts the user prompt from a list of messages."""
if not messages:
if example_id:
print(f"Warning: Messages list is empty for example {example_id}.")
return None
for item in messages:
if item.get("role") == "user":
return item.get("content")
if example_id:
print(f"Warning: No user prompt found in messages for example {example_id}.")
return None
def _get_assistant_response_from_messages(messages: List[Dict[str, str]], example_id: Optional[Any] = None) -> Optional[str]:
"""Extracts the assistant response from a list of messages."""
if not messages:
if example_id:
print(f"Warning: Messages list is empty for example {example_id}.")
return None
for item in messages:
if item.get("role") == "assistant":
return item.get("content")
if example_id:
print(f"Warning: No assistant response found in messages for example {example_id}.")
return None
def _load_or_compute_scores(args: argparse.Namespace,
current_aggregated_data: List[Dict[str, Any]],
model_outputs_dict: Dict[str, str],
score_cache_dir: str) -> Tuple[List[Dict[str, Any]], Dict[str, Any], Optional[str]]:
"""
Loads scores from cache if available, otherwise computes them using score_dataset.
Returns: scored_data_list, scores_dict, score_field_name
"""
if not current_aggregated_data:
print("No aggregated data provided to _load_or_compute_scores. Returning empty.")
return [], {}, None
nrefs, ref_models_str = get_ref_models_str(args.ref_models)
potential_score_file = build_score_path(
score_cache_dir,
args.hf_dataset_path,
args.metric,
args.model,
nrefs,
ref_models_str
)
final_scored_data = []
final_scores_dict = {}
final_score_field = None
needs_recompute = True # Default to recompute
if os.path.exists(potential_score_file):
print(f"INFO: Found existing score file at {potential_score_file}")
try:
with open(potential_score_file, 'r') as f:
loaded_scores_cache = json.load(f)
if not loaded_scores_cache or not isinstance(list(loaded_scores_cache.values())[0], dict) or 'metric' not in list(loaded_scores_cache.values())[0] or list(loaded_scores_cache.values())[0]['metric'] != args.metric: # Also check if metric matches
print(f"Warning: Loaded score file {potential_score_file} seems empty, invalid, or for a different metric. Recomputing scores.")
else:
print(f"INFO: Using existing scores from {potential_score_file}")
final_score_field = f"{args.metric}_score" # Construct based on current args.metric
final_scores_dict = loaded_scores_cache
temp_scored_data = []
missing_count = 0
for example in current_aggregated_data:
example_id_str = str(example["id"])
if example_id_str in final_scores_dict:
example[final_score_field] = final_scores_dict[example_id_str]["score"]
temp_scored_data.append(example)
else:
print(f"Warning: No score found for example {example_id_str} in {potential_score_file}. Skipping.")
missing_count += 1
final_scored_data = temp_scored_data
if missing_count > 0:
print(f"Warning: Skipped {missing_count} examples due to missing scores in the loaded file.")
print(f"Applied existing {final_score_field} scores to {len(final_scored_data)} examples")
needs_recompute = False # Scores loaded successfully
except json.JSONDecodeError:
print(f"Warning: Could not decode JSON from {potential_score_file}. Recomputing scores.")
except Exception as e:
print(f"Warning: Error reading score file {potential_score_file}: {e}. Recomputing scores.")
if needs_recompute:
print(f"Scoring {len(current_aggregated_data)} examples with model {args.model} using {args.metric} metric")
processed_scored_data, computed_scores_dict, computed_score_field = score_dataset(
current_aggregated_data, model_outputs_dict, args.metric
)
final_scored_data = processed_scored_data
final_scores_dict = computed_scores_dict
final_score_field = computed_score_field
# Save the newly computed scores
os.makedirs(score_cache_dir, exist_ok=True) # Ensure dir exists before saving
save_scores(final_scores_dict, potential_score_file)
print(f"Saved computed scores to {potential_score_file}")
return final_scored_data, final_scores_dict, final_score_field
def make_grpo_data(args):
print("Creating GRPO dataset...")
print(f"Received ref_models as: {args.ref_models}")
# Check if the final dataset directory already exists if a specific dataset name is provided
if args.output_dataset_name:
grpo_base_output_dir = os.path.join(args.output_dir, "data_grpo")
potential_output_path = os.path.join(grpo_base_output_dir, args.output_dataset_name)
if os.path.exists(potential_output_path) and os.path.isdir(potential_output_path):
print(f"Final dataset directory already exists at {potential_output_path}. Skipping generation.")
return potential_output_path
# Load dataset once for GRPO process, potentially passing it to run_inference
print(f"Loading source dataset from HuggingFace: {args.hf_dataset_path} for GRPO aggregation")
if args.hf_dataset_path.startswith('yapeichang/') or '/' in args.hf_dataset_path:
source_ds_for_grpo = load_dataset(args.hf_dataset_path, split="train")
else:
loaded_ds = load_from_disk(args.hf_dataset_path)
if isinstance(loaded_ds, DatasetDict):
source_ds_for_grpo = loaded_ds["train"]
else:
source_ds_for_grpo = loaded_ds # Assuming it's already a Dataset object
print(f"Loaded source dataset with {len(source_ds_for_grpo)} examples for GRPO aggregation.")
aggregated_data = aggregate_references_for_grpo(args, source_ds_for_grpo) # Pass loaded dataset
scored_data = [] # Initialize with a default
score_field = None
scores = {} # To hold scores for saving, if computed
model_outputs_dict = {}
if args.selection_mode == "random":
if not args.num_examples:
raise ValueError("When selection_mode is 'random', --num_examples must be provided")
scored_data = aggregated_data
if len(aggregated_data) > args.num_examples:
print(f"Randomly sampling {args.num_examples} examples using seed {args.seed}")
np.random.seed(args.seed)
random.seed(args.seed)
indices = np.random.choice(len(aggregated_data), args.num_examples, replace=False)
scored_data = [aggregated_data[i] for i in indices]
print(f"Randomly selected {args.num_examples} examples")
else:
print(f"Using all {len(aggregated_data)} examples (requested sample size larger than dataset or equal)")
else: # Modes: easy, medium, hard (require model outputs and scores)
if not args.model or not args.metric:
raise ValueError("For selection_mode 'easy', 'medium', or 'hard', both --model and --metric must be provided.")
# Step 1: Ensure model inference outputs are available
inference_base_dir = os.path.join(args.output_dir, "inference_outputs")
model_basename = get_model_name(args.model)
inference_results_path = os.path.join(inference_base_dir, f"{model_basename}_inference_results.csv")
if not os.path.exists(inference_results_path):
print(f"Inference results not found at {inference_results_path}. Running inference...")
# Prepare args for run_inference; hf_dataset_path is still needed if source_ds is None
inference_run_args = argparse.Namespace(
hf_dataset_path=args.hf_dataset_path,
model=args.model,
output_dir=inference_base_dir,
seed=args.seed,
max_new_tokens=args.inference_max_new_tokens
)
# Pass the already loaded source_ds_for_grpo to run_inference
returned_inference_path = run_inference(inference_run_args, preloaded_dataset=source_ds_for_grpo)
if returned_inference_path is None or not os.path.exists(returned_inference_path):
raise FileNotFoundError(f"Inference run failed or did not produce the expected output file at {inference_results_path}. Attempted path: {returned_inference_path}")
inference_results_path = returned_inference_path # Use the path returned by the function
print(f"Inference complete. Results saved to {inference_results_path}")
else:
print(f"Found existing inference results at {inference_results_path}")
# Load model outputs from the CSV
print(f"Loading model outputs from {inference_results_path}...")
df_inference = pd.read_csv(inference_results_path)
for _, row in df_inference.iterrows():
if 'id' in row and 'response' in row:
model_outputs_dict[str(row['id'])] = row['response'] # Ensure ID is string for consistency
print(f"Loaded {len(model_outputs_dict)} model outputs.")
# Filter aggregated_data based on available model outputs
original_aggregated_count = len(aggregated_data)
aggregated_data = [ex for ex in aggregated_data if str(ex["id"]) in model_outputs_dict]
print(f"Filtered aggregated_data from {original_aggregated_count} to {len(aggregated_data)} based on available model outputs.")
if not aggregated_data:
print("No examples in aggregated_data after filtering against model outputs. GRPO dataset will be empty or very small.")
scored_data = []
scores = {}
score_field = None
else:
score_cache_dir = os.path.join(args.output_dir, "scored_outputs")
scored_data, scores, score_field = _load_or_compute_scores(
args,
aggregated_data, # This is the filtered list
model_outputs_dict,
score_cache_dir
)
# Step 3: Selection logic based on mode and number of examples
if scored_data and score_field: # Ensure there is data and a score field to sort by
if args.num_examples and len(scored_data) > args.num_examples:
if args.selection_mode == "easy":
print(f"Selecting the {args.num_examples} highest scoring examples")
scored_data = sorted(scored_data, key=lambda x: x[score_field], reverse=True)[:args.num_examples]
elif args.selection_mode == "medium":
print(f"Selecting {args.num_examples} examples from the middle of the distribution")
# Sort first to ensure consistent selection
temp_sorted_for_medium = sorted(scored_data, key=lambda x: x[score_field])
start_idx = (len(temp_sorted_for_medium) - args.num_examples) // 2
scored_data = temp_sorted_for_medium[start_idx:start_idx + args.num_examples]
# Re-sort by score descending for consistency with other modes if desired, or remove if not needed
scored_data = sorted(scored_data, key=lambda x: x[score_field], reverse=True)
print(f"Re-sorted selected examples from highest to lowest score")
else: # args.selection_mode == "hard"
print(f"Selecting the {args.num_examples} lowest scoring examples")
scored_data = sorted(scored_data, key=lambda x: x[score_field])[:args.num_examples]
# Re-sort by score descending for consistency
scored_data = sorted(scored_data, key=lambda x: x[score_field], reverse=True)
print(f"Re-sorted selected examples from highest to lowest score")
else: # Not reducing num_examples or scored_data is already smaller
scored_data = sorted(scored_data, key=lambda x: x[score_field], reverse=True)
if args.num_examples is None:
print(f"Using all {len(scored_data)} examples, sorted by {score_field.split('_')[0]} score (highest first)")
# If args.num_examples is set but len(scored_data) <= args.num_examples, all examples are kept and sorted.
print(f"Sorted {len(scored_data)} examples by {score_field.split('_')[0]} score (highest first)")
if scored_data:
print(f"Highest {score_field} score: {scored_data[0][score_field]}, "
f"Lowest {score_field} score: {scored_data[-1][score_field]}")
elif not scored_data:
print("No examples left after scoring/filtering to select from.")
else: # scored_data exists but no score_field (should not happen in this branch)
print("Warning: Scored data exists but score_field is not set. Cannot sort or select by score.")
# If selection_mode was "random", scored_data is already set.
# If other modes, scored_data is now selected and sorted.
if not scored_data and args.selection_mode != "random":
print("Warning: No data to save after selection process. Output dataset will be empty.")
# scored_data could be an empty list here, which is fine for save_grpo_dataset
output_path = save_grpo_dataset(args, scored_data, score_field)
return output_path
def aggregate_references_for_grpo(args, source_dataset: Dataset):
print(f"Aggregating data from reference models: {args.ref_models}")
os.makedirs(args.output_dir, exist_ok=True)
ds = source_dataset # Use the passed dataset
print(f"Using pre-loaded dataset with {len(ds)} examples for aggregation.")
print(f"Dataset columns: {ds.column_names}")
# Map reference model names to dataset column names
use_gold = "gold" in args.ref_models
models_to_use = args.ref_models.copy()
if use_gold:
models_to_use.remove("gold")
# Create mapping from model names to column names
model_column_mapping = {}
available_ref_columns = [col for col in ds.column_names if col.startswith('ref_output_')]
print(f"Available reference output columns: {available_ref_columns}")
for model in models_to_use:
expected_column = f"ref_output_{model}"
if expected_column in available_ref_columns:
model_column_mapping[model] = expected_column
print(f"Mapped model '{model}' to column '{expected_column}'")
else:
print(f"Warning: Could not find column '{expected_column}' for model '{model}'")
models_to_use = [model for model in models_to_use if model in model_column_mapping]
if not models_to_use and not use_gold:
raise ValueError("No reference models could be mapped to dataset columns")
def has_all_references(example):
if use_gold and (example.get("ref_output_gold") is None or
pd.isna(example.get("ref_output_gold")) or
str(example.get("ref_output_gold")).strip() == ""):
return False
for model in models_to_use:
col_name = model_column_mapping[model]
if (example.get(col_name) is None or
pd.isna(example.get(col_name)) or
str(example.get(col_name)).strip() == ""):
return False
return True
ds_filtered = ds.filter(has_all_references)
print(f"After filtering for complete references: {len(ds_filtered)} examples")
aggregated_data = []
for example in ds_filtered:
example_id = example.get("id", "unknown_id")
if "prompt" in example and example["prompt"] is not None:
prompt = example["prompt"]
else:
prompt = _get_user_prompt_from_messages(example.get("messages"), example_id)
if "ref_output_gold" in example and example["ref_output_gold"] is not None:
ground_truth = example["ref_output_gold"]
else:
ground_truth = _get_assistant_response_from_messages(example.get("messages"), example_id)
references = []
if use_gold:
references.append(ground_truth)
for model in models_to_use:
col_name = model_column_mapping[model]
references.append(example[col_name])
aggregated_data.append({
"id": example["id"],
"source": example.get("source", "unknown"),
"messages": example["messages"],
"prompt": prompt,
"ground_truth": ground_truth,
"references": references,
})
print(f"Aggregated {len(aggregated_data)} examples with specified references")
return aggregated_data
def save_grpo_dataset(args, data, score_field=None):
data_basename = os.path.basename(args.hf_dataset_path)
nrefs, ref_models_str = get_ref_models_str(args.ref_models)
model_str = ""
metric_str = ""
sampling_str = ""
if args.selection_mode == "random":
sampling_str = "_random"
elif args.selection_mode in ["easy", "medium", "hard"]:
sampling_str = f"_{args.selection_mode}"
plot_metric_name = ""
if args.metric:
metric_str = f"{args.metric}"
plot_metric_name = args.metric
if args.model:
model_str = f"_{get_model_name(args.model)}_"
elif score_field:
plot_metric_name = score_field.replace("_score","")
grpo_base_output_dir = os.path.join(args.output_dir, "data_grpo")
if args.output_dataset_name:
dataset_name_suffix = args.output_dataset_name
else:
dataset_name_suffix = f"{data_basename}_{metric_str}{model_str}{nrefs}ref{ref_models_str}{sampling_str}_{len(data)}"
output_path = os.path.join(grpo_base_output_dir, dataset_name_suffix)
dataset = DatasetDict({
"train": Dataset.from_list(data)
})
if os.path.exists(output_path):
print(f"Dataset already exists at {output_path}, skipping save.")
else:
os.makedirs(output_path, exist_ok=True)
dataset.save_to_disk(output_path)
print(f"Saved GRPO dataset with {len(data)} examples to {output_path}")
if score_field and plot_metric_name and data:
scores_to_plot = [item[score_field] for item in data if score_field in item and pd.notna(item[score_field])]
if scores_to_plot:
fig_path = os.path.join(output_path, f"{plot_metric_name}_distribution.png")
save_histogram(scores_to_plot, plot_metric_name, f"{plot_metric_name.upper()} Score Distribution for GRPO Output", f"{plot_metric_name.upper()} Score", fig_path)
else:
print(f"No valid scores found for field '{score_field}' in GRPO output to generate distribution plot.")
elif not data:
print("GRPO output data is empty, skipping histogram generation.")
return output_path
def make_sft_data(args):
# Construct the potential output path first
input_basename = os.path.basename(args.input_data_path)
sft_base_output_dir = os.path.join(args.output_dir, "data_sft")
dataset_name = f"{input_basename}_SFT"
output_path = os.path.join(sft_base_output_dir, dataset_name)
# Check if the dataset already exists
if os.path.exists(output_path) and os.path.isdir(output_path):
print(f"SFT dataset already exists at {output_path}. Skipping generation.")
return output_path
data = load_from_disk(args.input_data_path)["train"]
print(f"Processing {len(data)} examples with references...")
def build_message(prompt, response):
return [
{"role": "user", "content": prompt},
{"role": "assistant", "content": response},
]
sft_data = []
for example in tqdm(data, desc="Converting to SFT format"):
for j, ref in enumerate(example["references"]):
new_id = f"{example['id']}_{j}"
sft_data.append({
"id": new_id,
"source": example["source"],
"messages": build_message(example["prompt"], ref),
})
print(f"Created {len(sft_data)} SFT examples from {len(data)} input examples")
dataset = DatasetDict({
"train": Dataset.from_list(sft_data),
})
os.makedirs(output_path, exist_ok=True) # Ensure dir exists before saving, even if we checked earlier
dataset.save_to_disk(output_path)
print(f"Saved SFT dataset with {len(sft_data)} examples to {output_path}")
return output_path
def set_seed(seed):
random.seed(seed)
np.random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def format_message(prompt):
return [
{"role": "user", "content": prompt},
]
def run_inference(args, preloaded_dataset: Optional[Dataset] = None):
"""Run model inference on the data pool and save results"""
print("Running model inference...")
try:
from vllm import LLM, SamplingParams
from chat_templates import QWEN_CHAT_TEMPLATE, LLAMA_CHAT_TEMPLATE, OLMO_CHAT_TEMPLATE
except ImportError as e:
print(f"Error importing vLLM or chat templates: {e}")
print("Please install vLLM: pip install vllm")
return None
def setup_vllm(args):
llm = LLM(
model=args.model,
dtype="bfloat16",
# hf_token=os.getenv("HF_TOKEN")
)
tokenizer = llm.get_tokenizer()
if not tokenizer.chat_template:
if "qwen" in args.model.lower():
tokenizer.chat_template = QWEN_CHAT_TEMPLATE
elif "llama" in args.model.lower():
tokenizer.chat_template = LLAMA_CHAT_TEMPLATE
elif "olmo" in args.model.lower():
tokenizer.chat_template = OLMO_CHAT_TEMPLATE
print(f"Chat template set to {tokenizer.chat_template}")
return llm, tokenizer
def generate_responses(llm, prompts, max_new_tokens=512):
sampling_params = SamplingParams(
max_tokens=max_new_tokens,
temperature=0.7,
top_p=0.9,
)
outputs = llm.chat(
prompts,
sampling_params,
add_generation_prompt=True
)
return [output.outputs[0].text.strip() for output in outputs]
set_seed(args.seed)
logging.basicConfig(level=logging.INFO)
if preloaded_dataset is not None:
print(f"Using pre-loaded dataset for inference with {len(preloaded_dataset)} examples.")
ds = preloaded_dataset
else:
print(f"Loading dataset from HuggingFace: {args.hf_dataset_path} for inference")
if args.hf_dataset_path.startswith('yapeichang/') or '/' in args.hf_dataset_path:
ds = load_dataset(args.hf_dataset_path, split="train")
else:
loaded_ds = load_from_disk(args.hf_dataset_path)
if isinstance(loaded_ds, DatasetDict):
ds = loaded_ds["train"]
else:
ds = loaded_ds
logging.info(f"Dataset loaded successfully from {args.hf_dataset_path} for inference")
print(f"Loaded dataset with {len(ds)} examples for inference")
llm, tokenizer = setup_vllm(args)
logging.info(f"Model loaded successfully with vLLM")
os.makedirs(args.output_dir, exist_ok=True)
model_basename = get_model_name(args.model)
save_path = os.path.join(args.output_dir, f"{model_basename}_inference_results.csv")
logging.info(f"Saving results to {save_path}")
# Load existing results if they exist
existing_results = None
all_results = []
if os.path.exists(save_path):
existing_results = pd.read_csv(save_path)
all_results = existing_results.to_dict('records')
logging.info(f"Loaded {len(all_results)} existing results from {save_path}")
processed_ids = set(existing_results['id'].values)
else:
processed_ids = set()
examples_to_process = [ex for ex in ds if ex['id'] not in processed_ids]
logging.info(f"Processing {len(examples_to_process)} new examples")
if examples_to_process:
prompts_for_model = []
original_prompts_content = []
for ex in examples_to_process:
example_id = ex.get("id", "unknown_id")
if "prompt" in ex and ex["prompt"] is not None:
current_prompt_content = ex["prompt"]
else:
current_prompt_content = _get_user_prompt_from_messages(ex.get("messages"), example_id)
if current_prompt_content is None:
print(f"Warning: Skipping example {example_id} in inference due to missing prompt.")
original_prompts_content.append(None)
prompts_for_model.append(format_message(""))
continue
original_prompts_content.append(current_prompt_content)
prompts_for_model.append(format_message(current_prompt_content))
valid_indices = [i for i, p in enumerate(original_prompts_content) if p is not None]
prompts_for_model = [prompts_for_model[i] for i in valid_indices]
examples_to_process_filtered = [examples_to_process[i] for i in valid_indices]
original_prompts_content_filtered = [original_prompts_content[i] for i in valid_indices]
if not prompts_for_model:
print("No valid prompts to send for inference after filtering.")
else:
logging.info(f"Generating responses for {len(prompts_for_model)} prompts")
responses = generate_responses(llm, prompts_for_model, max_new_tokens=args.max_new_tokens)
logging.info(f"Generated {len(responses)} responses")
response_idx = 0
for ex, prompt_content in zip(examples_to_process_filtered, original_prompts_content_filtered):
all_results.append({
"id": ex["id"],
"source": ex.get("source", "unknown"),
"prompt": prompt_content,
"response": responses[response_idx]
})
response_idx += 1
all_results = sorted(all_results, key=lambda x: x["id"])
pd.DataFrame(all_results).to_csv(save_path, index=False, quoting=csv.QUOTE_ALL, escapechar='\\')
logging.info(f"All results saved to {save_path}")
print(f"Inference complete! Results saved to {save_path}")
else:
print("All examples have already been processed!")
return save_path
def main():
from vllm import LLM, SamplingParams
parser = argparse.ArgumentParser(description="Unified training data generation script")
subparsers = parser.add_subparsers(dest="command", help="Command to run")
# GRPO data command
grpo_parser = subparsers.add_parser("grpo", help="Create GRPO dataset (aggregates references and optionally sorts by score)")
grpo_parser.add_argument("--hf_dataset_path", type=str, default="yapeichang/BLEUBERI-Tulu3-50k", help="HuggingFace dataset path or local path with reference outputs")
grpo_parser.add_argument("--inference_max_new_tokens", type=int, default=512, help="Max new tokens for on-the-fly inference if model outputs are missing.")
grpo_parser.add_argument("--ref_models", type=str, nargs="+", default=["gold"]) # More options: "claude-3-7-sonnet@20250219", "deepseek-chat-v3", "gemini-2.5-pro-exp-03-25", "o4-mini-2025-04-16", "Llama-3.1-8B-Instruct
grpo_parser.add_argument("--selection_mode", type=str, choices=["random", "easy", "medium", "hard"], default="hard", help="Selection mode: 'random' for random sampling, 'easy' for highest scores, 'medium' for middle scores, 'hard' for lowest scores")
grpo_parser.add_argument("--output_dataset_name", type=str, default=None, help="Custom name for the GRPO output dataset directory. If None, a name will be automatically generated.")
grpo_parser.add_argument("--metric", type=str, choices=["bleu", "rm", "rouge", "bertscore", "bleu_rouge_f1"], help="Metric to use for scoring (if not using pre-computed scores)")
grpo_parser.add_argument("--model", type=str, help="Model to use for scoring (if not using pre-computed scores) and for generating inference outputs if missing.")
grpo_parser.add_argument("--num_examples", type=int, default=None, help="Number of examples to include in final dataset")
grpo_parser.add_argument("--output_dir", type=str, default="../data")
grpo_parser.add_argument("--seed", type=int, default=42, help="Random seed for on-the-fly inference if model outputs are missing.")
# SFT data command
sft_parser = subparsers.add_parser("sft", help="Convert dataset with references to SFT format")
sft_parser.add_argument("--input_data_path", type=str, required=True, help="Path to dataset with references")
sft_parser.add_argument("--output_dir", type=str, default="../data")
args = parser.parse_args()
if args.command == "grpo":
if args.selection_mode != "random" and not (args.model and args.metric):
raise ValueError("When selection_mode is 'easy', 'medium', or 'hard', both --model and --metric must be provided")
make_grpo_data(args)
elif args.command == "sft":
make_sft_data(args)
else:
parser.print_help()
if __name__ == "__main__":
main()