mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
fix olympiadbench due to upstream changes
This commit is contained in:
parent
3863ece98b
commit
e09ae8d3d3
1 changed files with 21 additions and 14 deletions
|
|
@ -9,6 +9,7 @@ import re
|
|||
from concurrent.futures import ProcessPoolExecutor
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import wandb
|
||||
from datasets import load_dataset
|
||||
from latex2sympy2_extended import NormalizationConfig
|
||||
from math_verify import LatexExtractionConfig, parse, verify
|
||||
|
|
@ -16,13 +17,12 @@ from math_verify.errors import TimeoutException
|
|||
from pydantic import Field
|
||||
from tqdm.asyncio import tqdm_asyncio
|
||||
|
||||
import wandb
|
||||
from atroposlib.envs.base import (
|
||||
BaseEnv,
|
||||
BaseEnvConfig,
|
||||
EvalHandlingEnum,
|
||||
OpenaiConfig,
|
||||
ScoredDataGroup,
|
||||
ServerBaseline,
|
||||
)
|
||||
|
||||
prompt_format = (
|
||||
|
|
@ -115,7 +115,7 @@ class MathEnv(BaseEnv):
|
|||
def __init__(
|
||||
self,
|
||||
config: RSConfig,
|
||||
server_configs: List[OpenaiConfig],
|
||||
server_configs: ServerBaseline,
|
||||
slurm=True,
|
||||
testing=False,
|
||||
):
|
||||
|
|
@ -133,7 +133,7 @@ class MathEnv(BaseEnv):
|
|||
self.iter = 0
|
||||
|
||||
@classmethod
|
||||
def config_init(cls) -> Tuple[RSConfig, List[OpenaiConfig]]:
|
||||
def config_init(cls) -> Tuple[RSConfig, ServerBaseline]:
|
||||
env_config = RSConfig(
|
||||
tokenizer_name="Qwen/Qwen2.5-7B",
|
||||
group_size=8,
|
||||
|
|
@ -147,14 +147,10 @@ class MathEnv(BaseEnv):
|
|||
eval_handling=EvalHandlingEnum.LIMIT_TRAIN,
|
||||
eval_limit_ratio=0.1,
|
||||
)
|
||||
server_configs = [
|
||||
OpenaiConfig(
|
||||
model_name="default",
|
||||
base_url="http://localhost:9004/v1",
|
||||
api_key="x",
|
||||
num_requests_for_eval=256, # since evaling only on one...
|
||||
),
|
||||
]
|
||||
server_configs = ServerBaseline(
|
||||
model_name="default",
|
||||
num_requests_for_eval=256, # since evaling only on one...
|
||||
)
|
||||
|
||||
return env_config, server_configs
|
||||
|
||||
|
|
@ -222,8 +218,8 @@ class MathEnv(BaseEnv):
|
|||
)
|
||||
)
|
||||
for name, t_dataset in zip(
|
||||
["amc23", "minerva", "olympiad"],
|
||||
[amc_test_data, minerva_test_data, olympiad_test_data],
|
||||
["amc23", "minerva"],
|
||||
[amc_test_data, minerva_test_data],
|
||||
):
|
||||
for item in t_dataset:
|
||||
self.test.append(
|
||||
|
|
@ -235,6 +231,17 @@ class MathEnv(BaseEnv):
|
|||
name,
|
||||
)
|
||||
)
|
||||
for name, t_dataset in zip(["olympiad"], [olympiad_test_data]):
|
||||
for item in t_dataset:
|
||||
self.test.append(
|
||||
(
|
||||
prompt_format.format(
|
||||
prompt=problem_format.format(problem=item["question"])
|
||||
),
|
||||
item["final_answer"][0],
|
||||
name,
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
async def rollout_and_score_eval(self, question, answer, subset):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue