diff --git a/environments/gsm8k_server.py b/environments/gsm8k_server.py index d0470985..1eaaa3fd 100644 --- a/environments/gsm8k_server.py +++ b/environments/gsm8k_server.py @@ -122,18 +122,19 @@ class GSM8kEnv(BaseEnv): async def rollout_and_score_eval(self, question: str, answer: str) -> dict: """Rollout and score evaluation with detailed sample data collection.""" - completion = await self.server.chat_completion( - messages=[ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": question}, - ], - n=1, - max_tokens=self.config.max_token_length, - temperature=0.0, - split="eval", - ) + + async with self.server.managed_server(tokenizer=self.tokenizer) as managed: + completion = await managed.chat_completion( + messages=[ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": question}, + ], + n=1, + max_tokens=self.config.max_token_length, + temperature=0.0, + ) - response_content = completion.choices[0].message.content + response_content = completion.choices[0].message.content # Parse gold answer gold_parsed = parse(