Making saving data optional in config, add scores to saved data

This commit is contained in:
teknium1 2025-05-23 14:11:11 -07:00
parent a20886d720
commit 1fa798a69e

View file

@ -139,6 +139,10 @@ class SWERLEnvConfig(BaseEnvConfig):
default=0.20,
description="The train/avg_patch_format_accuracy threshold at which to switch from ICL to standard prompt.",
)
dump_rollouts: bool = Field(
default=False,
description="Whether to dump rollouts to JSONL files.",
)
class SWERLEnv(BaseEnv):
@ -182,7 +186,7 @@ class SWERLEnv(BaseEnv):
# For saving rollouts to JSONL
self.rollouts_to_save_buffer: List[
Dict[str, Union[str, List[List[Dict[str, str]]]]]
Dict[str, Union[str, List[Dict[str, Union[List[Dict[str, str]], float]]]]]
] = []
self.processed_item_count = 0
# Creates .../atropos/environments/datadumps/ relative to the project structure
@ -442,56 +446,71 @@ class SWERLEnv(BaseEnv):
temperature=0.8,
stop=stop_tokens,
)
to_score = []
if not completions or not completions.choices:
self.logger.warning(f"No completions received for item_id: {item_id}")
return None, []
# Prepare to collect all conversations for this item
all_conversations_for_this_item = []
# Prepare to collect all conversations and their potential scores for this item
# This list will hold tuples of (conversation_messages, oracle_patch, finish_reason)
# which is the input format expected by the self.score method.
raw_rollouts_for_scoring = []
for choice in completions.choices:
# The 'messages_for_llm_prompt' already contains the full context
# (system, ICL example if used, current user query)
# We just need to add the assistant's response to this full context.
current_trajectory_messages = messages_for_llm_prompt + [
{"role": "assistant", "content": choice.text.strip()}
]
to_score.append(
raw_rollouts_for_scoring.append(
(current_trajectory_messages, oracle_patch, choice.finish_reason)
)
all_conversations_for_this_item.append(current_trajectory_messages)
if not to_score:
if not raw_rollouts_for_scoring:
return None, []
# If rollouts were generated for this item, create a single record
# for this item containing all its conversations.
if all_conversations_for_this_item:
item_data_to_save = {
"item_id": item_id,
"conversations": all_conversations_for_this_item,
}
self.rollouts_to_save_buffer.append(item_data_to_save)
self.processed_item_count += 1
# Score all generated rollouts for the current item
scored_data = await self.score(raw_rollouts_for_scoring)
# Check if it's time to save a batch of rollouts
if self.processed_item_count % 100 == 0 and self.processed_item_count > 0:
log_msg = (
f"Reached {self.processed_item_count} processed items. "
f"Triggering save for {len(self.rollouts_to_save_buffer)} items "
f"(each with multiple conversations)."
# If rollouts were generated and scored, and data dumping is enabled,
# prepare them for saving.
if scored_data and self.config.dump_rollouts:
rollouts_with_scores_to_save = []
num_scored_rollouts = len(scored_data.get("scores", []))
for i in range(num_scored_rollouts):
# raw_rollouts_for_scoring[i][0] is the list of message dicts for the i-th rollout
conversation_messages = raw_rollouts_for_scoring[i][0]
score_for_rollout = scored_data["scores"][i]
rollouts_with_scores_to_save.append(
{
"conversation": conversation_messages, # Full conversation history
"score": score_for_rollout,
}
)
self.logger.info(log_msg)
await self._save_rollouts_to_jsonl()
scored_data = await self.score(to_score)
if rollouts_with_scores_to_save:
item_data_to_save = {
"item_id": item_id,
"rollouts": rollouts_with_scores_to_save, # Changed from "conversations"
}
self.rollouts_to_save_buffer.append(item_data_to_save)
self.processed_item_count += 1
# Check if it's time to save a batch of rollouts
if (
self.config.dump_rollouts
and self.processed_item_count % 100 == 0
and self.processed_item_count > 0
):
log_msg = (
f"Reached {self.processed_item_count} processed items. "
f"Triggering save for {len(self.rollouts_to_save_buffer)} items "
f"(each with multiple scored rollouts)."
)
self.logger.info(log_msg)
await self._save_rollouts_to_jsonl()
if scored_data and item_id != "unknown_item":
# The item_ids in scored_data should still refer to the base item_id,
# as scoring happens per completion, but the saved data groups them.
# This part might need adjustment if scores are also grouped in saved data,
# but current requirement is only for conversations.
scored_data["item_ids"] = [item_id] * len(scored_data.get("scores", []))
return scored_data, []
def _extract_content_after_think_tags(
@ -1082,7 +1101,9 @@ class SWERLEnv(BaseEnv):
self.logger.info(
"Closing SWERLEnv. Attempting to save any remaining rollouts..."
)
if self.rollouts_to_save_buffer: # Check if there's anything to save
if (
self.config.dump_rollouts and self.rollouts_to_save_buffer
): # Check if there's anything to save
self.logger.info(
f"Found {len(self.rollouts_to_save_buffer)} rollouts in buffer. Saving now."
)