BLEUBERI/training/run_inference.py
2025-06-04 20:36:43 +00:00

127 lines
4.6 KiB
Python

import torch
import argparse
import logging
import os
import pandas as pd
import random
import numpy as np
from tqdm import tqdm
from datasets import load_from_disk
from chat_templates import QWEN_CHAT_TEMPLATE, LLAMA_CHAT_TEMPLATE, OLMO_CHAT_TEMPLATE
def set_seed(seed):
random.seed(seed)
np.random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def format_message(prompt):
return [
{"role": "user", "content": prompt},
]
def main():
from vllm import LLM, SamplingParams
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
def setup_vllm(args):
llm = LLM(
model=args.model,
dtype="bfloat16",
# hf_token=os.getenv("HF_TOKEN")
)
tokenizer = llm.get_tokenizer()
if not tokenizer.chat_template:
if "qwen" in args.model.lower():
tokenizer.chat_template = QWEN_CHAT_TEMPLATE
elif "llama" in args.model.lower():
tokenizer.chat_template = LLAMA_CHAT_TEMPLATE
elif "olmo" in args.model.lower():
tokenizer.chat_template = OLMO_CHAT_TEMPLATE
print(f"Chat template set to {tokenizer.chat_template}")
return llm, tokenizer
def generate_responses(llm, prompts, max_new_tokens=512):
sampling_params = SamplingParams(
max_tokens=max_new_tokens,
temperature=0.7,
top_p=0.9,
)
outputs = llm.chat(
prompts,
sampling_params,
add_generation_prompt=True
)
return [output.outputs[0].text.strip() for output in outputs]
parser = argparse.ArgumentParser(description='Model inference script')
parser.add_argument('--model', type=str, required=True, help='Path to the trained model')
parser.add_argument('--data_path', type=str, default='../data/tulu3_50k', help='Path to the data file')
parser.add_argument('--split', type=str, default=None, help='Split to use for inference')
parser.add_argument('--seed', type=int, default=42, help='Random seed')
parser.add_argument('--out_dir', type=str, default='../data/inference_outputs', help='Output directory')
args = parser.parse_args()
set_seed(args.seed)
logging.basicConfig(level=logging.INFO)
logging.info("Loading model with vLLM...")
ds = load_from_disk(args.data_path)
if args.split:
ds = ds[args.split]
logging.info(f"Dataset loaded successfully from {args.data_path}")
llm, tokenizer = setup_vllm(args)
logging.info(f"Model loaded successfully with vLLM")
if "ckpts" in args.model:
model_basename = args.model.split("/")[-2:]
model_basename = "_".join(model_basename)
else:
model_basename = os.path.basename(args.model)
save_path = os.path.join(args.out_dir, f"{model_basename}.csv")
logging.info(f"Saving results to {save_path}")
os.makedirs(args.out_dir, exist_ok=True)
existing_results = None
all_results = []
if os.path.exists(save_path):
existing_results = pd.read_csv(save_path)
all_results = existing_results.to_dict('records')
logging.info(f"Loaded {len(all_results)} existing results from {save_path}")
processed_ids = set(existing_results['id'].values)
else:
processed_ids = set()
examples_to_process = [ex for ex in ds if ex['id'] not in processed_ids]
logging.info(f"Processing {len(examples_to_process)} new examples")
if examples_to_process:
prompts = []
for ex in examples_to_process:
prompt = next((item['content'] for item in ex["messages"] if item['role'] == 'user'), None)
prompts.append(format_message(prompt))
logging.info(f"Generating responses for {len(prompts)} prompts")
responses = generate_responses(llm, prompts)
logging.info(f"Generated {len(responses)} responses")
for j, ex in enumerate(examples_to_process):
prompt = next((item['content'] for item in ex["messages"] if item['role'] == 'user'), None)
all_results.append({
"id": ex["id"],
"source": ex["source"],
"prompt": prompt,
"response": responses[j]
})
all_results = sorted(all_results, key=lambda x: x["id"])
pd.DataFrame(all_results).to_csv(save_path, index=False, quoting=csv.QUOTE_ALL, escapechar='\\')
logging.info(f"All results saved to {save_path}")
if __name__ == "__main__":
main()