add eval runner

This commit is contained in:
Dakota 2025-12-19 19:56:59 -06:00
parent 405efa8302
commit 8ec5066998
3 changed files with 410 additions and 8 deletions

225
atroposlib/envs/eval.py Normal file
View file

@ -0,0 +1,225 @@
import json
import logging
import os
import time
from abc import ABC, abstractmethod
from typing import Dict, List, Optional, Tuple
import jsonlines
import numpy as np
from openai.types.chat import ChatCompletion
from tqdm.asyncio import tqdm_asyncio
from atroposlib.envs.server_handling.server_manager import ServerManager
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
def pass_at_k(m, c, k):
"""
m: total samples
c: correct samples
k: k in pass@k
"""
if m - c < k:
return 1.0
return 1.0 - np.prod(1.0 - k / np.arange(m - c + 1, m + 1))
def evaluate_log(
metrics: Dict,
eval_dir: Optional[str] = None,
task_name: Optional[str] = None,
model_name: Optional[str] = None,
start_time: Optional[float] = None,
end_time: Optional[float] = None,
generation_parameters: Optional[Dict] = None,
samples: Optional[List[Dict]] = None,
verbose: bool = True,
):
"""
Log evaluation results to a JSON file in the format expected by nous-evals.
Args:
metrics: Dictionary of metrics to log (same format as wandb_log)
eval_dir: Directory to save evaluation results to
task_name: Name of the evaluation task (defaults to env name)
model_name: Name of the model being evaluated
start_time: Start time of evaluation (unix timestamp)
end_time: End time of evaluation (unix timestamp)
generation_parameters: Dictionary of generation parameters used
samples: List of sample dictionaries to save to samples.jsonl
verbose: If True, print a markdown table of the metrics
"""
if eval_dir is None:
logger.warning("eval_dir is not set, skipping evaluation logging")
return
# Create directory if it doesn't exist
os.makedirs(eval_dir, exist_ok=True)
# Generate filename
filename = "metrics.json"
filepath = os.path.join(eval_dir, filename)
if start_time is None:
start_time = time.time()
if end_time is None:
end_time = time.time()
if generation_parameters is None:
generation_parameters = {}
# Print metrics table if verbose
if verbose:
from atroposlib.utils.display import display_metrics_table
display_metrics_table(task_name, metrics, start_time, end_time)
# Build evaluation result structure - skeleton of lighteval's
task_key = f"atropos|{task_name}|0"
eval_result = {
"config_general": {
"model_name": model_name,
"total_evaluation_time_seconds": str(end_time - start_time),
"generation_parameters": generation_parameters,
},
"results": {
task_key: metrics,
"all": metrics,
},
}
# Write main results to JSON file
with open(filepath, "w") as f:
json.dump(eval_result, f, indent=2)
print(f"Evaluation results saved to {filepath}")
# Write samples to JSONL file if provided
if samples:
samples_filepath = os.path.join(eval_dir, "samples.jsonl")
with jsonlines.open(samples_filepath, "w") as writer:
for sample in samples:
writer.write(sample)
print(f"Evaluation samples saved to {samples_filepath}")
class EvalBase(ABC):
""" """
def __init__(self, pass_at_n=1, pass_at_n_samples=1, **kwargs):
self.pass_at_n = pass_at_n
self.pass_at_n_samples = pass_at_n_samples
for k, v in kwargs.items():
setattr(self, k, v)
self.data = self.setup_data()
def get_generation_params(self):
"""
Generation params to be sent to an openai server
"""
temp = getattr(self, "temperature", 0.0)
n = max(self.pass_at_n_samples, self.pass_at_n)
top_p = getattr(self, "top_p", 1.0)
max_tokens = getattr(self, "max_tokens", -1)
return {"temperature": temp, "n": n, "top_p": top_p, "max_tokens": max_tokens}
async def chat_completion(self, server, messages) -> ChatCompletion:
gen_params = self.get_generation_params()
return await server.chat_completion(messages=messages, **gen_params)
@abstractmethod
def setup_data(self) -> list:
raise NotImplementedError("Setup data method must be implemented in subclass")
@abstractmethod
async def run_item(
self, server: ServerManager, data_item: dict
) -> Tuple[dict, list]:
"""
An abstract method that must be overridden in a subclass to define how a
specific item should be processed. This method encapsulates the logic required
to run or process the given data item on the provided server instance.
Args:
server (ServerManager): An instance of ServerManager used to manage and
interact with server operations during the item processing.
data_item (dict): A dictionary representing the data item to be processed.
The structure and content of the dictionary would depend on the
specific application and use case.
Returns:
Tuple[dict, list]: A tuple where the first element is a dictionary
containing the processed results or output of the operation, and
the second element is a list containing any additional data generated
or collected during the item's processing.
Raises:
NotImplementedError: This error is raised when the method is not
implemented in subclasses.
"""
raise NotImplementedError("Run item method must be implemented in subclass")
async def __call__(self, server_manager: ServerManager):
task_coros = list()
start_time = time.time()
for data_item in self.data:
task_coros.append(self.run_item(server_manager, data_item))
task_results = await tqdm_asyncio.gather(*task_coros)
end_time = time.time()
# grab metrics and generation params
metrics_list = [result[0] for result in task_results]
# aggregate metrics
keys = list(metrics_list[0].keys())
metrics = {
key: sum([result[key] for result in metrics_list]) / len(metrics_list)
for key in keys
}
# check if all generation params are the same
samples = [result[1] for result in task_results]
task_name = self.__class__.__name__
task_name += f"@{self.pass_at_n}"
if self.pass_at_n != self.pass_at_n_samples:
task_name += f":{self.pass_at_n_samples}"
print(f"{task_name} metrics: {metrics}")
evaluate_log(
metrics,
eval_dir=getattr(self, "eval_dir", None),
task_name=task_name,
model_name=server_manager.servers[0].config.model_name,
start_time=start_time,
end_time=end_time,
generation_parameters=self.get_generation_params(),
samples=samples,
verbose=getattr(self, "verbose", False),
)
return metrics
async def eval_runner(eval_env: EvalBase):
import argparse
from atroposlib.envs.server_handling.server_baseline import APIServerConfig
parser = argparse.ArgumentParser()
parser.add_argument(
"--server-url",
type=str,
default="http://localhost:8000",
help="URL of the server to connect to.",
)
parser.add_argument("--model-name", type=str, default=None, help="Model name")
args = parser.parse_args()
server_manager = ServerManager(
configs=[
APIServerConfig(
api_key="dummy",
base_url=args.server_url,
model_name=args.model_name,
health_check=False,
),
]
)
return await eval_env(server_manager)

View file

@ -1,6 +1,7 @@
import asyncio
import inspect
import os
import warnings
from contextlib import asynccontextmanager
from typing import AsyncGenerator, List, Union
@ -352,12 +353,18 @@ class ServerManager:
most_available_server_num_slots = server.sem._value
# Create ManagedServer wrapping the selected server
managed = ManagedServer(
server=self.servers[most_available_server], tokenizer=tokenizer
)
if isinstance(self.servers[most_available_server], OpenAIServer):
warnings.warn(
"Using OpenAIServer with managed_server does not allow for state tracking"
)
yield self.servers[most_available_server]
else:
managed = ManagedServer(
server=self.servers[most_available_server], tokenizer=tokenizer
)
try:
yield managed
finally:
# Clean up: reset tracked sequences
managed.reset()
try:
yield managed
finally:
# Clean up: reset tracked sequences
managed.reset()

View file

@ -0,0 +1,170 @@
import asyncio
from concurrent.futures import ProcessPoolExecutor
from typing import Optional, Tuple
from datasets import load_dataset
from latex2sympy2_extended import NormalizationConfig
from math_verify import LatexExtractionConfig, parse, verify
from math_verify.errors import TimeoutException
from atroposlib.envs.eval import EvalBase, eval_runner, pass_at_k
from atroposlib.envs.server_handling.server_manager import ServerManager
hermes_system_prompt = (
"You are a deep thinking AI, you may use extremely long chains of thought to deeply consider the "
"problem and deliberate with yourself via systematic reasoning processes to help come to a correct "
"solution prior to answering. You should enclose your thoughts and internal monologue inside <think> "
"</think> tags, and then provide your solution or response to the problem."
)
def score_answer(gold, resp) -> Optional[bool]:
try:
gold_parsed = parse(
gold,
extraction_mode="first_match",
extraction_config=[LatexExtractionConfig()],
)
except (Exception, TimeoutException, KeyError, TypeError, NotImplementedError):
return None
if len(gold_parsed) != 0:
try:
answer_parsed = parse(
resp,
extraction_config=[
LatexExtractionConfig(
normalization_config=NormalizationConfig(
nits=False,
malformed_operators=False,
basic_latex=True,
boxed="all",
units=True,
),
# Ensures that boxed is tried first
boxed_match_priority=0,
try_extract_without_anchor=False,
)
],
extraction_mode="first_match",
)
except (
Exception,
TimeoutException,
KeyError,
TypeError,
NotImplementedError,
):
# Can't parse, so we skip
return None
# Reward 1 if the content is the same as the ground truth, 0 otherwise
try:
return verify(answer_parsed, gold_parsed)
except (
Exception,
TimeoutException,
KeyError,
TypeError,
NotImplementedError,
):
return None
return None
class AIME24(EvalBase):
"""
AIME24 Eval Environment
kwargs:
use_system_prompt (bool): Whether to use the system prompt in the evaluation.
"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.mp_executor = ProcessPoolExecutor(8)
def setup_data(self):
aime_test_data = load_dataset("HuggingFaceH4/aime_2024", split="train")
data = list()
for item in aime_test_data:
data.append(
{
"problem": item["problem"],
"answer": item["answer"],
}
)
return data
async def run_item(
self, server: ServerManager, data_item: dict
) -> Tuple[dict, list]:
"""
An abstract method that must be overridden in a subclass to define how a
specific item should be processed. This method encapsulates the logic required
to run or process the given data item on the provided server instance.
Args:
server (ServerManager): An instance of ServerManager used to manage and
interact with server operations during the item processing.
data_item (dict): A dictionary representing the data item to be processed.
The structure and content of the dictionary would depend on the
specific application and use case.
Returns:
Tuple[dict, list]: A tuple where the first element is a dictionary
containing the processed results or output of the operation, and
the second element is a list containing any additional data generated
or collected during the item's processing.
"""
answer = data_item["answer"]
question = data_item["problem"]
use_sys_prompt = getattr(self, "use_system_prompt", False)
async with server.managed_server() as managed:
messages = (
[{"role": "system", "content": hermes_system_prompt}]
if use_sys_prompt
else []
)
messages.append({"role": "user", "content": question})
completion = await self.chat_completion(managed, messages)
loop = asyncio.get_event_loop()
gold = "\\boxed{" + answer + "}" if "\\boxed" not in answer else answer
tasks = []
for choice in completion.choices:
resp = choice.message.content.split("</think>")[-1]
tasks.append(
loop.run_in_executor(self.mp_executor, score_answer, gold, resp)
)
rewards = await asyncio.gather(*tasks)
rewards = [1.0 if reward else 0.0 for reward in rewards]
passing = sum(rewards)
n = self.get_generation_params()["n"]
pass_at_k_val = getattr(self, "pass_at_k", 1)
acc_at_k = pass_at_k(n, passing, getattr(self, "pass_at_k", 1))
print(acc_at_k, n, passing, pass_at_k_val)
key = f"pass@{pass_at_k_val}"
if n != pass_at_k_val:
key += f":{n}"
return {
key: acc_at_k,
}, [
{
"messages": messages,
"answer": choice.message.content,
"score": rewards[i],
}
for i, choice in enumerate(completion.choices)
]
if __name__ == "__main__":
asyncio.run(
eval_runner(
AIME24(
pass_at_n=1,
pass_at_n_samples=4,
temperature=1.0,
max_tokens=32768,
use_system_prompt=True,
)
)
)