diff --git a/atroposlib/envs/base.py b/atroposlib/envs/base.py index 64d2eb04..825baf62 100644 --- a/atroposlib/envs/base.py +++ b/atroposlib/envs/base.py @@ -758,11 +758,7 @@ class BaseEnv(ABC): else: scored_data["env_id"] = getattr(self, "env_id", None) - url = ( - f"{self.config.rollout_server_url}/scored_data_list" - if isinstance(scored_data, list) - else f"{self.config.rollout_server_url}/scored_data" - ) + url = f"{self.config.rollout_server_url}/scored_data" async with aiohttp.ClientSession() as session: async with session.post( url, @@ -795,24 +791,39 @@ class BaseEnv(ABC): do_send_to_api: Whether to send the data to the API abort_on_any_max_length_exceeded: Whether to abort if any token length exceeds the max """ - original_was_list = isinstance(scored_data, list) # not sure if this is needed - data_to_process = scored_data if original_was_list else [scored_data] + # Ensure we're working with a list of groups + data_to_process = ( + scored_data if isinstance(scored_data, list) else [scored_data] + ) valid_groups = [] for group in data_to_process: if group is None: continue - group_size = group.get("group_overrides", {}).get( - "group_size", self.config.group_size - ) + try: + overrides = group.get("group_overrides", {}) + if overrides is None: + overrides = {} + group_size = overrides.get("group_size", self.config.group_size) + except Exception as e: + logger.error(f"Error getting group size: {e}") + continue - if not ( - (None not in group) and (len(group.get("tokens", [])) == group_size) - ): + tokens = group.get("tokens", []) + + # Check for empty token sequences + empty_tokens = [i for i, t in enumerate(tokens) if len(t) == 0] + if empty_tokens: + logger.debug( + f"Found group with empty token sequences at indices {empty_tokens}" + ) + continue + + # Check group size constraints + if len(tokens) != group_size: logger.warning( - f"Group structure invalid, or token count mismatch (expected {group_size}), " - f"or 'tokens' key missing. Skipping group: {str(group)[:200]}..." + f"Group size mismatch (expected {group_size}, got {len(tokens)})" ) continue @@ -820,7 +831,9 @@ class BaseEnv(ABC): self.config.ensure_scores_are_not_same and len(set(group["scores"])) == 1 ): - logger.warning("Scores are the same in a group, skipping...") + logger.warning( + f"All scores are the same ({group.get('scores', [])[:5]}...), skipping..." + ) continue group.setdefault("ref_logprobs", None) @@ -851,28 +864,28 @@ class BaseEnv(ABC): if self.jsonl_writer is not None: self.jsonl_writer.write(group) - print(f"Wrote scored group to {self.config.data_path_to_save_groups}") + logger.info( + f"Wrote scored group to {self.config.data_path_to_save_groups}" + ) valid_groups.append(group) + logger.info( + f"Valid groups: {len(valid_groups)}, do_send_to_api: {do_send_to_api}" + ) if valid_groups and do_send_to_api: - data_to_send_to_api: Union[ScoredDataGroup, List[ScoredDataGroup]] - # send single or list of scored data groups - if not original_was_list and len(valid_groups) == 1: - data_to_send_to_api = valid_groups[0] - else: - data_to_send_to_api = valid_groups + # Always send groups individually to avoid distributed process issues + for i, group in enumerate(valid_groups): + try: + await self._send_scored_data_to_api(group) + except (Exception, TimeoutError) as e: + logger.error(f"Failed to send group {i+1} after retries: {e}") - try: - self.items_sent_this_step += len(valid_groups) - await self._send_scored_data_to_api(data_to_send_to_api) - except (Exception, TimeoutError) as e: - data_type_str = ( - "single ScoredDataGroup" - if isinstance(data_to_send_to_api, dict) - else f"{len(data_to_send_to_api)} ScoredDataGroups" - ) - print(f"Failed to send {data_type_str} after retries: {e}") + self.items_sent_this_step += len(valid_groups) + else: + logger.info( + f"Not sending data: valid_groups={len(valid_groups)}, do_send_to_api={do_send_to_api}" + ) async def handle_env( self, item_uuid: str