mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-28 17:29:30 +00:00
Making saving data optional in config, add scores to saved data
This commit is contained in:
parent
a20886d720
commit
1fa798a69e
1 changed files with 54 additions and 33 deletions
|
|
@ -139,6 +139,10 @@ class SWERLEnvConfig(BaseEnvConfig):
|
|||
default=0.20,
|
||||
description="The train/avg_patch_format_accuracy threshold at which to switch from ICL to standard prompt.",
|
||||
)
|
||||
dump_rollouts: bool = Field(
|
||||
default=False,
|
||||
description="Whether to dump rollouts to JSONL files.",
|
||||
)
|
||||
|
||||
|
||||
class SWERLEnv(BaseEnv):
|
||||
|
|
@ -182,7 +186,7 @@ class SWERLEnv(BaseEnv):
|
|||
|
||||
# For saving rollouts to JSONL
|
||||
self.rollouts_to_save_buffer: List[
|
||||
Dict[str, Union[str, List[List[Dict[str, str]]]]]
|
||||
Dict[str, Union[str, List[Dict[str, Union[List[Dict[str, str]], float]]]]]
|
||||
] = []
|
||||
self.processed_item_count = 0
|
||||
# Creates .../atropos/environments/datadumps/ relative to the project structure
|
||||
|
|
@ -442,56 +446,71 @@ class SWERLEnv(BaseEnv):
|
|||
temperature=0.8,
|
||||
stop=stop_tokens,
|
||||
)
|
||||
to_score = []
|
||||
if not completions or not completions.choices:
|
||||
self.logger.warning(f"No completions received for item_id: {item_id}")
|
||||
return None, []
|
||||
|
||||
# Prepare to collect all conversations for this item
|
||||
all_conversations_for_this_item = []
|
||||
# Prepare to collect all conversations and their potential scores for this item
|
||||
# This list will hold tuples of (conversation_messages, oracle_patch, finish_reason)
|
||||
# which is the input format expected by the self.score method.
|
||||
raw_rollouts_for_scoring = []
|
||||
|
||||
for choice in completions.choices:
|
||||
# The 'messages_for_llm_prompt' already contains the full context
|
||||
# (system, ICL example if used, current user query)
|
||||
# We just need to add the assistant's response to this full context.
|
||||
current_trajectory_messages = messages_for_llm_prompt + [
|
||||
{"role": "assistant", "content": choice.text.strip()}
|
||||
]
|
||||
to_score.append(
|
||||
raw_rollouts_for_scoring.append(
|
||||
(current_trajectory_messages, oracle_patch, choice.finish_reason)
|
||||
)
|
||||
all_conversations_for_this_item.append(current_trajectory_messages)
|
||||
|
||||
if not to_score:
|
||||
if not raw_rollouts_for_scoring:
|
||||
return None, []
|
||||
|
||||
# If rollouts were generated for this item, create a single record
|
||||
# for this item containing all its conversations.
|
||||
if all_conversations_for_this_item:
|
||||
item_data_to_save = {
|
||||
"item_id": item_id,
|
||||
"conversations": all_conversations_for_this_item,
|
||||
}
|
||||
self.rollouts_to_save_buffer.append(item_data_to_save)
|
||||
self.processed_item_count += 1
|
||||
# Score all generated rollouts for the current item
|
||||
scored_data = await self.score(raw_rollouts_for_scoring)
|
||||
|
||||
# Check if it's time to save a batch of rollouts
|
||||
if self.processed_item_count % 100 == 0 and self.processed_item_count > 0:
|
||||
log_msg = (
|
||||
f"Reached {self.processed_item_count} processed items. "
|
||||
f"Triggering save for {len(self.rollouts_to_save_buffer)} items "
|
||||
f"(each with multiple conversations)."
|
||||
# If rollouts were generated and scored, and data dumping is enabled,
|
||||
# prepare them for saving.
|
||||
if scored_data and self.config.dump_rollouts:
|
||||
rollouts_with_scores_to_save = []
|
||||
|
||||
num_scored_rollouts = len(scored_data.get("scores", []))
|
||||
for i in range(num_scored_rollouts):
|
||||
# raw_rollouts_for_scoring[i][0] is the list of message dicts for the i-th rollout
|
||||
conversation_messages = raw_rollouts_for_scoring[i][0]
|
||||
score_for_rollout = scored_data["scores"][i]
|
||||
rollouts_with_scores_to_save.append(
|
||||
{
|
||||
"conversation": conversation_messages, # Full conversation history
|
||||
"score": score_for_rollout,
|
||||
}
|
||||
)
|
||||
self.logger.info(log_msg)
|
||||
await self._save_rollouts_to_jsonl()
|
||||
|
||||
scored_data = await self.score(to_score)
|
||||
if rollouts_with_scores_to_save:
|
||||
item_data_to_save = {
|
||||
"item_id": item_id,
|
||||
"rollouts": rollouts_with_scores_to_save, # Changed from "conversations"
|
||||
}
|
||||
self.rollouts_to_save_buffer.append(item_data_to_save)
|
||||
self.processed_item_count += 1
|
||||
|
||||
# Check if it's time to save a batch of rollouts
|
||||
if (
|
||||
self.config.dump_rollouts
|
||||
and self.processed_item_count % 100 == 0
|
||||
and self.processed_item_count > 0
|
||||
):
|
||||
log_msg = (
|
||||
f"Reached {self.processed_item_count} processed items. "
|
||||
f"Triggering save for {len(self.rollouts_to_save_buffer)} items "
|
||||
f"(each with multiple scored rollouts)."
|
||||
)
|
||||
self.logger.info(log_msg)
|
||||
await self._save_rollouts_to_jsonl()
|
||||
|
||||
if scored_data and item_id != "unknown_item":
|
||||
# The item_ids in scored_data should still refer to the base item_id,
|
||||
# as scoring happens per completion, but the saved data groups them.
|
||||
# This part might need adjustment if scores are also grouped in saved data,
|
||||
# but current requirement is only for conversations.
|
||||
scored_data["item_ids"] = [item_id] * len(scored_data.get("scores", []))
|
||||
|
||||
return scored_data, []
|
||||
|
||||
def _extract_content_after_think_tags(
|
||||
|
|
@ -1082,7 +1101,9 @@ class SWERLEnv(BaseEnv):
|
|||
self.logger.info(
|
||||
"Closing SWERLEnv. Attempting to save any remaining rollouts..."
|
||||
)
|
||||
if self.rollouts_to_save_buffer: # Check if there's anything to save
|
||||
if (
|
||||
self.config.dump_rollouts and self.rollouts_to_save_buffer
|
||||
): # Check if there's anything to save
|
||||
self.logger.info(
|
||||
f"Found {len(self.rollouts_to_save_buffer)} rollouts in buffer. Saving now."
|
||||
)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue