diff --git a/atroposlib/envs/base.py b/atroposlib/envs/base.py index 825baf62..64d2eb04 100644 --- a/atroposlib/envs/base.py +++ b/atroposlib/envs/base.py @@ -758,7 +758,11 @@ class BaseEnv(ABC): else: scored_data["env_id"] = getattr(self, "env_id", None) - url = f"{self.config.rollout_server_url}/scored_data" + url = ( + f"{self.config.rollout_server_url}/scored_data_list" + if isinstance(scored_data, list) + else f"{self.config.rollout_server_url}/scored_data" + ) async with aiohttp.ClientSession() as session: async with session.post( url, @@ -791,39 +795,24 @@ class BaseEnv(ABC): do_send_to_api: Whether to send the data to the API abort_on_any_max_length_exceeded: Whether to abort if any token length exceeds the max """ - # Ensure we're working with a list of groups - data_to_process = ( - scored_data if isinstance(scored_data, list) else [scored_data] - ) + original_was_list = isinstance(scored_data, list) # not sure if this is needed + data_to_process = scored_data if original_was_list else [scored_data] valid_groups = [] for group in data_to_process: if group is None: continue - try: - overrides = group.get("group_overrides", {}) - if overrides is None: - overrides = {} - group_size = overrides.get("group_size", self.config.group_size) - except Exception as e: - logger.error(f"Error getting group size: {e}") - continue + group_size = group.get("group_overrides", {}).get( + "group_size", self.config.group_size + ) - tokens = group.get("tokens", []) - - # Check for empty token sequences - empty_tokens = [i for i, t in enumerate(tokens) if len(t) == 0] - if empty_tokens: - logger.debug( - f"Found group with empty token sequences at indices {empty_tokens}" - ) - continue - - # Check group size constraints - if len(tokens) != group_size: + if not ( + (None not in group) and (len(group.get("tokens", [])) == group_size) + ): logger.warning( - f"Group size mismatch (expected {group_size}, got {len(tokens)})" + f"Group structure invalid, or token count mismatch (expected {group_size}), " + f"or 'tokens' key missing. Skipping group: {str(group)[:200]}..." ) continue @@ -831,9 +820,7 @@ class BaseEnv(ABC): self.config.ensure_scores_are_not_same and len(set(group["scores"])) == 1 ): - logger.warning( - f"All scores are the same ({group.get('scores', [])[:5]}...), skipping..." - ) + logger.warning("Scores are the same in a group, skipping...") continue group.setdefault("ref_logprobs", None) @@ -864,28 +851,28 @@ class BaseEnv(ABC): if self.jsonl_writer is not None: self.jsonl_writer.write(group) - logger.info( - f"Wrote scored group to {self.config.data_path_to_save_groups}" - ) + print(f"Wrote scored group to {self.config.data_path_to_save_groups}") valid_groups.append(group) - logger.info( - f"Valid groups: {len(valid_groups)}, do_send_to_api: {do_send_to_api}" - ) if valid_groups and do_send_to_api: - # Always send groups individually to avoid distributed process issues - for i, group in enumerate(valid_groups): - try: - await self._send_scored_data_to_api(group) - except (Exception, TimeoutError) as e: - logger.error(f"Failed to send group {i+1} after retries: {e}") + data_to_send_to_api: Union[ScoredDataGroup, List[ScoredDataGroup]] + # send single or list of scored data groups + if not original_was_list and len(valid_groups) == 1: + data_to_send_to_api = valid_groups[0] + else: + data_to_send_to_api = valid_groups - self.items_sent_this_step += len(valid_groups) - else: - logger.info( - f"Not sending data: valid_groups={len(valid_groups)}, do_send_to_api={do_send_to_api}" - ) + try: + self.items_sent_this_step += len(valid_groups) + await self._send_scored_data_to_api(data_to_send_to_api) + except (Exception, TimeoutError) as e: + data_type_str = ( + "single ScoredDataGroup" + if isinstance(data_to_send_to_api, dict) + else f"{len(data_to_send_to_api)} ScoredDataGroups" + ) + print(f"Failed to send {data_type_str} after retries: {e}") async def handle_env( self, item_uuid: str