diff --git a/atroposlib/envs/base.py b/atroposlib/envs/base.py index 25183f6b..0577363a 100644 --- a/atroposlib/envs/base.py +++ b/atroposlib/envs/base.py @@ -592,55 +592,61 @@ class BaseEnv(ABC): do_send_to_api: Whether to send the data to the API abort_on_any_max_length_exceeded: Whether to abort if any token length exceeds the max """ - original_was_list = isinstance(scored_data, list) # not sure if this is needed + original_was_list = isinstance(scored_data, list) # not sure if this is needed data_to_process = scored_data if original_was_list else [scored_data] valid_groups = [] for group in data_to_process: if group is None: continue - + group_size = group.get("group_overrides", {}).get( "group_size", self.config.group_size ) - - if not ((None not in group) and (len(group.get("tokens", [])) == group_size)): + + if not ( + (None not in group) and (len(group.get("tokens", [])) == group_size) + ): logger.warning( - f"Group structure invalid, or token count mismatch (expected {group_size}), or 'tokens' key missing. Skipping group: {str(group)[:200]}..." + f"Group structure invalid, or token count mismatch (expected {group_size}), " + f"or 'tokens' key missing. Skipping group: {str(group)[:200]}..." ) continue - - if self.config.ensure_scores_are_not_same and len(set(group["scores"])) == 1: + + if ( + self.config.ensure_scores_are_not_same + and len(set(group["scores"])) == 1 + ): logger.warning("Scores are the same in a group, skipping...") continue - + group.setdefault("ref_logprobs", None) group.setdefault("overrides", None) group.setdefault("group_overrides", None) - + for mask in group["masks"]: self.completion_lengths.append(len(mask)) - + if abort_on_any_max_length_exceeded and any( [len(x) >= self.max_token_len for x in group["tokens"]] ): logger.warning("Token length is too long in a group, skipping...") continue - + if self.config.include_messages and group.get("messages") is None: group["messages"] = [ self.tokenizer.decode(group["tokens"][i]) for i in range(len(group["tokens"])) ] - + await self.add_rollouts_for_wandb(group, item) - + if self.jsonl_writer is not None: self.jsonl_writer.write(group) print(f"Wrote scored group to {self.config.data_path_to_save_groups}") - + valid_groups.append(group) - + if valid_groups and do_send_to_api: data_to_send_to_api: Union[ScoredDataGroup, List[ScoredDataGroup]] # send single or list of scored data groups @@ -653,7 +659,11 @@ class BaseEnv(ABC): self.items_sent_this_step += len(valid_groups) await self._send_scored_data_to_api(data_to_send_to_api) except (Exception, TimeoutError) as e: - data_type_str = "single ScoredDataGroup" if isinstance(data_to_send_to_api, dict) else f"{len(data_to_send_to_api)} ScoredDataGroups" + data_type_str = ( + "single ScoredDataGroup" + if isinstance(data_to_send_to_api, dict) + else f"{len(data_to_send_to_api)} ScoredDataGroups" + ) print(f"Failed to send {data_type_str} after retries: {e}") async def handle_env(