mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
Revert "Fix multiple scored data groups (#223)"
This reverts commit 67b3144113.
This commit is contained in:
parent
67b3144113
commit
1a808e2038
1 changed files with 33 additions and 46 deletions
|
|
@ -758,7 +758,11 @@ class BaseEnv(ABC):
|
|||
else:
|
||||
scored_data["env_id"] = getattr(self, "env_id", None)
|
||||
|
||||
url = f"{self.config.rollout_server_url}/scored_data"
|
||||
url = (
|
||||
f"{self.config.rollout_server_url}/scored_data_list"
|
||||
if isinstance(scored_data, list)
|
||||
else f"{self.config.rollout_server_url}/scored_data"
|
||||
)
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
url,
|
||||
|
|
@ -791,39 +795,24 @@ class BaseEnv(ABC):
|
|||
do_send_to_api: Whether to send the data to the API
|
||||
abort_on_any_max_length_exceeded: Whether to abort if any token length exceeds the max
|
||||
"""
|
||||
# Ensure we're working with a list of groups
|
||||
data_to_process = (
|
||||
scored_data if isinstance(scored_data, list) else [scored_data]
|
||||
)
|
||||
original_was_list = isinstance(scored_data, list) # not sure if this is needed
|
||||
data_to_process = scored_data if original_was_list else [scored_data]
|
||||
|
||||
valid_groups = []
|
||||
for group in data_to_process:
|
||||
if group is None:
|
||||
continue
|
||||
|
||||
try:
|
||||
overrides = group.get("group_overrides", {})
|
||||
if overrides is None:
|
||||
overrides = {}
|
||||
group_size = overrides.get("group_size", self.config.group_size)
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting group size: {e}")
|
||||
continue
|
||||
group_size = group.get("group_overrides", {}).get(
|
||||
"group_size", self.config.group_size
|
||||
)
|
||||
|
||||
tokens = group.get("tokens", [])
|
||||
|
||||
# Check for empty token sequences
|
||||
empty_tokens = [i for i, t in enumerate(tokens) if len(t) == 0]
|
||||
if empty_tokens:
|
||||
logger.debug(
|
||||
f"Found group with empty token sequences at indices {empty_tokens}"
|
||||
)
|
||||
continue
|
||||
|
||||
# Check group size constraints
|
||||
if len(tokens) != group_size:
|
||||
if not (
|
||||
(None not in group) and (len(group.get("tokens", [])) == group_size)
|
||||
):
|
||||
logger.warning(
|
||||
f"Group size mismatch (expected {group_size}, got {len(tokens)})"
|
||||
f"Group structure invalid, or token count mismatch (expected {group_size}), "
|
||||
f"or 'tokens' key missing. Skipping group: {str(group)[:200]}..."
|
||||
)
|
||||
continue
|
||||
|
||||
|
|
@ -831,9 +820,7 @@ class BaseEnv(ABC):
|
|||
self.config.ensure_scores_are_not_same
|
||||
and len(set(group["scores"])) == 1
|
||||
):
|
||||
logger.warning(
|
||||
f"All scores are the same ({group.get('scores', [])[:5]}...), skipping..."
|
||||
)
|
||||
logger.warning("Scores are the same in a group, skipping...")
|
||||
continue
|
||||
|
||||
group.setdefault("ref_logprobs", None)
|
||||
|
|
@ -864,28 +851,28 @@ class BaseEnv(ABC):
|
|||
|
||||
if self.jsonl_writer is not None:
|
||||
self.jsonl_writer.write(group)
|
||||
logger.info(
|
||||
f"Wrote scored group to {self.config.data_path_to_save_groups}"
|
||||
)
|
||||
print(f"Wrote scored group to {self.config.data_path_to_save_groups}")
|
||||
|
||||
valid_groups.append(group)
|
||||
|
||||
logger.info(
|
||||
f"Valid groups: {len(valid_groups)}, do_send_to_api: {do_send_to_api}"
|
||||
)
|
||||
if valid_groups and do_send_to_api:
|
||||
# Always send groups individually to avoid distributed process issues
|
||||
for i, group in enumerate(valid_groups):
|
||||
try:
|
||||
await self._send_scored_data_to_api(group)
|
||||
except (Exception, TimeoutError) as e:
|
||||
logger.error(f"Failed to send group {i+1} after retries: {e}")
|
||||
data_to_send_to_api: Union[ScoredDataGroup, List[ScoredDataGroup]]
|
||||
# send single or list of scored data groups
|
||||
if not original_was_list and len(valid_groups) == 1:
|
||||
data_to_send_to_api = valid_groups[0]
|
||||
else:
|
||||
data_to_send_to_api = valid_groups
|
||||
|
||||
self.items_sent_this_step += len(valid_groups)
|
||||
else:
|
||||
logger.info(
|
||||
f"Not sending data: valid_groups={len(valid_groups)}, do_send_to_api={do_send_to_api}"
|
||||
)
|
||||
try:
|
||||
self.items_sent_this_step += len(valid_groups)
|
||||
await self._send_scored_data_to_api(data_to_send_to_api)
|
||||
except (Exception, TimeoutError) as e:
|
||||
data_type_str = (
|
||||
"single ScoredDataGroup"
|
||||
if isinstance(data_to_send_to_api, dict)
|
||||
else f"{len(data_to_send_to_api)} ScoredDataGroups"
|
||||
)
|
||||
print(f"Failed to send {data_type_str} after retries: {e}")
|
||||
|
||||
async def handle_env(
|
||||
self, item_uuid: str
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue