diff --git a/environments/gsm8k_server.py b/environments/gsm8k_server.py index 24e650fb..6b2aae62 100644 --- a/environments/gsm8k_server.py +++ b/environments/gsm8k_server.py @@ -12,7 +12,7 @@ from atroposlib.envs.base import ( BaseEnvConfig, ScoredDataGroup, ) -from atroposlib.type_definitions import Item, number +from atroposlib.type_definitions import Item from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer system_prompt = ( @@ -120,45 +120,7 @@ class GSM8kEnv(BaseEnv): data["iter"] = self.iter super().save_checkpoint(step, data) - async def rollout_and_score_eval(self, question: str, answer: str) -> number: - completion = await self.server.chat_completion( - messages=[ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": question}, - ], - n=1, - max_tokens=self.config.max_token_length, - temperature=0.0, - split="eval", - ) - gold_parsed = parse( - "\\boxed{" + answer + "}", - extraction_mode="first_match", - extraction_config=[LatexExtractionConfig()], - ) - answer_parsed = parse( - completion.choices[0].message.content.split("")[-1], - extraction_config=[ - LatexExtractionConfig( - normalization_config=NormalizationConfig( - nits=False, - malformed_operators=False, - basic_latex=True, - equations=True, - boxed="all", - units=True, - ), - # Ensures that boxed is tried first - boxed_match_priority=0, - try_extract_without_anchor=False, - ) - ], - extraction_mode="first_match", - ) - score = 1 if verify(answer_parsed, gold_parsed) else 0 - return score - - async def rollout_and_score_eval_detailed(self, question: str, answer: str) -> dict: + async def rollout_and_score_eval(self, question: str, answer: str) -> dict: """Rollout and score evaluation with detailed sample data collection.""" completion = await self.server.chat_completion( messages=[ @@ -234,9 +196,7 @@ class GSM8kEnv(BaseEnv): eval_tasks = [] for item in self.test: eval_tasks.append( - self.rollout_and_score_eval_detailed( - item["question"], item["gold_answer"] - ) + self.rollout_and_score_eval(item["question"], item["gold_answer"]) ) results = await tqdm_asyncio.gather(*eval_tasks)