diff --git a/environments/gsm8k_server.py b/environments/gsm8k_server.py index 941526ce..b528a880 100644 --- a/environments/gsm8k_server.py +++ b/environments/gsm8k_server.py @@ -309,7 +309,6 @@ class GSM8kEnv(BaseEnv): masks = item["masks"] logprobs = item["logprobs"] - # remove obviously bad examples if len([1 for i in masks if i != -100]) < 10: continue @@ -317,7 +316,7 @@ class GSM8kEnv(BaseEnv): scores["masks"].append(masks) scores["inference_logprobs"].append(logprobs) scores["scores"].append(1.0 if reward else -1.0) - + if len(scores["tokens"]) >= self.config.group_size: break