add multi-turn interaction.

This commit is contained in:
97hongjun 2025-05-18 17:08:19 -07:00
parent 128dce55bc
commit c818de8bec

View file

@ -191,41 +191,47 @@ class GSM8kEnv(BaseEnv):
self, item: CatRow
) -> Tuple[ScoredDataGroup, list[Item]]:
user_message = {"role": "user", "content": item["scenario"]}
# gold_answer = (
# "\\boxed{" + item["answer"].split("#")[-1].strip().replace(",", "") + "}"
# )
cat_completions = await self.server.chat_completion(
messages=[{"role": "system", "content": cat_system_prompt}, user_message],
n=self.config.group_size,
max_tokens=self.config.max_token_length,
)
for i, cat_completion in enumerate(cat_completions.choices):
if i == 0:
cat_message = cat_completion.message.content
caretaker_message = {"role": "user", "content": cat_message}
caretaker_completions = await self.server.chat_completion(
messages=[{"role": "system", "content": caretaker_system_prompt}, caretaker_message],
n=self.config.group_size,
max_tokens=self.config.max_token_length,
)
to_score = list()
to_backlog = list()
for i, caretaker_completion in enumerate(caretaker_completions.choices):
messages = (
{"role": "system", "content": cat_system_prompt},
user_message,
{"role": "system", "content": cat_message},
{"role": "assistant", "content": caretaker_completion.message.content},
)
to_score.append(
{
"messages": messages,
}
)
for j in range(self.config.group_size):
all_messages = []
history = []
cat_history = [user_message]
for i in range(5):
cat_completions = await self.server.chat_completion(
messages=[{"role": "system", "content": cat_system_prompt}] + cat_history,
n=self.config.group_size,
max_tokens=self.config.max_token_length,
)
for i, cat_completion in enumerate(cat_completions.choices):
if i == 0:
cat_message = cat_completion.message.content
cat_response = {"role": "system", "content": cat_message}
cat_history.append(cat_response)
caretaker_message = {"role": "user", "content": cat_message}
history.append(caretaker_message)
caretaker_completions = await self.server.chat_completion(
messages=[{"role": "system", "content": caretaker_system_prompt}] + history,
n=1,
max_tokens=self.config.max_token_length,
)
caretaker_response = {"role": "assistant", "content": caretaker_completions.choices[0].message.content}
cat_history.append(caretaker_response)
history.append(caretaker_response)
messages = [
{"role": "system", "content": cat_system_prompt},
user_message,
cat_response,
caretaker_response
]
all_messages.extend(messages)
all_messages = tuple(all_messages)
to_score.append({
"messages": all_messages,
})
import pdb; pdb.set_trace()
to_postprocess = await self.score(to_score)
# import pdb; pdb.set_trace()
return to_postprocess, to_backlog