mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-27 17:23:08 +00:00
add multi-turn interaction.
This commit is contained in:
parent
128dce55bc
commit
c818de8bec
1 changed files with 39 additions and 33 deletions
|
|
@ -191,41 +191,47 @@ class GSM8kEnv(BaseEnv):
|
|||
self, item: CatRow
|
||||
) -> Tuple[ScoredDataGroup, list[Item]]:
|
||||
user_message = {"role": "user", "content": item["scenario"]}
|
||||
# gold_answer = (
|
||||
# "\\boxed{" + item["answer"].split("#")[-1].strip().replace(",", "") + "}"
|
||||
# )
|
||||
|
||||
cat_completions = await self.server.chat_completion(
|
||||
messages=[{"role": "system", "content": cat_system_prompt}, user_message],
|
||||
n=self.config.group_size,
|
||||
max_tokens=self.config.max_token_length,
|
||||
)
|
||||
|
||||
for i, cat_completion in enumerate(cat_completions.choices):
|
||||
if i == 0:
|
||||
cat_message = cat_completion.message.content
|
||||
|
||||
caretaker_message = {"role": "user", "content": cat_message}
|
||||
caretaker_completions = await self.server.chat_completion(
|
||||
messages=[{"role": "system", "content": caretaker_system_prompt}, caretaker_message],
|
||||
n=self.config.group_size,
|
||||
max_tokens=self.config.max_token_length,
|
||||
)
|
||||
|
||||
to_score = list()
|
||||
to_backlog = list()
|
||||
for i, caretaker_completion in enumerate(caretaker_completions.choices):
|
||||
messages = (
|
||||
{"role": "system", "content": cat_system_prompt},
|
||||
user_message,
|
||||
{"role": "system", "content": cat_message},
|
||||
{"role": "assistant", "content": caretaker_completion.message.content},
|
||||
)
|
||||
to_score.append(
|
||||
{
|
||||
"messages": messages,
|
||||
}
|
||||
)
|
||||
for j in range(self.config.group_size):
|
||||
all_messages = []
|
||||
history = []
|
||||
cat_history = [user_message]
|
||||
for i in range(5):
|
||||
cat_completions = await self.server.chat_completion(
|
||||
messages=[{"role": "system", "content": cat_system_prompt}] + cat_history,
|
||||
n=self.config.group_size,
|
||||
max_tokens=self.config.max_token_length,
|
||||
)
|
||||
|
||||
for i, cat_completion in enumerate(cat_completions.choices):
|
||||
if i == 0:
|
||||
cat_message = cat_completion.message.content
|
||||
cat_response = {"role": "system", "content": cat_message}
|
||||
cat_history.append(cat_response)
|
||||
caretaker_message = {"role": "user", "content": cat_message}
|
||||
history.append(caretaker_message)
|
||||
caretaker_completions = await self.server.chat_completion(
|
||||
messages=[{"role": "system", "content": caretaker_system_prompt}] + history,
|
||||
n=1,
|
||||
max_tokens=self.config.max_token_length,
|
||||
)
|
||||
caretaker_response = {"role": "assistant", "content": caretaker_completions.choices[0].message.content}
|
||||
cat_history.append(caretaker_response)
|
||||
history.append(caretaker_response)
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": cat_system_prompt},
|
||||
user_message,
|
||||
cat_response,
|
||||
caretaker_response
|
||||
]
|
||||
all_messages.extend(messages)
|
||||
all_messages = tuple(all_messages)
|
||||
to_score.append({
|
||||
"messages": all_messages,
|
||||
})
|
||||
import pdb; pdb.set_trace()
|
||||
to_postprocess = await self.score(to_score)
|
||||
# import pdb; pdb.set_trace()
|
||||
return to_postprocess, to_backlog
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue