Revert "Fix multiple scored data groups (#223)"

This reverts commit 67b3144113.
This commit is contained in:
shannonsands 2025-08-29 17:55:45 +10:00 committed by GitHub
parent 67b3144113
commit 1a808e2038
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -758,7 +758,11 @@ class BaseEnv(ABC):
else:
scored_data["env_id"] = getattr(self, "env_id", None)
url = f"{self.config.rollout_server_url}/scored_data"
url = (
f"{self.config.rollout_server_url}/scored_data_list"
if isinstance(scored_data, list)
else f"{self.config.rollout_server_url}/scored_data"
)
async with aiohttp.ClientSession() as session:
async with session.post(
url,
@ -791,39 +795,24 @@ class BaseEnv(ABC):
do_send_to_api: Whether to send the data to the API
abort_on_any_max_length_exceeded: Whether to abort if any token length exceeds the max
"""
# Ensure we're working with a list of groups
data_to_process = (
scored_data if isinstance(scored_data, list) else [scored_data]
)
original_was_list = isinstance(scored_data, list) # not sure if this is needed
data_to_process = scored_data if original_was_list else [scored_data]
valid_groups = []
for group in data_to_process:
if group is None:
continue
try:
overrides = group.get("group_overrides", {})
if overrides is None:
overrides = {}
group_size = overrides.get("group_size", self.config.group_size)
except Exception as e:
logger.error(f"Error getting group size: {e}")
continue
group_size = group.get("group_overrides", {}).get(
"group_size", self.config.group_size
)
tokens = group.get("tokens", [])
# Check for empty token sequences
empty_tokens = [i for i, t in enumerate(tokens) if len(t) == 0]
if empty_tokens:
logger.debug(
f"Found group with empty token sequences at indices {empty_tokens}"
)
continue
# Check group size constraints
if len(tokens) != group_size:
if not (
(None not in group) and (len(group.get("tokens", [])) == group_size)
):
logger.warning(
f"Group size mismatch (expected {group_size}, got {len(tokens)})"
f"Group structure invalid, or token count mismatch (expected {group_size}), "
f"or 'tokens' key missing. Skipping group: {str(group)[:200]}..."
)
continue
@ -831,9 +820,7 @@ class BaseEnv(ABC):
self.config.ensure_scores_are_not_same
and len(set(group["scores"])) == 1
):
logger.warning(
f"All scores are the same ({group.get('scores', [])[:5]}...), skipping..."
)
logger.warning("Scores are the same in a group, skipping...")
continue
group.setdefault("ref_logprobs", None)
@ -864,28 +851,28 @@ class BaseEnv(ABC):
if self.jsonl_writer is not None:
self.jsonl_writer.write(group)
logger.info(
f"Wrote scored group to {self.config.data_path_to_save_groups}"
)
print(f"Wrote scored group to {self.config.data_path_to_save_groups}")
valid_groups.append(group)
logger.info(
f"Valid groups: {len(valid_groups)}, do_send_to_api: {do_send_to_api}"
)
if valid_groups and do_send_to_api:
# Always send groups individually to avoid distributed process issues
for i, group in enumerate(valid_groups):
try:
await self._send_scored_data_to_api(group)
except (Exception, TimeoutError) as e:
logger.error(f"Failed to send group {i+1} after retries: {e}")
data_to_send_to_api: Union[ScoredDataGroup, List[ScoredDataGroup]]
# send single or list of scored data groups
if not original_was_list and len(valid_groups) == 1:
data_to_send_to_api = valid_groups[0]
else:
data_to_send_to_api = valid_groups
self.items_sent_this_step += len(valid_groups)
else:
logger.info(
f"Not sending data: valid_groups={len(valid_groups)}, do_send_to_api={do_send_to_api}"
)
try:
self.items_sent_this_step += len(valid_groups)
await self._send_scored_data_to_api(data_to_send_to_api)
except (Exception, TimeoutError) as e:
data_type_str = (
"single ScoredDataGroup"
if isinstance(data_to_send_to_api, dict)
else f"{len(data_to_send_to_api)} ScoredDataGroups"
)
print(f"Failed to send {data_type_str} after retries: {e}")
async def handle_env(
self, item_uuid: str