mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
update handle_send_to_api
This commit is contained in:
parent
9efd8c1529
commit
6c6a1c5d06
1 changed files with 57 additions and 41 deletions
|
|
@ -587,58 +587,74 @@ class BaseEnv(ABC):
|
|||
Send the chats to the API with robust error handling and support for multiple ScoredDataGroups.
|
||||
|
||||
Args:
|
||||
scored_data: List of scored items to send
|
||||
scored_data: Single ScoredDataGroup or List of ScoredDataGroups to send
|
||||
item: Optional item for context
|
||||
do_send_to_api: Whether to send the data to the API
|
||||
abort_on_any_max_length_exceeded: Whether to abort if any token length exceeds the max
|
||||
"""
|
||||
group_size = scored_data.get("group_overrides", {}).get(
|
||||
"group_size", self.config.group_size
|
||||
)
|
||||
if (
|
||||
(scored_data is not None)
|
||||
and (None not in scored_data)
|
||||
and (len(scored_data["tokens"]) == group_size)
|
||||
):
|
||||
if self.config.ensure_scores_are_not_same:
|
||||
if len(set(scored_data["scores"])) == 1:
|
||||
# Scores are the same, don't send to API
|
||||
logger.warning("Scores are the same, skipping...")
|
||||
return
|
||||
await self.add_rollouts_for_wandb(scored_data, item)
|
||||
# Check for ref_logprobs
|
||||
if "ref_logprobs" not in scored_data:
|
||||
# Strongly typed dict, so we need to add it
|
||||
scored_data["ref_logprobs"] = None
|
||||
if "overrides" not in scored_data:
|
||||
scored_data["overrides"] = None
|
||||
if "group_overrides" not in scored_data:
|
||||
scored_data["group_overrides"] = None
|
||||
original_was_list = isinstance(scored_data, list) # not sure if this is needed
|
||||
data_to_process = scored_data if original_was_list else [scored_data]
|
||||
|
||||
# Track completion lengths
|
||||
for mask in scored_data["masks"]:
|
||||
valid_groups = []
|
||||
for group in data_to_process:
|
||||
if group is None:
|
||||
continue
|
||||
|
||||
group_size = group.get("group_overrides", {}).get(
|
||||
"group_size", self.config.group_size
|
||||
)
|
||||
|
||||
if not ((None not in group) and (len(group.get("tokens", [])) == group_size)):
|
||||
logger.warning(
|
||||
f"Group structure invalid, or token count mismatch (expected {group_size}), or 'tokens' key missing. Skipping group: {str(group)[:200]}..."
|
||||
)
|
||||
continue
|
||||
|
||||
if self.config.ensure_scores_are_not_same and len(set(group["scores"])) == 1:
|
||||
logger.warning("Scores are the same in a group, skipping...")
|
||||
continue
|
||||
|
||||
group.setdefault("ref_logprobs", None)
|
||||
group.setdefault("overrides", None)
|
||||
group.setdefault("group_overrides", None)
|
||||
|
||||
for mask in group["masks"]:
|
||||
self.completion_lengths.append(len(mask))
|
||||
# Add the scores to the queue
|
||||
|
||||
if abort_on_any_max_length_exceeded and any(
|
||||
[len(x) >= self.max_token_len for x in scored_data["tokens"]]
|
||||
[len(x) >= self.max_token_len for x in group["tokens"]]
|
||||
):
|
||||
# Don't send to API if the token length is too long
|
||||
logger.warning("Token length is too long, skipping...")
|
||||
return
|
||||
# Save data, if applicable:
|
||||
if self.config.include_messages and scored_data.get("messages") is None:
|
||||
scored_data["messages"] = [
|
||||
self.tokenizer.decode(scored_data["tokens"][i])
|
||||
for i in range(group_size)
|
||||
logger.warning("Token length is too long in a group, skipping...")
|
||||
continue
|
||||
|
||||
if self.config.include_messages and group.get("messages") is None:
|
||||
group["messages"] = [
|
||||
self.tokenizer.decode(group["tokens"][i])
|
||||
for i in range(len(group["tokens"]))
|
||||
]
|
||||
|
||||
await self.add_rollouts_for_wandb(group, item)
|
||||
|
||||
if self.jsonl_writer is not None:
|
||||
self.jsonl_writer.write(scored_data)
|
||||
self.jsonl_writer.write(group)
|
||||
print(f"Wrote scored group to {self.config.data_path_to_save_groups}")
|
||||
# Send data with retries and error handling
|
||||
|
||||
valid_groups.append(group)
|
||||
|
||||
if valid_groups and do_send_to_api:
|
||||
data_to_send_to_api: Union[ScoredDataGroup, List[ScoredDataGroup]]
|
||||
# send single or list of scored data groups
|
||||
if not original_was_list and len(valid_groups) == 1:
|
||||
data_to_send_to_api = valid_groups[0]
|
||||
else:
|
||||
data_to_send_to_api = valid_groups
|
||||
|
||||
try:
|
||||
if do_send_to_api:
|
||||
self.items_sent_this_step += 1
|
||||
await self._send_scored_data_to_api(scored_data)
|
||||
self.items_sent_this_step += len(valid_groups)
|
||||
await self._send_scored_data_to_api(data_to_send_to_api)
|
||||
except (Exception, TimeoutError) as e:
|
||||
print(f"Failed to send scored data after retries: {e}")
|
||||
data_type_str = "single ScoredDataGroup" if isinstance(data_to_send_to_api, dict) else f"{len(data_to_send_to_api)} ScoredDataGroups"
|
||||
print(f"Failed to send {data_type_str} after retries: {e}")
|
||||
|
||||
async def handle_env(
|
||||
self, item_uuid: str
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue