mirror of
https://github.com/lilakk/BLEUBERI.git
synced 2026-04-19 12:58:12 +00:00
727 lines
34 KiB
Python
727 lines
34 KiB
Python
import argparse
|
|
import os
|
|
import random
|
|
import numpy as np
|
|
import pandas as pd
|
|
import matplotlib.pyplot as plt
|
|
import torch
|
|
import json
|
|
import logging
|
|
import csv
|
|
from tqdm import tqdm
|
|
from typing import List, Optional, Union, Tuple, Dict, Any
|
|
from datasets import load_from_disk, Dataset, DatasetDict, load_dataset
|
|
from evaluate import load
|
|
from transformers import AutoModel, AutoModelForSequenceClassification, AutoTokenizer
|
|
from chat_templates import QWEN_CHAT_TEMPLATE, LLAMA_CHAT_TEMPLATE, OLMO_CHAT_TEMPLATE
|
|
from utils import *
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
METRIC_FUNCTIONS = {
|
|
"bleu": lambda pred, refs, prompt: bleu_reward(pred, refs),
|
|
"rouge": lambda pred, refs, prompt: rouge_reward(pred, refs),
|
|
"bertscore": lambda pred, refs, prompt: bertscore_reward(pred, refs),
|
|
"bleu_rouge_f1": lambda pred, refs, prompt: bleu_rouge_f1_reward(pred, refs),
|
|
}
|
|
|
|
def score_dataset(data, model_outputs_dict, metric):
|
|
score_field = f"{metric}_score"
|
|
scored_data = []
|
|
scores = {}
|
|
|
|
rm_model, rm_tokenizer = None, None
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
if metric == "rm":
|
|
rm_model, rm_tokenizer = get_reward_model(device)
|
|
|
|
valid_examples = []
|
|
valid_responses = []
|
|
valid_prompts_for_rm = [] # Only populated if metric is 'rm'
|
|
|
|
for example in data:
|
|
example_id_str = str(example["id"])
|
|
if example_id_str not in model_outputs_dict:
|
|
print(f"Skipping example {example_id_str}: No model output found")
|
|
continue
|
|
|
|
response = model_outputs_dict[example_id_str]
|
|
|
|
valid_examples.append(example)
|
|
valid_responses.append(response)
|
|
if metric == "rm":
|
|
prompt = example.get("prompt")
|
|
if prompt is None:
|
|
print(f"Warning: Skipping example {example_id_str} for RM scoring due to missing 'prompt'.")
|
|
valid_prompts_for_rm.append(prompt)
|
|
|
|
if not valid_examples:
|
|
print("No valid examples found for scoring")
|
|
return scored_data, scores, score_field
|
|
|
|
batch_scores = []
|
|
if metric == "rm":
|
|
temp_valid_examples = []
|
|
temp_valid_responses = []
|
|
temp_valid_prompts_for_rm = []
|
|
for ex, resp, prmpt in zip(valid_examples, valid_responses, valid_prompts_for_rm):
|
|
if prmpt is not None:
|
|
temp_valid_examples.append(ex)
|
|
temp_valid_responses.append(resp)
|
|
temp_valid_prompts_for_rm.append(prmpt)
|
|
else:
|
|
print(f"Skipping example {ex['id']} for RM scoring as prompt is missing.")
|
|
|
|
valid_examples = temp_valid_examples
|
|
valid_responses = temp_valid_responses
|
|
valid_prompts_for_rm = temp_valid_prompts_for_rm
|
|
|
|
if valid_examples: # Only run RM if there are examples left
|
|
batch_scores = rm_reward(valid_responses, valid_prompts_for_rm, rm_model=rm_model, rm_tokenizer=rm_tokenizer, device=device)
|
|
else:
|
|
print("No valid examples with prompts for RM scoring.")
|
|
|
|
elif metric in METRIC_FUNCTIONS:
|
|
metric_func = METRIC_FUNCTIONS[metric]
|
|
desc = f"Processing {metric.upper()} scores"
|
|
for response, example in tqdm(zip(valid_responses, valid_examples), desc=desc, total=len(valid_responses)):
|
|
references_for_metric = example.get("references")
|
|
|
|
if not isinstance(references_for_metric, list) or not references_for_metric or not all(isinstance(r, str) and r for r in references_for_metric):
|
|
print(f"Warning: Invalid or empty references (expected a non-empty list of non-empty strings) found for example {example['id']} for metric {metric}. Skipping score calculation.")
|
|
batch_scores.append(0.0)
|
|
continue
|
|
|
|
score_val = metric_func(response, references_for_metric, None) # Pass the list
|
|
batch_scores.append(score_val)
|
|
else:
|
|
print(f"Unsupported metric: {metric}")
|
|
batch_scores = [0.0] * len(valid_examples)
|
|
|
|
# Ensure batch_scores has the same length as valid_examples if any specific metric path failed
|
|
if len(batch_scores) != len(valid_examples):
|
|
print(f"Warning: Mismatch in number of scores ({len(batch_scores)}) and examples ({len(valid_examples)}) for metric {metric}. Padding with 0.0.")
|
|
num_missing = len(valid_examples) - len(batch_scores)
|
|
batch_scores.extend([0.0] * num_missing)
|
|
|
|
for example, score in zip(valid_examples, batch_scores):
|
|
example[score_field] = score
|
|
scores[str(example["id"])] = { # Ensure ID is string for consistency
|
|
"score": float(score),
|
|
"metric": metric
|
|
}
|
|
scored_data.append(example)
|
|
|
|
print(f"Scored {len(scored_data)} examples using {metric} metric")
|
|
return scored_data, scores, score_field
|
|
|
|
def _get_user_prompt_from_messages(messages: List[Dict[str, str]], example_id: Optional[Any] = None) -> Optional[str]:
|
|
"""Extracts the user prompt from a list of messages."""
|
|
if not messages:
|
|
if example_id:
|
|
print(f"Warning: Messages list is empty for example {example_id}.")
|
|
return None
|
|
for item in messages:
|
|
if item.get("role") == "user":
|
|
return item.get("content")
|
|
if example_id:
|
|
print(f"Warning: No user prompt found in messages for example {example_id}.")
|
|
return None
|
|
|
|
def _get_assistant_response_from_messages(messages: List[Dict[str, str]], example_id: Optional[Any] = None) -> Optional[str]:
|
|
"""Extracts the assistant response from a list of messages."""
|
|
if not messages:
|
|
if example_id:
|
|
print(f"Warning: Messages list is empty for example {example_id}.")
|
|
return None
|
|
for item in messages:
|
|
if item.get("role") == "assistant":
|
|
return item.get("content")
|
|
if example_id:
|
|
print(f"Warning: No assistant response found in messages for example {example_id}.")
|
|
return None
|
|
|
|
def _load_or_compute_scores(args: argparse.Namespace,
|
|
current_aggregated_data: List[Dict[str, Any]],
|
|
model_outputs_dict: Dict[str, str],
|
|
score_cache_dir: str) -> Tuple[List[Dict[str, Any]], Dict[str, Any], Optional[str]]:
|
|
"""
|
|
Loads scores from cache if available, otherwise computes them using score_dataset.
|
|
Returns: scored_data_list, scores_dict, score_field_name
|
|
"""
|
|
if not current_aggregated_data:
|
|
print("No aggregated data provided to _load_or_compute_scores. Returning empty.")
|
|
return [], {}, None
|
|
|
|
nrefs, ref_models_str = get_ref_models_str(args.ref_models)
|
|
potential_score_file = build_score_path(
|
|
score_cache_dir,
|
|
args.hf_dataset_path,
|
|
args.metric,
|
|
args.model,
|
|
nrefs,
|
|
ref_models_str
|
|
)
|
|
|
|
final_scored_data = []
|
|
final_scores_dict = {}
|
|
final_score_field = None
|
|
needs_recompute = True # Default to recompute
|
|
|
|
if os.path.exists(potential_score_file):
|
|
print(f"INFO: Found existing score file at {potential_score_file}")
|
|
try:
|
|
with open(potential_score_file, 'r') as f:
|
|
loaded_scores_cache = json.load(f)
|
|
|
|
if not loaded_scores_cache or not isinstance(list(loaded_scores_cache.values())[0], dict) or 'metric' not in list(loaded_scores_cache.values())[0] or list(loaded_scores_cache.values())[0]['metric'] != args.metric: # Also check if metric matches
|
|
print(f"Warning: Loaded score file {potential_score_file} seems empty, invalid, or for a different metric. Recomputing scores.")
|
|
else:
|
|
print(f"INFO: Using existing scores from {potential_score_file}")
|
|
final_score_field = f"{args.metric}_score" # Construct based on current args.metric
|
|
final_scores_dict = loaded_scores_cache
|
|
temp_scored_data = []
|
|
missing_count = 0
|
|
for example in current_aggregated_data:
|
|
example_id_str = str(example["id"])
|
|
if example_id_str in final_scores_dict:
|
|
example[final_score_field] = final_scores_dict[example_id_str]["score"]
|
|
temp_scored_data.append(example)
|
|
else:
|
|
print(f"Warning: No score found for example {example_id_str} in {potential_score_file}. Skipping.")
|
|
missing_count += 1
|
|
final_scored_data = temp_scored_data
|
|
if missing_count > 0:
|
|
print(f"Warning: Skipped {missing_count} examples due to missing scores in the loaded file.")
|
|
print(f"Applied existing {final_score_field} scores to {len(final_scored_data)} examples")
|
|
needs_recompute = False # Scores loaded successfully
|
|
except json.JSONDecodeError:
|
|
print(f"Warning: Could not decode JSON from {potential_score_file}. Recomputing scores.")
|
|
except Exception as e:
|
|
print(f"Warning: Error reading score file {potential_score_file}: {e}. Recomputing scores.")
|
|
|
|
if needs_recompute:
|
|
print(f"Scoring {len(current_aggregated_data)} examples with model {args.model} using {args.metric} metric")
|
|
processed_scored_data, computed_scores_dict, computed_score_field = score_dataset(
|
|
current_aggregated_data, model_outputs_dict, args.metric
|
|
)
|
|
final_scored_data = processed_scored_data
|
|
final_scores_dict = computed_scores_dict
|
|
final_score_field = computed_score_field
|
|
|
|
# Save the newly computed scores
|
|
os.makedirs(score_cache_dir, exist_ok=True) # Ensure dir exists before saving
|
|
save_scores(final_scores_dict, potential_score_file)
|
|
print(f"Saved computed scores to {potential_score_file}")
|
|
|
|
return final_scored_data, final_scores_dict, final_score_field
|
|
|
|
def make_grpo_data(args):
|
|
print("Creating GRPO dataset...")
|
|
print(f"Received ref_models as: {args.ref_models}")
|
|
|
|
# Check if the final dataset directory already exists if a specific dataset name is provided
|
|
if args.output_dataset_name:
|
|
grpo_base_output_dir = os.path.join(args.output_dir, "data_grpo")
|
|
potential_output_path = os.path.join(grpo_base_output_dir, args.output_dataset_name)
|
|
if os.path.exists(potential_output_path) and os.path.isdir(potential_output_path):
|
|
print(f"Final dataset directory already exists at {potential_output_path}. Skipping generation.")
|
|
return potential_output_path
|
|
|
|
# Load dataset once for GRPO process, potentially passing it to run_inference
|
|
print(f"Loading source dataset from HuggingFace: {args.hf_dataset_path} for GRPO aggregation")
|
|
if args.hf_dataset_path.startswith('yapeichang/') or '/' in args.hf_dataset_path:
|
|
source_ds_for_grpo = load_dataset(args.hf_dataset_path, split="train")
|
|
else:
|
|
loaded_ds = load_from_disk(args.hf_dataset_path)
|
|
if isinstance(loaded_ds, DatasetDict):
|
|
source_ds_for_grpo = loaded_ds["train"]
|
|
else:
|
|
source_ds_for_grpo = loaded_ds # Assuming it's already a Dataset object
|
|
print(f"Loaded source dataset with {len(source_ds_for_grpo)} examples for GRPO aggregation.")
|
|
|
|
aggregated_data = aggregate_references_for_grpo(args, source_ds_for_grpo) # Pass loaded dataset
|
|
|
|
scored_data = [] # Initialize with a default
|
|
score_field = None
|
|
scores = {} # To hold scores for saving, if computed
|
|
model_outputs_dict = {}
|
|
|
|
if args.selection_mode == "random":
|
|
if not args.num_examples:
|
|
raise ValueError("When selection_mode is 'random', --num_examples must be provided")
|
|
|
|
scored_data = aggregated_data
|
|
if len(aggregated_data) > args.num_examples:
|
|
print(f"Randomly sampling {args.num_examples} examples using seed {args.seed}")
|
|
np.random.seed(args.seed)
|
|
random.seed(args.seed)
|
|
indices = np.random.choice(len(aggregated_data), args.num_examples, replace=False)
|
|
scored_data = [aggregated_data[i] for i in indices]
|
|
print(f"Randomly selected {args.num_examples} examples")
|
|
else:
|
|
print(f"Using all {len(aggregated_data)} examples (requested sample size larger than dataset or equal)")
|
|
else: # Modes: easy, medium, hard (require model outputs and scores)
|
|
if not args.model or not args.metric:
|
|
raise ValueError("For selection_mode 'easy', 'medium', or 'hard', both --model and --metric must be provided.")
|
|
|
|
# Step 1: Ensure model inference outputs are available
|
|
inference_base_dir = os.path.join(args.output_dir, "inference_outputs")
|
|
model_basename = get_model_name(args.model)
|
|
inference_results_path = os.path.join(inference_base_dir, f"{model_basename}_inference_results.csv")
|
|
|
|
if not os.path.exists(inference_results_path):
|
|
print(f"Inference results not found at {inference_results_path}. Running inference...")
|
|
# Prepare args for run_inference; hf_dataset_path is still needed if source_ds is None
|
|
inference_run_args = argparse.Namespace(
|
|
hf_dataset_path=args.hf_dataset_path,
|
|
model=args.model,
|
|
output_dir=inference_base_dir,
|
|
seed=args.seed,
|
|
max_new_tokens=args.inference_max_new_tokens
|
|
)
|
|
# Pass the already loaded source_ds_for_grpo to run_inference
|
|
returned_inference_path = run_inference(inference_run_args, preloaded_dataset=source_ds_for_grpo)
|
|
if returned_inference_path is None or not os.path.exists(returned_inference_path):
|
|
raise FileNotFoundError(f"Inference run failed or did not produce the expected output file at {inference_results_path}. Attempted path: {returned_inference_path}")
|
|
inference_results_path = returned_inference_path # Use the path returned by the function
|
|
print(f"Inference complete. Results saved to {inference_results_path}")
|
|
else:
|
|
print(f"Found existing inference results at {inference_results_path}")
|
|
|
|
# Load model outputs from the CSV
|
|
print(f"Loading model outputs from {inference_results_path}...")
|
|
df_inference = pd.read_csv(inference_results_path)
|
|
for _, row in df_inference.iterrows():
|
|
if 'id' in row and 'response' in row:
|
|
model_outputs_dict[str(row['id'])] = row['response'] # Ensure ID is string for consistency
|
|
print(f"Loaded {len(model_outputs_dict)} model outputs.")
|
|
|
|
# Filter aggregated_data based on available model outputs
|
|
original_aggregated_count = len(aggregated_data)
|
|
aggregated_data = [ex for ex in aggregated_data if str(ex["id"]) in model_outputs_dict]
|
|
print(f"Filtered aggregated_data from {original_aggregated_count} to {len(aggregated_data)} based on available model outputs.")
|
|
|
|
if not aggregated_data:
|
|
print("No examples in aggregated_data after filtering against model outputs. GRPO dataset will be empty or very small.")
|
|
scored_data = []
|
|
scores = {}
|
|
score_field = None
|
|
else:
|
|
score_cache_dir = os.path.join(args.output_dir, "scored_outputs")
|
|
scored_data, scores, score_field = _load_or_compute_scores(
|
|
args,
|
|
aggregated_data, # This is the filtered list
|
|
model_outputs_dict,
|
|
score_cache_dir
|
|
)
|
|
|
|
# Step 3: Selection logic based on mode and number of examples
|
|
if scored_data and score_field: # Ensure there is data and a score field to sort by
|
|
if args.num_examples and len(scored_data) > args.num_examples:
|
|
if args.selection_mode == "easy":
|
|
print(f"Selecting the {args.num_examples} highest scoring examples")
|
|
scored_data = sorted(scored_data, key=lambda x: x[score_field], reverse=True)[:args.num_examples]
|
|
elif args.selection_mode == "medium":
|
|
print(f"Selecting {args.num_examples} examples from the middle of the distribution")
|
|
# Sort first to ensure consistent selection
|
|
temp_sorted_for_medium = sorted(scored_data, key=lambda x: x[score_field])
|
|
start_idx = (len(temp_sorted_for_medium) - args.num_examples) // 2
|
|
scored_data = temp_sorted_for_medium[start_idx:start_idx + args.num_examples]
|
|
# Re-sort by score descending for consistency with other modes if desired, or remove if not needed
|
|
scored_data = sorted(scored_data, key=lambda x: x[score_field], reverse=True)
|
|
print(f"Re-sorted selected examples from highest to lowest score")
|
|
else: # args.selection_mode == "hard"
|
|
print(f"Selecting the {args.num_examples} lowest scoring examples")
|
|
scored_data = sorted(scored_data, key=lambda x: x[score_field])[:args.num_examples]
|
|
# Re-sort by score descending for consistency
|
|
scored_data = sorted(scored_data, key=lambda x: x[score_field], reverse=True)
|
|
print(f"Re-sorted selected examples from highest to lowest score")
|
|
else: # Not reducing num_examples or scored_data is already smaller
|
|
scored_data = sorted(scored_data, key=lambda x: x[score_field], reverse=True)
|
|
if args.num_examples is None:
|
|
print(f"Using all {len(scored_data)} examples, sorted by {score_field.split('_')[0]} score (highest first)")
|
|
# If args.num_examples is set but len(scored_data) <= args.num_examples, all examples are kept and sorted.
|
|
|
|
print(f"Sorted {len(scored_data)} examples by {score_field.split('_')[0]} score (highest first)")
|
|
if scored_data:
|
|
print(f"Highest {score_field} score: {scored_data[0][score_field]}, "
|
|
f"Lowest {score_field} score: {scored_data[-1][score_field]}")
|
|
elif not scored_data:
|
|
print("No examples left after scoring/filtering to select from.")
|
|
else: # scored_data exists but no score_field (should not happen in this branch)
|
|
print("Warning: Scored data exists but score_field is not set. Cannot sort or select by score.")
|
|
|
|
# If selection_mode was "random", scored_data is already set.
|
|
# If other modes, scored_data is now selected and sorted.
|
|
if not scored_data and args.selection_mode != "random":
|
|
print("Warning: No data to save after selection process. Output dataset will be empty.")
|
|
# scored_data could be an empty list here, which is fine for save_grpo_dataset
|
|
|
|
output_path = save_grpo_dataset(args, scored_data, score_field)
|
|
return output_path
|
|
|
|
def aggregate_references_for_grpo(args, source_dataset: Dataset):
|
|
print(f"Aggregating data from reference models: {args.ref_models}")
|
|
os.makedirs(args.output_dir, exist_ok=True)
|
|
ds = source_dataset # Use the passed dataset
|
|
|
|
print(f"Using pre-loaded dataset with {len(ds)} examples for aggregation.")
|
|
print(f"Dataset columns: {ds.column_names}")
|
|
|
|
# Map reference model names to dataset column names
|
|
use_gold = "gold" in args.ref_models
|
|
models_to_use = args.ref_models.copy()
|
|
if use_gold:
|
|
models_to_use.remove("gold")
|
|
|
|
# Create mapping from model names to column names
|
|
model_column_mapping = {}
|
|
available_ref_columns = [col for col in ds.column_names if col.startswith('ref_output_')]
|
|
|
|
print(f"Available reference output columns: {available_ref_columns}")
|
|
|
|
for model in models_to_use:
|
|
expected_column = f"ref_output_{model}"
|
|
|
|
if expected_column in available_ref_columns:
|
|
model_column_mapping[model] = expected_column
|
|
print(f"Mapped model '{model}' to column '{expected_column}'")
|
|
else:
|
|
print(f"Warning: Could not find column '{expected_column}' for model '{model}'")
|
|
|
|
models_to_use = [model for model in models_to_use if model in model_column_mapping]
|
|
if not models_to_use and not use_gold:
|
|
raise ValueError("No reference models could be mapped to dataset columns")
|
|
|
|
def has_all_references(example):
|
|
if use_gold and (example.get("ref_output_gold") is None or
|
|
pd.isna(example.get("ref_output_gold")) or
|
|
str(example.get("ref_output_gold")).strip() == ""):
|
|
return False
|
|
for model in models_to_use:
|
|
col_name = model_column_mapping[model]
|
|
if (example.get(col_name) is None or
|
|
pd.isna(example.get(col_name)) or
|
|
str(example.get(col_name)).strip() == ""):
|
|
return False
|
|
return True
|
|
|
|
ds_filtered = ds.filter(has_all_references)
|
|
|
|
print(f"After filtering for complete references: {len(ds_filtered)} examples")
|
|
|
|
aggregated_data = []
|
|
for example in ds_filtered:
|
|
example_id = example.get("id", "unknown_id")
|
|
if "prompt" in example and example["prompt"] is not None:
|
|
prompt = example["prompt"]
|
|
else:
|
|
prompt = _get_user_prompt_from_messages(example.get("messages"), example_id)
|
|
|
|
if "ref_output_gold" in example and example["ref_output_gold"] is not None:
|
|
ground_truth = example["ref_output_gold"]
|
|
else:
|
|
ground_truth = _get_assistant_response_from_messages(example.get("messages"), example_id)
|
|
|
|
references = []
|
|
if use_gold:
|
|
references.append(ground_truth)
|
|
for model in models_to_use:
|
|
col_name = model_column_mapping[model]
|
|
references.append(example[col_name])
|
|
|
|
aggregated_data.append({
|
|
"id": example["id"],
|
|
"source": example.get("source", "unknown"),
|
|
"messages": example["messages"],
|
|
"prompt": prompt,
|
|
"ground_truth": ground_truth,
|
|
"references": references,
|
|
})
|
|
|
|
print(f"Aggregated {len(aggregated_data)} examples with specified references")
|
|
return aggregated_data
|
|
|
|
def save_grpo_dataset(args, data, score_field=None):
|
|
data_basename = os.path.basename(args.hf_dataset_path)
|
|
nrefs, ref_models_str = get_ref_models_str(args.ref_models)
|
|
|
|
model_str = ""
|
|
metric_str = ""
|
|
sampling_str = ""
|
|
|
|
if args.selection_mode == "random":
|
|
sampling_str = "_random"
|
|
elif args.selection_mode in ["easy", "medium", "hard"]:
|
|
sampling_str = f"_{args.selection_mode}"
|
|
|
|
plot_metric_name = ""
|
|
if args.metric:
|
|
metric_str = f"{args.metric}"
|
|
plot_metric_name = args.metric
|
|
if args.model:
|
|
model_str = f"_{get_model_name(args.model)}_"
|
|
elif score_field:
|
|
plot_metric_name = score_field.replace("_score","")
|
|
|
|
grpo_base_output_dir = os.path.join(args.output_dir, "data_grpo")
|
|
|
|
if args.output_dataset_name:
|
|
dataset_name_suffix = args.output_dataset_name
|
|
else:
|
|
dataset_name_suffix = f"{data_basename}_{metric_str}{model_str}{nrefs}ref{ref_models_str}{sampling_str}_{len(data)}"
|
|
|
|
output_path = os.path.join(grpo_base_output_dir, dataset_name_suffix)
|
|
|
|
dataset = DatasetDict({
|
|
"train": Dataset.from_list(data)
|
|
})
|
|
|
|
if os.path.exists(output_path):
|
|
print(f"Dataset already exists at {output_path}, skipping save.")
|
|
else:
|
|
os.makedirs(output_path, exist_ok=True)
|
|
dataset.save_to_disk(output_path)
|
|
print(f"Saved GRPO dataset with {len(data)} examples to {output_path}")
|
|
|
|
if score_field and plot_metric_name and data:
|
|
scores_to_plot = [item[score_field] for item in data if score_field in item and pd.notna(item[score_field])]
|
|
if scores_to_plot:
|
|
fig_path = os.path.join(output_path, f"{plot_metric_name}_distribution.png")
|
|
save_histogram(scores_to_plot, plot_metric_name, f"{plot_metric_name.upper()} Score Distribution for GRPO Output", f"{plot_metric_name.upper()} Score", fig_path)
|
|
else:
|
|
print(f"No valid scores found for field '{score_field}' in GRPO output to generate distribution plot.")
|
|
elif not data:
|
|
print("GRPO output data is empty, skipping histogram generation.")
|
|
|
|
return output_path
|
|
|
|
def make_sft_data(args):
|
|
# Construct the potential output path first
|
|
input_basename = os.path.basename(args.input_data_path)
|
|
sft_base_output_dir = os.path.join(args.output_dir, "data_sft")
|
|
dataset_name = f"{input_basename}_SFT"
|
|
output_path = os.path.join(sft_base_output_dir, dataset_name)
|
|
|
|
# Check if the dataset already exists
|
|
if os.path.exists(output_path) and os.path.isdir(output_path):
|
|
print(f"SFT dataset already exists at {output_path}. Skipping generation.")
|
|
return output_path
|
|
|
|
data = load_from_disk(args.input_data_path)["train"]
|
|
print(f"Processing {len(data)} examples with references...")
|
|
|
|
def build_message(prompt, response):
|
|
return [
|
|
{"role": "user", "content": prompt},
|
|
{"role": "assistant", "content": response},
|
|
]
|
|
|
|
sft_data = []
|
|
for example in tqdm(data, desc="Converting to SFT format"):
|
|
for j, ref in enumerate(example["references"]):
|
|
new_id = f"{example['id']}_{j}"
|
|
sft_data.append({
|
|
"id": new_id,
|
|
"source": example["source"],
|
|
"messages": build_message(example["prompt"], ref),
|
|
})
|
|
|
|
print(f"Created {len(sft_data)} SFT examples from {len(data)} input examples")
|
|
|
|
dataset = DatasetDict({
|
|
"train": Dataset.from_list(sft_data),
|
|
})
|
|
os.makedirs(output_path, exist_ok=True) # Ensure dir exists before saving, even if we checked earlier
|
|
dataset.save_to_disk(output_path)
|
|
print(f"Saved SFT dataset with {len(sft_data)} examples to {output_path}")
|
|
|
|
return output_path
|
|
|
|
def set_seed(seed):
|
|
random.seed(seed)
|
|
np.random.seed(seed)
|
|
os.environ['PYTHONHASHSEED'] = str(seed)
|
|
torch.manual_seed(seed)
|
|
torch.cuda.manual_seed_all(seed)
|
|
torch.backends.cudnn.deterministic = True
|
|
torch.backends.cudnn.benchmark = False
|
|
|
|
def format_message(prompt):
|
|
return [
|
|
{"role": "user", "content": prompt},
|
|
]
|
|
|
|
def run_inference(args, preloaded_dataset: Optional[Dataset] = None):
|
|
"""Run model inference on the data pool and save results"""
|
|
print("Running model inference...")
|
|
|
|
try:
|
|
from vllm import LLM, SamplingParams
|
|
from chat_templates import QWEN_CHAT_TEMPLATE, LLAMA_CHAT_TEMPLATE, OLMO_CHAT_TEMPLATE
|
|
except ImportError as e:
|
|
print(f"Error importing vLLM or chat templates: {e}")
|
|
print("Please install vLLM: pip install vllm")
|
|
return None
|
|
|
|
def setup_vllm(args):
|
|
llm = LLM(
|
|
model=args.model,
|
|
dtype="bfloat16",
|
|
# hf_token=os.getenv("HF_TOKEN")
|
|
)
|
|
tokenizer = llm.get_tokenizer()
|
|
if not tokenizer.chat_template:
|
|
if "qwen" in args.model.lower():
|
|
tokenizer.chat_template = QWEN_CHAT_TEMPLATE
|
|
elif "llama" in args.model.lower():
|
|
tokenizer.chat_template = LLAMA_CHAT_TEMPLATE
|
|
elif "olmo" in args.model.lower():
|
|
tokenizer.chat_template = OLMO_CHAT_TEMPLATE
|
|
print(f"Chat template set to {tokenizer.chat_template}")
|
|
|
|
return llm, tokenizer
|
|
|
|
def generate_responses(llm, prompts, max_new_tokens=512):
|
|
sampling_params = SamplingParams(
|
|
max_tokens=max_new_tokens,
|
|
temperature=0.7,
|
|
top_p=0.9,
|
|
)
|
|
outputs = llm.chat(
|
|
prompts,
|
|
sampling_params,
|
|
add_generation_prompt=True
|
|
)
|
|
return [output.outputs[0].text.strip() for output in outputs]
|
|
|
|
set_seed(args.seed)
|
|
logging.basicConfig(level=logging.INFO)
|
|
|
|
if preloaded_dataset is not None:
|
|
print(f"Using pre-loaded dataset for inference with {len(preloaded_dataset)} examples.")
|
|
ds = preloaded_dataset
|
|
else:
|
|
print(f"Loading dataset from HuggingFace: {args.hf_dataset_path} for inference")
|
|
if args.hf_dataset_path.startswith('yapeichang/') or '/' in args.hf_dataset_path:
|
|
ds = load_dataset(args.hf_dataset_path, split="train")
|
|
else:
|
|
loaded_ds = load_from_disk(args.hf_dataset_path)
|
|
if isinstance(loaded_ds, DatasetDict):
|
|
ds = loaded_ds["train"]
|
|
else:
|
|
ds = loaded_ds
|
|
logging.info(f"Dataset loaded successfully from {args.hf_dataset_path} for inference")
|
|
print(f"Loaded dataset with {len(ds)} examples for inference")
|
|
|
|
llm, tokenizer = setup_vllm(args)
|
|
logging.info(f"Model loaded successfully with vLLM")
|
|
|
|
os.makedirs(args.output_dir, exist_ok=True)
|
|
model_basename = get_model_name(args.model)
|
|
save_path = os.path.join(args.output_dir, f"{model_basename}_inference_results.csv")
|
|
logging.info(f"Saving results to {save_path}")
|
|
|
|
# Load existing results if they exist
|
|
existing_results = None
|
|
all_results = []
|
|
if os.path.exists(save_path):
|
|
existing_results = pd.read_csv(save_path)
|
|
all_results = existing_results.to_dict('records')
|
|
logging.info(f"Loaded {len(all_results)} existing results from {save_path}")
|
|
processed_ids = set(existing_results['id'].values)
|
|
else:
|
|
processed_ids = set()
|
|
|
|
examples_to_process = [ex for ex in ds if ex['id'] not in processed_ids]
|
|
logging.info(f"Processing {len(examples_to_process)} new examples")
|
|
|
|
if examples_to_process:
|
|
prompts_for_model = []
|
|
original_prompts_content = []
|
|
|
|
for ex in examples_to_process:
|
|
example_id = ex.get("id", "unknown_id")
|
|
if "prompt" in ex and ex["prompt"] is not None:
|
|
current_prompt_content = ex["prompt"]
|
|
else:
|
|
current_prompt_content = _get_user_prompt_from_messages(ex.get("messages"), example_id)
|
|
|
|
if current_prompt_content is None:
|
|
print(f"Warning: Skipping example {example_id} in inference due to missing prompt.")
|
|
original_prompts_content.append(None)
|
|
prompts_for_model.append(format_message(""))
|
|
continue
|
|
|
|
original_prompts_content.append(current_prompt_content)
|
|
prompts_for_model.append(format_message(current_prompt_content))
|
|
|
|
valid_indices = [i for i, p in enumerate(original_prompts_content) if p is not None]
|
|
prompts_for_model = [prompts_for_model[i] for i in valid_indices]
|
|
examples_to_process_filtered = [examples_to_process[i] for i in valid_indices]
|
|
original_prompts_content_filtered = [original_prompts_content[i] for i in valid_indices]
|
|
|
|
if not prompts_for_model:
|
|
print("No valid prompts to send for inference after filtering.")
|
|
else:
|
|
logging.info(f"Generating responses for {len(prompts_for_model)} prompts")
|
|
responses = generate_responses(llm, prompts_for_model, max_new_tokens=args.max_new_tokens)
|
|
logging.info(f"Generated {len(responses)} responses")
|
|
|
|
response_idx = 0
|
|
for ex, prompt_content in zip(examples_to_process_filtered, original_prompts_content_filtered):
|
|
all_results.append({
|
|
"id": ex["id"],
|
|
"source": ex.get("source", "unknown"),
|
|
"prompt": prompt_content,
|
|
"response": responses[response_idx]
|
|
})
|
|
response_idx += 1
|
|
|
|
all_results = sorted(all_results, key=lambda x: x["id"])
|
|
pd.DataFrame(all_results).to_csv(save_path, index=False, quoting=csv.QUOTE_ALL, escapechar='\\')
|
|
logging.info(f"All results saved to {save_path}")
|
|
print(f"Inference complete! Results saved to {save_path}")
|
|
else:
|
|
print("All examples have already been processed!")
|
|
|
|
return save_path
|
|
|
|
def main():
|
|
from vllm import LLM, SamplingParams
|
|
|
|
parser = argparse.ArgumentParser(description="Unified training data generation script")
|
|
subparsers = parser.add_subparsers(dest="command", help="Command to run")
|
|
|
|
# GRPO data command
|
|
grpo_parser = subparsers.add_parser("grpo", help="Create GRPO dataset (aggregates references and optionally sorts by score)")
|
|
grpo_parser.add_argument("--hf_dataset_path", type=str, default="yapeichang/BLEUBERI-Tulu3-50k", help="HuggingFace dataset path or local path with reference outputs")
|
|
grpo_parser.add_argument("--inference_max_new_tokens", type=int, default=512, help="Max new tokens for on-the-fly inference if model outputs are missing.")
|
|
grpo_parser.add_argument("--ref_models", type=str, nargs="+", default=["gold"]) # More options: "claude-3-7-sonnet@20250219", "deepseek-chat-v3", "gemini-2.5-pro-exp-03-25", "o4-mini-2025-04-16", "Llama-3.1-8B-Instruct
|
|
grpo_parser.add_argument("--selection_mode", type=str, choices=["random", "easy", "medium", "hard"], default="hard", help="Selection mode: 'random' for random sampling, 'easy' for highest scores, 'medium' for middle scores, 'hard' for lowest scores")
|
|
grpo_parser.add_argument("--output_dataset_name", type=str, default=None, help="Custom name for the GRPO output dataset directory. If None, a name will be automatically generated.")
|
|
grpo_parser.add_argument("--metric", type=str, choices=["bleu", "rm", "rouge", "bertscore", "bleu_rouge_f1"], help="Metric to use for scoring (if not using pre-computed scores)")
|
|
grpo_parser.add_argument("--model", type=str, help="Model to use for scoring (if not using pre-computed scores) and for generating inference outputs if missing.")
|
|
grpo_parser.add_argument("--num_examples", type=int, default=None, help="Number of examples to include in final dataset")
|
|
grpo_parser.add_argument("--output_dir", type=str, default="../data")
|
|
grpo_parser.add_argument("--seed", type=int, default=42, help="Random seed for on-the-fly inference if model outputs are missing.")
|
|
|
|
# SFT data command
|
|
sft_parser = subparsers.add_parser("sft", help="Convert dataset with references to SFT format")
|
|
sft_parser.add_argument("--input_data_path", type=str, required=True, help="Path to dataset with references")
|
|
sft_parser.add_argument("--output_dir", type=str, default="../data")
|
|
|
|
args = parser.parse_args()
|
|
|
|
if args.command == "grpo":
|
|
if args.selection_mode != "random" and not (args.model and args.metric):
|
|
raise ValueError("When selection_mode is 'easy', 'medium', or 'hard', both --model and --metric must be provided")
|
|
make_grpo_data(args)
|
|
elif args.command == "sft":
|
|
make_sft_data(args)
|
|
else:
|
|
parser.print_help()
|
|
|
|
if __name__ == "__main__":
|
|
main()
|