mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
add eval runner
This commit is contained in:
parent
405efa8302
commit
8ec5066998
3 changed files with 410 additions and 8 deletions
225
atroposlib/envs/eval.py
Normal file
225
atroposlib/envs/eval.py
Normal file
|
|
@ -0,0 +1,225 @@
|
|||
import json
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import jsonlines
|
||||
import numpy as np
|
||||
from openai.types.chat import ChatCompletion
|
||||
from tqdm.asyncio import tqdm_asyncio
|
||||
|
||||
from atroposlib.envs.server_handling.server_manager import ServerManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
|
||||
def pass_at_k(m, c, k):
|
||||
"""
|
||||
m: total samples
|
||||
c: correct samples
|
||||
k: k in pass@k
|
||||
"""
|
||||
if m - c < k:
|
||||
return 1.0
|
||||
return 1.0 - np.prod(1.0 - k / np.arange(m - c + 1, m + 1))
|
||||
|
||||
|
||||
def evaluate_log(
|
||||
metrics: Dict,
|
||||
eval_dir: Optional[str] = None,
|
||||
task_name: Optional[str] = None,
|
||||
model_name: Optional[str] = None,
|
||||
start_time: Optional[float] = None,
|
||||
end_time: Optional[float] = None,
|
||||
generation_parameters: Optional[Dict] = None,
|
||||
samples: Optional[List[Dict]] = None,
|
||||
verbose: bool = True,
|
||||
):
|
||||
"""
|
||||
Log evaluation results to a JSON file in the format expected by nous-evals.
|
||||
|
||||
Args:
|
||||
metrics: Dictionary of metrics to log (same format as wandb_log)
|
||||
eval_dir: Directory to save evaluation results to
|
||||
task_name: Name of the evaluation task (defaults to env name)
|
||||
model_name: Name of the model being evaluated
|
||||
start_time: Start time of evaluation (unix timestamp)
|
||||
end_time: End time of evaluation (unix timestamp)
|
||||
generation_parameters: Dictionary of generation parameters used
|
||||
samples: List of sample dictionaries to save to samples.jsonl
|
||||
verbose: If True, print a markdown table of the metrics
|
||||
"""
|
||||
if eval_dir is None:
|
||||
logger.warning("eval_dir is not set, skipping evaluation logging")
|
||||
return
|
||||
# Create directory if it doesn't exist
|
||||
os.makedirs(eval_dir, exist_ok=True)
|
||||
|
||||
# Generate filename
|
||||
filename = "metrics.json"
|
||||
filepath = os.path.join(eval_dir, filename)
|
||||
|
||||
if start_time is None:
|
||||
start_time = time.time()
|
||||
if end_time is None:
|
||||
end_time = time.time()
|
||||
if generation_parameters is None:
|
||||
generation_parameters = {}
|
||||
|
||||
# Print metrics table if verbose
|
||||
if verbose:
|
||||
from atroposlib.utils.display import display_metrics_table
|
||||
|
||||
display_metrics_table(task_name, metrics, start_time, end_time)
|
||||
|
||||
# Build evaluation result structure - skeleton of lighteval's
|
||||
task_key = f"atropos|{task_name}|0"
|
||||
|
||||
eval_result = {
|
||||
"config_general": {
|
||||
"model_name": model_name,
|
||||
"total_evaluation_time_seconds": str(end_time - start_time),
|
||||
"generation_parameters": generation_parameters,
|
||||
},
|
||||
"results": {
|
||||
task_key: metrics,
|
||||
"all": metrics,
|
||||
},
|
||||
}
|
||||
|
||||
# Write main results to JSON file
|
||||
with open(filepath, "w") as f:
|
||||
json.dump(eval_result, f, indent=2)
|
||||
|
||||
print(f"Evaluation results saved to {filepath}")
|
||||
|
||||
# Write samples to JSONL file if provided
|
||||
if samples:
|
||||
samples_filepath = os.path.join(eval_dir, "samples.jsonl")
|
||||
with jsonlines.open(samples_filepath, "w") as writer:
|
||||
for sample in samples:
|
||||
writer.write(sample)
|
||||
print(f"Evaluation samples saved to {samples_filepath}")
|
||||
|
||||
|
||||
class EvalBase(ABC):
|
||||
""" """
|
||||
|
||||
def __init__(self, pass_at_n=1, pass_at_n_samples=1, **kwargs):
|
||||
self.pass_at_n = pass_at_n
|
||||
self.pass_at_n_samples = pass_at_n_samples
|
||||
for k, v in kwargs.items():
|
||||
setattr(self, k, v)
|
||||
self.data = self.setup_data()
|
||||
|
||||
def get_generation_params(self):
|
||||
"""
|
||||
Generation params to be sent to an openai server
|
||||
"""
|
||||
temp = getattr(self, "temperature", 0.0)
|
||||
n = max(self.pass_at_n_samples, self.pass_at_n)
|
||||
top_p = getattr(self, "top_p", 1.0)
|
||||
max_tokens = getattr(self, "max_tokens", -1)
|
||||
return {"temperature": temp, "n": n, "top_p": top_p, "max_tokens": max_tokens}
|
||||
|
||||
async def chat_completion(self, server, messages) -> ChatCompletion:
|
||||
gen_params = self.get_generation_params()
|
||||
return await server.chat_completion(messages=messages, **gen_params)
|
||||
|
||||
@abstractmethod
|
||||
def setup_data(self) -> list:
|
||||
raise NotImplementedError("Setup data method must be implemented in subclass")
|
||||
|
||||
@abstractmethod
|
||||
async def run_item(
|
||||
self, server: ServerManager, data_item: dict
|
||||
) -> Tuple[dict, list]:
|
||||
"""
|
||||
An abstract method that must be overridden in a subclass to define how a
|
||||
specific item should be processed. This method encapsulates the logic required
|
||||
to run or process the given data item on the provided server instance.
|
||||
|
||||
Args:
|
||||
server (ServerManager): An instance of ServerManager used to manage and
|
||||
interact with server operations during the item processing.
|
||||
data_item (dict): A dictionary representing the data item to be processed.
|
||||
The structure and content of the dictionary would depend on the
|
||||
specific application and use case.
|
||||
|
||||
Returns:
|
||||
Tuple[dict, list]: A tuple where the first element is a dictionary
|
||||
containing the processed results or output of the operation, and
|
||||
the second element is a list containing any additional data generated
|
||||
or collected during the item's processing.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: This error is raised when the method is not
|
||||
implemented in subclasses.
|
||||
"""
|
||||
raise NotImplementedError("Run item method must be implemented in subclass")
|
||||
|
||||
async def __call__(self, server_manager: ServerManager):
|
||||
task_coros = list()
|
||||
start_time = time.time()
|
||||
for data_item in self.data:
|
||||
task_coros.append(self.run_item(server_manager, data_item))
|
||||
task_results = await tqdm_asyncio.gather(*task_coros)
|
||||
end_time = time.time()
|
||||
# grab metrics and generation params
|
||||
metrics_list = [result[0] for result in task_results]
|
||||
# aggregate metrics
|
||||
keys = list(metrics_list[0].keys())
|
||||
metrics = {
|
||||
key: sum([result[key] for result in metrics_list]) / len(metrics_list)
|
||||
for key in keys
|
||||
}
|
||||
# check if all generation params are the same
|
||||
samples = [result[1] for result in task_results]
|
||||
task_name = self.__class__.__name__
|
||||
task_name += f"@{self.pass_at_n}"
|
||||
if self.pass_at_n != self.pass_at_n_samples:
|
||||
task_name += f":{self.pass_at_n_samples}"
|
||||
print(f"{task_name} metrics: {metrics}")
|
||||
evaluate_log(
|
||||
metrics,
|
||||
eval_dir=getattr(self, "eval_dir", None),
|
||||
task_name=task_name,
|
||||
model_name=server_manager.servers[0].config.model_name,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
generation_parameters=self.get_generation_params(),
|
||||
samples=samples,
|
||||
verbose=getattr(self, "verbose", False),
|
||||
)
|
||||
|
||||
return metrics
|
||||
|
||||
|
||||
async def eval_runner(eval_env: EvalBase):
|
||||
import argparse
|
||||
|
||||
from atroposlib.envs.server_handling.server_baseline import APIServerConfig
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--server-url",
|
||||
type=str,
|
||||
default="http://localhost:8000",
|
||||
help="URL of the server to connect to.",
|
||||
)
|
||||
parser.add_argument("--model-name", type=str, default=None, help="Model name")
|
||||
args = parser.parse_args()
|
||||
server_manager = ServerManager(
|
||||
configs=[
|
||||
APIServerConfig(
|
||||
api_key="dummy",
|
||||
base_url=args.server_url,
|
||||
model_name=args.model_name,
|
||||
health_check=False,
|
||||
),
|
||||
]
|
||||
)
|
||||
return await eval_env(server_manager)
|
||||
|
|
@ -1,6 +1,7 @@
|
|||
import asyncio
|
||||
import inspect
|
||||
import os
|
||||
import warnings
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import AsyncGenerator, List, Union
|
||||
|
||||
|
|
@ -352,12 +353,18 @@ class ServerManager:
|
|||
most_available_server_num_slots = server.sem._value
|
||||
|
||||
# Create ManagedServer wrapping the selected server
|
||||
managed = ManagedServer(
|
||||
server=self.servers[most_available_server], tokenizer=tokenizer
|
||||
)
|
||||
if isinstance(self.servers[most_available_server], OpenAIServer):
|
||||
warnings.warn(
|
||||
"Using OpenAIServer with managed_server does not allow for state tracking"
|
||||
)
|
||||
yield self.servers[most_available_server]
|
||||
else:
|
||||
managed = ManagedServer(
|
||||
server=self.servers[most_available_server], tokenizer=tokenizer
|
||||
)
|
||||
|
||||
try:
|
||||
yield managed
|
||||
finally:
|
||||
# Clean up: reset tracked sequences
|
||||
managed.reset()
|
||||
try:
|
||||
yield managed
|
||||
finally:
|
||||
# Clean up: reset tracked sequences
|
||||
managed.reset()
|
||||
|
|
|
|||
170
environments/eval_environments/aime24.py
Normal file
170
environments/eval_environments/aime24.py
Normal file
|
|
@ -0,0 +1,170 @@
|
|||
import asyncio
|
||||
from concurrent.futures import ProcessPoolExecutor
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from datasets import load_dataset
|
||||
from latex2sympy2_extended import NormalizationConfig
|
||||
from math_verify import LatexExtractionConfig, parse, verify
|
||||
from math_verify.errors import TimeoutException
|
||||
|
||||
from atroposlib.envs.eval import EvalBase, eval_runner, pass_at_k
|
||||
from atroposlib.envs.server_handling.server_manager import ServerManager
|
||||
|
||||
hermes_system_prompt = (
|
||||
"You are a deep thinking AI, you may use extremely long chains of thought to deeply consider the "
|
||||
"problem and deliberate with yourself via systematic reasoning processes to help come to a correct "
|
||||
"solution prior to answering. You should enclose your thoughts and internal monologue inside <think> "
|
||||
"</think> tags, and then provide your solution or response to the problem."
|
||||
)
|
||||
|
||||
|
||||
def score_answer(gold, resp) -> Optional[bool]:
|
||||
try:
|
||||
gold_parsed = parse(
|
||||
gold,
|
||||
extraction_mode="first_match",
|
||||
extraction_config=[LatexExtractionConfig()],
|
||||
)
|
||||
except (Exception, TimeoutException, KeyError, TypeError, NotImplementedError):
|
||||
return None
|
||||
if len(gold_parsed) != 0:
|
||||
try:
|
||||
answer_parsed = parse(
|
||||
resp,
|
||||
extraction_config=[
|
||||
LatexExtractionConfig(
|
||||
normalization_config=NormalizationConfig(
|
||||
nits=False,
|
||||
malformed_operators=False,
|
||||
basic_latex=True,
|
||||
boxed="all",
|
||||
units=True,
|
||||
),
|
||||
# Ensures that boxed is tried first
|
||||
boxed_match_priority=0,
|
||||
try_extract_without_anchor=False,
|
||||
)
|
||||
],
|
||||
extraction_mode="first_match",
|
||||
)
|
||||
except (
|
||||
Exception,
|
||||
TimeoutException,
|
||||
KeyError,
|
||||
TypeError,
|
||||
NotImplementedError,
|
||||
):
|
||||
# Can't parse, so we skip
|
||||
return None
|
||||
# Reward 1 if the content is the same as the ground truth, 0 otherwise
|
||||
try:
|
||||
return verify(answer_parsed, gold_parsed)
|
||||
except (
|
||||
Exception,
|
||||
TimeoutException,
|
||||
KeyError,
|
||||
TypeError,
|
||||
NotImplementedError,
|
||||
):
|
||||
return None
|
||||
return None
|
||||
|
||||
|
||||
class AIME24(EvalBase):
|
||||
"""
|
||||
AIME24 Eval Environment
|
||||
|
||||
kwargs:
|
||||
use_system_prompt (bool): Whether to use the system prompt in the evaluation.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.mp_executor = ProcessPoolExecutor(8)
|
||||
|
||||
def setup_data(self):
|
||||
aime_test_data = load_dataset("HuggingFaceH4/aime_2024", split="train")
|
||||
data = list()
|
||||
for item in aime_test_data:
|
||||
data.append(
|
||||
{
|
||||
"problem": item["problem"],
|
||||
"answer": item["answer"],
|
||||
}
|
||||
)
|
||||
return data
|
||||
|
||||
async def run_item(
|
||||
self, server: ServerManager, data_item: dict
|
||||
) -> Tuple[dict, list]:
|
||||
"""
|
||||
An abstract method that must be overridden in a subclass to define how a
|
||||
specific item should be processed. This method encapsulates the logic required
|
||||
to run or process the given data item on the provided server instance.
|
||||
|
||||
Args:
|
||||
server (ServerManager): An instance of ServerManager used to manage and
|
||||
interact with server operations during the item processing.
|
||||
data_item (dict): A dictionary representing the data item to be processed.
|
||||
The structure and content of the dictionary would depend on the
|
||||
specific application and use case.
|
||||
|
||||
Returns:
|
||||
Tuple[dict, list]: A tuple where the first element is a dictionary
|
||||
containing the processed results or output of the operation, and
|
||||
the second element is a list containing any additional data generated
|
||||
or collected during the item's processing.
|
||||
"""
|
||||
answer = data_item["answer"]
|
||||
question = data_item["problem"]
|
||||
use_sys_prompt = getattr(self, "use_system_prompt", False)
|
||||
async with server.managed_server() as managed:
|
||||
messages = (
|
||||
[{"role": "system", "content": hermes_system_prompt}]
|
||||
if use_sys_prompt
|
||||
else []
|
||||
)
|
||||
messages.append({"role": "user", "content": question})
|
||||
completion = await self.chat_completion(managed, messages)
|
||||
loop = asyncio.get_event_loop()
|
||||
gold = "\\boxed{" + answer + "}" if "\\boxed" not in answer else answer
|
||||
tasks = []
|
||||
for choice in completion.choices:
|
||||
resp = choice.message.content.split("</think>")[-1]
|
||||
tasks.append(
|
||||
loop.run_in_executor(self.mp_executor, score_answer, gold, resp)
|
||||
)
|
||||
rewards = await asyncio.gather(*tasks)
|
||||
rewards = [1.0 if reward else 0.0 for reward in rewards]
|
||||
passing = sum(rewards)
|
||||
n = self.get_generation_params()["n"]
|
||||
pass_at_k_val = getattr(self, "pass_at_k", 1)
|
||||
acc_at_k = pass_at_k(n, passing, getattr(self, "pass_at_k", 1))
|
||||
print(acc_at_k, n, passing, pass_at_k_val)
|
||||
key = f"pass@{pass_at_k_val}"
|
||||
if n != pass_at_k_val:
|
||||
key += f":{n}"
|
||||
return {
|
||||
key: acc_at_k,
|
||||
}, [
|
||||
{
|
||||
"messages": messages,
|
||||
"answer": choice.message.content,
|
||||
"score": rewards[i],
|
||||
}
|
||||
for i, choice in enumerate(completion.choices)
|
||||
]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(
|
||||
eval_runner(
|
||||
AIME24(
|
||||
pass_at_n=1,
|
||||
pass_at_n_samples=4,
|
||||
temperature=1.0,
|
||||
max_tokens=32768,
|
||||
use_system_prompt=True,
|
||||
)
|
||||
)
|
||||
)
|
||||
Loading…
Add table
Add a link
Reference in a new issue