diff --git a/environments/gsm8k_server.py b/environments/gsm8k_server.py index 766e27ec..6ae5285b 100644 --- a/environments/gsm8k_server.py +++ b/environments/gsm8k_server.py @@ -272,6 +272,7 @@ class GSM8kEnv(BaseEnv): scores["tokens"] = list() scores["masks"] = list() scores["scores"] = list() + scores["inference_logprobs"] = list() gold_parsed = parse( rollout_group_data[0]["gold_answer"], extraction_mode="first_match",