This commit is contained in:
Shannon Sands 2025-05-10 09:10:31 +10:00
parent 6c6a1c5d06
commit 4d0f919fd1

View file

@ -592,55 +592,61 @@ class BaseEnv(ABC):
do_send_to_api: Whether to send the data to the API
abort_on_any_max_length_exceeded: Whether to abort if any token length exceeds the max
"""
original_was_list = isinstance(scored_data, list) # not sure if this is needed
original_was_list = isinstance(scored_data, list) # not sure if this is needed
data_to_process = scored_data if original_was_list else [scored_data]
valid_groups = []
for group in data_to_process:
if group is None:
continue
group_size = group.get("group_overrides", {}).get(
"group_size", self.config.group_size
)
if not ((None not in group) and (len(group.get("tokens", [])) == group_size)):
if not (
(None not in group) and (len(group.get("tokens", [])) == group_size)
):
logger.warning(
f"Group structure invalid, or token count mismatch (expected {group_size}), or 'tokens' key missing. Skipping group: {str(group)[:200]}..."
f"Group structure invalid, or token count mismatch (expected {group_size}), "
f"or 'tokens' key missing. Skipping group: {str(group)[:200]}..."
)
continue
if self.config.ensure_scores_are_not_same and len(set(group["scores"])) == 1:
if (
self.config.ensure_scores_are_not_same
and len(set(group["scores"])) == 1
):
logger.warning("Scores are the same in a group, skipping...")
continue
group.setdefault("ref_logprobs", None)
group.setdefault("overrides", None)
group.setdefault("group_overrides", None)
for mask in group["masks"]:
self.completion_lengths.append(len(mask))
if abort_on_any_max_length_exceeded and any(
[len(x) >= self.max_token_len for x in group["tokens"]]
):
logger.warning("Token length is too long in a group, skipping...")
continue
if self.config.include_messages and group.get("messages") is None:
group["messages"] = [
self.tokenizer.decode(group["tokens"][i])
for i in range(len(group["tokens"]))
]
await self.add_rollouts_for_wandb(group, item)
if self.jsonl_writer is not None:
self.jsonl_writer.write(group)
print(f"Wrote scored group to {self.config.data_path_to_save_groups}")
valid_groups.append(group)
if valid_groups and do_send_to_api:
data_to_send_to_api: Union[ScoredDataGroup, List[ScoredDataGroup]]
# send single or list of scored data groups
@ -653,7 +659,11 @@ class BaseEnv(ABC):
self.items_sent_this_step += len(valid_groups)
await self._send_scored_data_to_api(data_to_send_to_api)
except (Exception, TimeoutError) as e:
data_type_str = "single ScoredDataGroup" if isinstance(data_to_send_to_api, dict) else f"{len(data_to_send_to_api)} ScoredDataGroups"
data_type_str = (
"single ScoredDataGroup"
if isinstance(data_to_send_to_api, dict)
else f"{len(data_to_send_to_api)} ScoredDataGroups"
)
print(f"Failed to send {data_type_str} after retries: {e}")
async def handle_env(