update handle_send_to_api

This commit is contained in:
Shannon Sands 2025-05-10 09:07:54 +10:00
parent 9efd8c1529
commit 6c6a1c5d06

View file

@ -587,58 +587,74 @@ class BaseEnv(ABC):
Send the chats to the API with robust error handling and support for multiple ScoredDataGroups.
Args:
scored_data: List of scored items to send
scored_data: Single ScoredDataGroup or List of ScoredDataGroups to send
item: Optional item for context
do_send_to_api: Whether to send the data to the API
abort_on_any_max_length_exceeded: Whether to abort if any token length exceeds the max
"""
group_size = scored_data.get("group_overrides", {}).get(
"group_size", self.config.group_size
)
if (
(scored_data is not None)
and (None not in scored_data)
and (len(scored_data["tokens"]) == group_size)
):
if self.config.ensure_scores_are_not_same:
if len(set(scored_data["scores"])) == 1:
# Scores are the same, don't send to API
logger.warning("Scores are the same, skipping...")
return
await self.add_rollouts_for_wandb(scored_data, item)
# Check for ref_logprobs
if "ref_logprobs" not in scored_data:
# Strongly typed dict, so we need to add it
scored_data["ref_logprobs"] = None
if "overrides" not in scored_data:
scored_data["overrides"] = None
if "group_overrides" not in scored_data:
scored_data["group_overrides"] = None
original_was_list = isinstance(scored_data, list) # not sure if this is needed
data_to_process = scored_data if original_was_list else [scored_data]
# Track completion lengths
for mask in scored_data["masks"]:
valid_groups = []
for group in data_to_process:
if group is None:
continue
group_size = group.get("group_overrides", {}).get(
"group_size", self.config.group_size
)
if not ((None not in group) and (len(group.get("tokens", [])) == group_size)):
logger.warning(
f"Group structure invalid, or token count mismatch (expected {group_size}), or 'tokens' key missing. Skipping group: {str(group)[:200]}..."
)
continue
if self.config.ensure_scores_are_not_same and len(set(group["scores"])) == 1:
logger.warning("Scores are the same in a group, skipping...")
continue
group.setdefault("ref_logprobs", None)
group.setdefault("overrides", None)
group.setdefault("group_overrides", None)
for mask in group["masks"]:
self.completion_lengths.append(len(mask))
# Add the scores to the queue
if abort_on_any_max_length_exceeded and any(
[len(x) >= self.max_token_len for x in scored_data["tokens"]]
[len(x) >= self.max_token_len for x in group["tokens"]]
):
# Don't send to API if the token length is too long
logger.warning("Token length is too long, skipping...")
return
# Save data, if applicable:
if self.config.include_messages and scored_data.get("messages") is None:
scored_data["messages"] = [
self.tokenizer.decode(scored_data["tokens"][i])
for i in range(group_size)
logger.warning("Token length is too long in a group, skipping...")
continue
if self.config.include_messages and group.get("messages") is None:
group["messages"] = [
self.tokenizer.decode(group["tokens"][i])
for i in range(len(group["tokens"]))
]
await self.add_rollouts_for_wandb(group, item)
if self.jsonl_writer is not None:
self.jsonl_writer.write(scored_data)
self.jsonl_writer.write(group)
print(f"Wrote scored group to {self.config.data_path_to_save_groups}")
# Send data with retries and error handling
valid_groups.append(group)
if valid_groups and do_send_to_api:
data_to_send_to_api: Union[ScoredDataGroup, List[ScoredDataGroup]]
# send single or list of scored data groups
if not original_was_list and len(valid_groups) == 1:
data_to_send_to_api = valid_groups[0]
else:
data_to_send_to_api = valid_groups
try:
if do_send_to_api:
self.items_sent_this_step += 1
await self._send_scored_data_to_api(scored_data)
self.items_sent_this_step += len(valid_groups)
await self._send_scored_data_to_api(data_to_send_to_api)
except (Exception, TimeoutError) as e:
print(f"Failed to send scored data after retries: {e}")
data_type_str = "single ScoredDataGroup" if isinstance(data_to_send_to_api, dict) else f"{len(data_to_send_to_api)} ScoredDataGroups"
print(f"Failed to send {data_type_str} after retries: {e}")
async def handle_env(
self, item_uuid: str