mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-19 12:58:07 +00:00
wip eval script
This commit is contained in:
parent
661cbc4c71
commit
3babfbbd29
2 changed files with 60 additions and 0 deletions
0
training/eval/__init__.py
Normal file
0
training/eval/__init__.py
Normal file
60
training/eval/evaluate.py
Normal file
60
training/eval/evaluate.py
Normal file
|
|
@ -0,0 +1,60 @@
|
|||
import argparse
|
||||
from typing import Any
|
||||
|
||||
from transformers import AutoModel, AutoTokenizer
|
||||
|
||||
import reasoning_gym
|
||||
from reasoning_gym.dataset import ProceduralDataset
|
||||
from reasoning_gym.utils import SYSTEM_PROMPTS, extract_answer
|
||||
|
||||
from ..utils import ReasoningGymDataset
|
||||
|
||||
|
||||
def get_model_response(model, prompt: str) -> str:
|
||||
# TODO
|
||||
return model.generate(prompt)
|
||||
|
||||
|
||||
def process_entry(model, dataset: ProceduralDataset, entry: dict[str, Any]) -> float:
|
||||
model_response = get_model_response(model, entry["question"])
|
||||
model_answer = extract_answer(model_response)
|
||||
score = dataset.score_answer(answer=model_answer, entry=entry)
|
||||
return score
|
||||
|
||||
|
||||
def evaluate(model, dataset: ReasoningGymDataset) -> float:
|
||||
procedural_dataset = dataset.data
|
||||
|
||||
total_score, n = 0.0, 0
|
||||
for entry in procedural_dataset:
|
||||
score = process_entry(model, procedural_dataset, entry)
|
||||
total_score += score
|
||||
n += 1
|
||||
|
||||
return total_score / n
|
||||
|
||||
|
||||
def main(args):
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.model)
|
||||
model = AutoModel.from_pretrained(args.model)
|
||||
|
||||
procedural_dataset = reasoning_gym.create_dataset(args.dataset_name, seed=args.dataset_seed, size=args.dataset_size)
|
||||
|
||||
dataset = ReasoningGymDataset(
|
||||
tokenizer=tokenizer,
|
||||
procedural_dataset=procedural_dataset,
|
||||
developer_prompt=SYSTEM_PROMPTS[args.developer_prompt],
|
||||
)
|
||||
|
||||
score = evaluate(model, dataset)
|
||||
print(f"Score: {score}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--model", type=str, required=True)
|
||||
parser.add_argument("--dataset-name", type=str, required=True)
|
||||
parser.add_argument("--dataset-size", type=int, default=10000)
|
||||
parser.add_argument("--dataset-seed", type=int, default=2)
|
||||
parser.add_argument("--developer-prompt", type=str, default="DeepSeekZero")
|
||||
args = parser.parse_args()
|
||||
Loading…
Add table
Add a link
Reference in a new issue