diff --git a/environments/cat_server.py b/environments/cat_server.py index 5fdcddf6..58fe3c3f 100644 --- a/environments/cat_server.py +++ b/environments/cat_server.py @@ -260,7 +260,7 @@ class GSM8kEnv(BaseEnv): scores["scores"] = list() # # random.shuffle(rollout_group_data) for item in rollout_group_data: - final_question = list(item["messages"]) + [{'role': 'system', 'content': 'The conversation is over. Say meow if the caretaker did a bad job, or purr if the caretaker did a good job.'}] + final_question = list(item["messages"]) + [{'role': 'system', 'content': 'The conversation is over. Say purr if the caretaker did everything perfectly and there was nothing that the caretaker could have done even slightly better. Otherwise, say meow. Make sure it is rare that you rate the caretaker with a purr.'}] caretaker_completions = await self.server.chat_completion( messages=final_question, n=1,