mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
linting
This commit is contained in:
parent
6c6a1c5d06
commit
4d0f919fd1
1 changed files with 26 additions and 16 deletions
|
|
@ -592,55 +592,61 @@ class BaseEnv(ABC):
|
|||
do_send_to_api: Whether to send the data to the API
|
||||
abort_on_any_max_length_exceeded: Whether to abort if any token length exceeds the max
|
||||
"""
|
||||
original_was_list = isinstance(scored_data, list) # not sure if this is needed
|
||||
original_was_list = isinstance(scored_data, list) # not sure if this is needed
|
||||
data_to_process = scored_data if original_was_list else [scored_data]
|
||||
|
||||
valid_groups = []
|
||||
for group in data_to_process:
|
||||
if group is None:
|
||||
continue
|
||||
|
||||
|
||||
group_size = group.get("group_overrides", {}).get(
|
||||
"group_size", self.config.group_size
|
||||
)
|
||||
|
||||
if not ((None not in group) and (len(group.get("tokens", [])) == group_size)):
|
||||
|
||||
if not (
|
||||
(None not in group) and (len(group.get("tokens", [])) == group_size)
|
||||
):
|
||||
logger.warning(
|
||||
f"Group structure invalid, or token count mismatch (expected {group_size}), or 'tokens' key missing. Skipping group: {str(group)[:200]}..."
|
||||
f"Group structure invalid, or token count mismatch (expected {group_size}), "
|
||||
f"or 'tokens' key missing. Skipping group: {str(group)[:200]}..."
|
||||
)
|
||||
continue
|
||||
|
||||
if self.config.ensure_scores_are_not_same and len(set(group["scores"])) == 1:
|
||||
|
||||
if (
|
||||
self.config.ensure_scores_are_not_same
|
||||
and len(set(group["scores"])) == 1
|
||||
):
|
||||
logger.warning("Scores are the same in a group, skipping...")
|
||||
continue
|
||||
|
||||
|
||||
group.setdefault("ref_logprobs", None)
|
||||
group.setdefault("overrides", None)
|
||||
group.setdefault("group_overrides", None)
|
||||
|
||||
|
||||
for mask in group["masks"]:
|
||||
self.completion_lengths.append(len(mask))
|
||||
|
||||
|
||||
if abort_on_any_max_length_exceeded and any(
|
||||
[len(x) >= self.max_token_len for x in group["tokens"]]
|
||||
):
|
||||
logger.warning("Token length is too long in a group, skipping...")
|
||||
continue
|
||||
|
||||
|
||||
if self.config.include_messages and group.get("messages") is None:
|
||||
group["messages"] = [
|
||||
self.tokenizer.decode(group["tokens"][i])
|
||||
for i in range(len(group["tokens"]))
|
||||
]
|
||||
|
||||
|
||||
await self.add_rollouts_for_wandb(group, item)
|
||||
|
||||
|
||||
if self.jsonl_writer is not None:
|
||||
self.jsonl_writer.write(group)
|
||||
print(f"Wrote scored group to {self.config.data_path_to_save_groups}")
|
||||
|
||||
|
||||
valid_groups.append(group)
|
||||
|
||||
|
||||
if valid_groups and do_send_to_api:
|
||||
data_to_send_to_api: Union[ScoredDataGroup, List[ScoredDataGroup]]
|
||||
# send single or list of scored data groups
|
||||
|
|
@ -653,7 +659,11 @@ class BaseEnv(ABC):
|
|||
self.items_sent_this_step += len(valid_groups)
|
||||
await self._send_scored_data_to_api(data_to_send_to_api)
|
||||
except (Exception, TimeoutError) as e:
|
||||
data_type_str = "single ScoredDataGroup" if isinstance(data_to_send_to_api, dict) else f"{len(data_to_send_to_api)} ScoredDataGroups"
|
||||
data_type_str = (
|
||||
"single ScoredDataGroup"
|
||||
if isinstance(data_to_send_to_api, dict)
|
||||
else f"{len(data_to_send_to_api)} ScoredDataGroups"
|
||||
)
|
||||
print(f"Failed to send {data_type_str} after retries: {e}")
|
||||
|
||||
async def handle_env(
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue