diff --git a/atroposlib/envs/base.py b/atroposlib/envs/base.py index 59c2adf3..25183f6b 100644 --- a/atroposlib/envs/base.py +++ b/atroposlib/envs/base.py @@ -587,58 +587,74 @@ class BaseEnv(ABC): Send the chats to the API with robust error handling and support for multiple ScoredDataGroups. Args: - scored_data: List of scored items to send + scored_data: Single ScoredDataGroup or List of ScoredDataGroups to send item: Optional item for context + do_send_to_api: Whether to send the data to the API + abort_on_any_max_length_exceeded: Whether to abort if any token length exceeds the max """ - group_size = scored_data.get("group_overrides", {}).get( - "group_size", self.config.group_size - ) - if ( - (scored_data is not None) - and (None not in scored_data) - and (len(scored_data["tokens"]) == group_size) - ): - if self.config.ensure_scores_are_not_same: - if len(set(scored_data["scores"])) == 1: - # Scores are the same, don't send to API - logger.warning("Scores are the same, skipping...") - return - await self.add_rollouts_for_wandb(scored_data, item) - # Check for ref_logprobs - if "ref_logprobs" not in scored_data: - # Strongly typed dict, so we need to add it - scored_data["ref_logprobs"] = None - if "overrides" not in scored_data: - scored_data["overrides"] = None - if "group_overrides" not in scored_data: - scored_data["group_overrides"] = None + original_was_list = isinstance(scored_data, list) # not sure if this is needed + data_to_process = scored_data if original_was_list else [scored_data] - # Track completion lengths - for mask in scored_data["masks"]: + valid_groups = [] + for group in data_to_process: + if group is None: + continue + + group_size = group.get("group_overrides", {}).get( + "group_size", self.config.group_size + ) + + if not ((None not in group) and (len(group.get("tokens", [])) == group_size)): + logger.warning( + f"Group structure invalid, or token count mismatch (expected {group_size}), or 'tokens' key missing. Skipping group: {str(group)[:200]}..." + ) + continue + + if self.config.ensure_scores_are_not_same and len(set(group["scores"])) == 1: + logger.warning("Scores are the same in a group, skipping...") + continue + + group.setdefault("ref_logprobs", None) + group.setdefault("overrides", None) + group.setdefault("group_overrides", None) + + for mask in group["masks"]: self.completion_lengths.append(len(mask)) - # Add the scores to the queue + if abort_on_any_max_length_exceeded and any( - [len(x) >= self.max_token_len for x in scored_data["tokens"]] + [len(x) >= self.max_token_len for x in group["tokens"]] ): - # Don't send to API if the token length is too long - logger.warning("Token length is too long, skipping...") - return - # Save data, if applicable: - if self.config.include_messages and scored_data.get("messages") is None: - scored_data["messages"] = [ - self.tokenizer.decode(scored_data["tokens"][i]) - for i in range(group_size) + logger.warning("Token length is too long in a group, skipping...") + continue + + if self.config.include_messages and group.get("messages") is None: + group["messages"] = [ + self.tokenizer.decode(group["tokens"][i]) + for i in range(len(group["tokens"])) ] + + await self.add_rollouts_for_wandb(group, item) + if self.jsonl_writer is not None: - self.jsonl_writer.write(scored_data) + self.jsonl_writer.write(group) print(f"Wrote scored group to {self.config.data_path_to_save_groups}") - # Send data with retries and error handling + + valid_groups.append(group) + + if valid_groups and do_send_to_api: + data_to_send_to_api: Union[ScoredDataGroup, List[ScoredDataGroup]] + # send single or list of scored data groups + if not original_was_list and len(valid_groups) == 1: + data_to_send_to_api = valid_groups[0] + else: + data_to_send_to_api = valid_groups + try: - if do_send_to_api: - self.items_sent_this_step += 1 - await self._send_scored_data_to_api(scored_data) + self.items_sent_this_step += len(valid_groups) + await self._send_scored_data_to_api(data_to_send_to_api) except (Exception, TimeoutError) as e: - print(f"Failed to send scored data after retries: {e}") + data_type_str = "single ScoredDataGroup" if isinstance(data_to_send_to_api, dict) else f"{len(data_to_send_to_api)} ScoredDataGroups" + print(f"Failed to send {data_type_str} after retries: {e}") async def handle_env( self, item_uuid: str