diff --git a/atroposlib/envs/base.py b/atroposlib/envs/base.py index 933a3ee0..9fe4c68f 100644 --- a/atroposlib/envs/base.py +++ b/atroposlib/envs/base.py @@ -185,11 +185,11 @@ class BaseEnv(ABC): server_cls: APIServer = APIServer def __init__( - self, - config: BaseEnvConfig, - server_configs: Union[ServerBaseline, List[APIServerConfig]], - slurm=False, - testing=False, + self, + config: BaseEnvConfig, + server_configs: Union[ServerBaseline, List[APIServerConfig]], + slurm=False, + testing=False, ): self.items_sent_this_step = 0 self.eval_runner = None # type: Optional[asyncio.Task] @@ -272,7 +272,7 @@ class BaseEnv(ABC): @classmethod def config_init( - cls, + cls, ) -> Tuple[BaseEnvConfig, Union[ServerBaseline, List[APIServerConfig]]]: """ Initialize the config @@ -280,7 +280,7 @@ class BaseEnv(ABC): return cls.env_config_cls(), ServerBaseline() async def collect_trajectory( - self, item: Item + self, item: Item ) -> Tuple[Optional[Union[ScoredDataItem, Any]], List[Item]]: raise NotImplementedError( "Handle env single method must be implemented in subclass " @@ -339,8 +339,8 @@ class BaseEnv(ABC): return to_postprocess, backlog async def postprocess_histories( - self, - trajectories: Union[Optional[ScoredDataGroup], List[Optional[ScoredDataGroup]]], + self, + trajectories: Union[Optional[ScoredDataGroup], List[Optional[ScoredDataGroup]]], ) -> Union[Optional[ScoredDataGroup], List[Optional[ScoredDataGroup]]]: """ Postprocess the histories, this is called after the collect_trajectories method @@ -428,7 +428,7 @@ class BaseEnv(ABC): while self.wandb_project is None: async with aiohttp.ClientSession() as session: async with session.get( - f"{self.config.rollout_server_url}/wandb_info" + f"{self.config.rollout_server_url}/wandb_info" ) as resp: data = await parse_http_response(resp, logger) self.wandb_group = data["group"] @@ -462,14 +462,14 @@ class BaseEnv(ABC): try: async with aiohttp.ClientSession() as session: async with session.post( - f"{self.config.rollout_server_url}/register-env", - json={ - "max_token_length": self.config.max_token_length, - "desired_name": self.config.wandb_name, - "weight": self.config.inference_weight, - "min_batch_allocation": self.config.min_batch_allocation, - "group_size": self.config.group_size, - }, + f"{self.config.rollout_server_url}/register-env", + json={ + "max_token_length": self.config.max_token_length, + "desired_name": self.config.wandb_name, + "weight": self.config.inference_weight, + "min_batch_allocation": self.config.min_batch_allocation, + "group_size": self.config.group_size, + }, ) as resp: data = await parse_http_response(resp, logger) return data @@ -577,9 +577,9 @@ class BaseEnv(ABC): return wandb_metrics async def add_rollouts_for_wandb( - self, - scored_data: Union[ScoredDataGroup, List[ScoredDataGroup]], - item: Item = None, + self, + scored_data: Union[ScoredDataGroup, List[ScoredDataGroup]], + item: Item = None, ): # Save rollout to trajectory num_keep = self.config.num_rollouts_per_group_for_logging @@ -626,7 +626,7 @@ class BaseEnv(ABC): self.completion_lengths ) wandb_metrics["train/completion_lengths_p95"] = ( - np.array(self.completion_lengths) > (0.95 * self.max_token_len) + np.array(self.completion_lengths) > (0.95 * self.max_token_len) ).mean() wandb_metrics = await self.create_rollout_table(wandb_metrics) wandb_metrics = self.perf_stats(wandb_metrics) @@ -642,15 +642,15 @@ class BaseEnv(ABC): wandb.log(wandb_metrics, step=self.curr_step) async def evaluate_log( - self, - metrics: Dict, - task_name: Optional[str] = None, - model_name: Optional[str] = None, - start_time: Optional[float] = None, - end_time: Optional[float] = None, - generation_parameters: Optional[Dict] = None, - samples: Optional[List[Dict]] = None, - verbose: bool = True, + self, + metrics: Dict, + task_name: Optional[str] = None, + model_name: Optional[str] = None, + start_time: Optional[float] = None, + end_time: Optional[float] = None, + generation_parameters: Optional[Dict] = None, + samples: Optional[List[Dict]] = None, + verbose: bool = True, ): """ Log evaluation results to a JSON file in the format expected by nous-evals. @@ -690,7 +690,7 @@ class BaseEnv(ABC): # Get model name from first server config first_server = self.server.servers[0] if hasattr(first_server, "config") and hasattr( - first_server.config, "model_name" + first_server.config, "model_name" ): model_name = first_server.config.model_name if start_time is None: @@ -767,8 +767,8 @@ class BaseEnv(ABC): ) async with aiohttp.ClientSession() as session: async with session.post( - url, - json=scored_data, + url, + json=scored_data, ) as resp: if resp.status >= 500: # Server errors (5xx) should trigger a retry @@ -782,11 +782,11 @@ class BaseEnv(ABC): print(await resp.text()) async def handle_send_to_api( - self, - scored_data: Union[ScoredDataGroup, List[ScoredDataGroup]], - item: Item = None, - do_send_to_api: bool = True, - abort_on_any_max_length_exceeded: bool = True, + self, + scored_data: Union[ScoredDataGroup, List[ScoredDataGroup]], + item: Item = None, + do_send_to_api: bool = True, + abort_on_any_max_length_exceeded: bool = True, ): """ Send the chats to the API with robust error handling and support for multiple ScoredDataGroups. @@ -810,7 +810,7 @@ class BaseEnv(ABC): ) if not ( - (None not in group) and (len(group.get("tokens", [])) == group_size) + (None not in group) and (len(group.get("tokens", [])) == group_size) ): logger.warning( f"Group structure invalid, or token count mismatch (expected {group_size}), " @@ -819,8 +819,8 @@ class BaseEnv(ABC): continue if ( - self.config.ensure_scores_are_not_same - and len(set(group["scores"])) == 1 + self.config.ensure_scores_are_not_same + and len(set(group["scores"])) == 1 ): logger.warning("Scores are the same in a group, skipping...") continue @@ -838,7 +838,7 @@ class BaseEnv(ABC): "ensure your trainer handles this appropriately." ) elif abort_on_any_max_length_exceeded and any( - [len(x) >= self.max_token_len for x in group["tokens"]] + [len(x) >= self.max_token_len for x in group["tokens"]] ): logger.warning("Token length is too long in a group, skipping...") continue @@ -877,7 +877,7 @@ class BaseEnv(ABC): print(f"Failed to send {data_type_str} after retries: {e}") async def handle_env( - self, item_uuid: str + self, item_uuid: str ) -> Optional[Union[ScoredDataGroup, List[ScoredDataGroup]]]: """ Handle the rollout of an item @@ -938,8 +938,8 @@ class BaseEnv(ABC): async def get_status(self): async with aiohttp.ClientSession() as session: async with session.get( - f"{self.config.rollout_server_url}/status-env", - json={"env_id": self.env_id}, + f"{self.config.rollout_server_url}/status-env", + json={"env_id": self.env_id}, ) as resp: self.status_dict = await parse_http_response(resp, logger) new_weight = self.status_dict["env_weight"] @@ -956,7 +956,7 @@ class BaseEnv(ABC): if self.curr_step != self.status_dict["current_step"]: if self.config.steps_per_eval > 0: if (self.curr_step % self.config.steps_per_eval) > ( - self.status_dict["current_step"] % self.config.steps_per_eval + self.status_dict["current_step"] % self.config.steps_per_eval ): if (self.eval_runner is None) or (self.eval_runner.done()): eval_task = asyncio.create_task(self.evaluate()) @@ -976,11 +976,11 @@ class BaseEnv(ABC): ) if self.checkpoint_interval > 0: if (self.curr_step % self.checkpoint_interval) > ( - self.status_dict["current_step"] % self.checkpoint_interval + self.status_dict["current_step"] % self.checkpoint_interval ): checkpoint_step = ( - self.status_dict["current_step"] // self.checkpoint_interval - ) * self.checkpoint_interval + self.status_dict["current_step"] // self.checkpoint_interval + ) * self.checkpoint_interval self.save_checkpoint(checkpoint_step) self.curr_step = self.status_dict["current_step"] if self.items_sent_this_step >= self.config.min_items_sent_before_logging: @@ -1003,14 +1003,12 @@ class BaseEnv(ABC): max_num_workers = min( max_num_workers, ( - self.config.max_batches_offpolicy - * self.derived_batch_size - // self.config.group_size + self.config.max_batches_offpolicy + * self.derived_batch_size + // self.config.group_size + ) + - (self.status_dict["queue_size"]), ) - - (self.status_dict["self_queue_size"]), - ) - # now minimum num workers based on allocation - # Now if we have a minimum batch allocation, we need to add workers to fill the self queue, in case of # overruns by other environments if self.config.min_batch_allocation is not None: @@ -1018,31 +1016,31 @@ class BaseEnv(ABC): 0, math.ceil( ( - ( ( - math.ceil( - self.config.min_batch_allocation - * self.config.batch_size - * self.config.max_batches_offpolicy - / self.status_dict["max_group_size"] - ) - + ( - self.status_dict["max_group_size"] - // self.config.group_size + ( + math.ceil( + self.config.min_batch_allocation + * self.config.batch_size + * self.config.max_batches_offpolicy + / self.status_dict["max_group_size"] + ) + + ( + self.status_dict["max_group_size"] + // self.config.group_size + ) + ) + * self.status_dict["max_group_size"] + ) + - ( + ( + self.status_dict["max_group_size"] + * self.status_dict["self_queue_size"] + // ( + self.status_dict["max_group_size"] + / self.config.group_size + ) ) ) - * self.status_dict["max_group_size"] - ) - - ( - ( - self.status_dict["max_group_size"] - * self.status_dict["self_queue_size"] - // ( - self.status_dict["max_group_size"] - / self.config.group_size - ) - ) - ) ) / self.config.group_size ), @@ -1124,45 +1122,45 @@ class BaseEnv(ABC): await self.env_step_checks() logger.info(f"env_manager: Status dict: {self.status_dict}") if ( - self.status_dict["current_step"] - + ( - self.status_dict["self_queue_size"] - * self.config.group_size - // self.config.batch_size - ) + self.status_dict["current_step"] + + ( + self.status_dict["queue_size"] + * self.config.group_size + // self.config.batch_size + ) ) > self.config.total_steps: for worker in self.workers: worker.cancel() break if ( - ( - self.status_dict["self_queue_size"] * self.config.group_size - >= self.config.max_batches_offpolicy * self.config.batch_size - ) - and (self.config.max_batches_offpolicy > 0) - and ( - (self.config.min_batch_allocation is None) - or ( - ( - ( - ( - math.ceil( - self.config.min_batch_allocation - * self.config.batch_size - * self.config.max_batches_offpolicy - / self.status_dict["max_group_size"] - ) - * ( - self.status_dict["max_group_size"] - // self.config.group_size - ) - ) - ) - - (self.status_dict["self_queue_size"]) - ) - <= 0 + ( + self.status_dict["queue_size"] * self.config.group_size + >= self.config.max_batches_offpolicy * self.config.batch_size + ) + and (self.config.max_batches_offpolicy > 0) + and ( + (self.config.min_batch_allocation is None) + or ( + ( + ( + ( + math.ceil( + self.config.min_batch_allocation + * self.config.batch_size + * self.config.max_batches_offpolicy + / self.status_dict["max_group_size"] + ) + * ( + self.status_dict["max_group_size"] + // self.config.group_size + ) + ) + ) + - (self.status_dict["self_queue_size"]) + ) + <= 0 + ) ) - ) ) or (self.derived_batch_size == -1): # We have too many, lets cleanup the tasks and wait a bit self.backlog.extend([x["item"] for x in self.running_items.values()]) @@ -1335,8 +1333,8 @@ class BaseEnv(ABC): # Note: This modifies the 'self' instance based on CLI args before full parsing. wandb_name_attr = f"{ENV_NAMESPACE}{NAMESPACE_SEP}wandb_name" if ( - getattr(self, wandb_name_attr, None) is None - and cls.name is not None + getattr(self, wandb_name_attr, None) is None + and cls.name is not None ): setattr(self, wandb_name_attr, cls.name) @@ -1367,14 +1365,14 @@ class BaseEnv(ABC): ) # CLI args yaml_oai_config = yaml_config.get(OPENAI_NAMESPACE, {}) if isinstance(default_server_configs, ServerBaseline) and ( - oai_cli_passed_args or yaml_oai_config + oai_cli_passed_args or yaml_oai_config ): raise ValueError( "ServerBaseline is not compatible with OpenAI-namespaced CLI arguments. Please edit `config_init` directly or use APIServerConfig." # noqa: E501 ) if ( - isinstance(default_server_configs, list) - and len(default_server_configs) == 1 + isinstance(default_server_configs, list) + and len(default_server_configs) == 1 ): # can't use the same var name because it shadows the class variable and we get an error default_openai_config_ = default_server_configs[0] @@ -1383,7 +1381,7 @@ class BaseEnv(ABC): if isinstance(yaml_oai_config, list) and len(yaml_oai_config) == 1: yaml_oai_config = yaml_oai_config[0] if isinstance(default_openai_config_, APIServerConfig) and isinstance( - yaml_oai_config, dict + yaml_oai_config, dict ): openai_config_dict = merge_dicts( default_openai_config_.model_dump(), # Default APIServerConfig (or from class init) @@ -1487,7 +1485,7 @@ class BaseEnv(ABC): # If it's ServerBaseline, we use APIServerConfig type for CLI args to allow overrides. if isinstance(default_server_configs_from_init, list): if default_server_configs_from_init and isinstance( - default_server_configs_from_init[0], APIServerConfig + default_server_configs_from_init[0], APIServerConfig ): openai_config_cls_for_cli = type(default_server_configs_from_init[0]) # Use the actual instance for default values later if it's a single config @@ -1530,8 +1528,8 @@ class BaseEnv(ABC): # Set default wandb name if not provided and class has a name wandb_name_attr = f"{ENV_NAMESPACE}{NAMESPACE_SEP}wandb_name" if ( - getattr(self, wandb_name_attr, None) is None - and cls.name is not None + getattr(self, wandb_name_attr, None) is None + and cls.name is not None ): setattr(self, wandb_name_attr, cls.name) @@ -1580,7 +1578,7 @@ class BaseEnv(ABC): ) if isinstance(default_server_configs_from_init, ServerBaseline) and ( - oai_cli_passed_args or yaml_oai_config + oai_cli_passed_args or yaml_oai_config ): # If config_init provided ServerBaseline, but CLI/YAML provides OpenAI specifics, # it implies an override intent for a single server. @@ -1659,13 +1657,13 @@ class BaseEnv(ABC): if isinstance(final_openai_configs, list): for cfg in final_openai_configs: if ( - isinstance(cfg, APIServerConfig) - and cfg.base_url - and ( + isinstance(cfg, APIServerConfig) + and cfg.base_url + and ( "localhost" in cfg.base_url or "0.0.0.0" in cfg.base_url or "127.0.0.1" in cfg.base_url - ) + ) ): warnings.warn( "You are using a local Base URL for an OpenAI compatible server in 'process' mode. " @@ -1674,13 +1672,13 @@ class BaseEnv(ABC): ) break # Warn once elif ( - isinstance(final_openai_configs, APIServerConfig) - and final_openai_configs.base_url - and ( - "localhost" in final_openai_configs.base_url - or "0.0.0.0" in final_openai_configs.base_url - or "127.0.0.1" in final_openai_configs.base_url - ) + isinstance(final_openai_configs, APIServerConfig) + and final_openai_configs.base_url + and ( + "localhost" in final_openai_configs.base_url + or "0.0.0.0" in final_openai_configs.base_url + or "127.0.0.1" in final_openai_configs.base_url + ) ): warnings.warn( "You are using a local Base URL for an OpenAI compatible server in 'process' mode. " @@ -1756,7 +1754,7 @@ class BaseEnv(ABC): # If it's ServerBaseline, we use APIServerConfig type for CLI args to allow overrides. if isinstance(default_server_configs_from_init, list): if default_server_configs_from_init and isinstance( - default_server_configs_from_init[0], APIServerConfig + default_server_configs_from_init[0], APIServerConfig ): openai_config_cls_for_cli = type(default_server_configs_from_init[0]) # Use the actual instance for default values later if it's a single config @@ -1799,8 +1797,8 @@ class BaseEnv(ABC): # Set default wandb name if not provided and class has a name wandb_name_attr = f"{ENV_NAMESPACE}{NAMESPACE_SEP}wandb_name" if ( - getattr(self, wandb_name_attr, None) is None - and cls.name is not None + getattr(self, wandb_name_attr, None) is None + and cls.name is not None ): setattr(self, wandb_name_attr, cls.name) @@ -1843,7 +1841,7 @@ class BaseEnv(ABC): ) if isinstance(default_server_configs_from_init, ServerBaseline) and ( - oai_cli_passed_args or yaml_oai_config + oai_cli_passed_args or yaml_oai_config ): # If config_init provided ServerBaseline, but CLI/YAML provides OpenAI specifics, # it implies an override intent for a single server. @@ -1922,13 +1920,13 @@ class BaseEnv(ABC): if isinstance(final_openai_configs, list): for cfg in final_openai_configs: if ( - isinstance(cfg, APIServerConfig) - and cfg.base_url - and ( + isinstance(cfg, APIServerConfig) + and cfg.base_url + and ( "localhost" in cfg.base_url or "0.0.0.0" in cfg.base_url or "127.0.0.1" in cfg.base_url - ) + ) ): warnings.warn( "You are using a local Base URL for an OpenAI compatible server in 'evaluate' mode. " @@ -1937,13 +1935,13 @@ class BaseEnv(ABC): ) break # Warn once elif ( - isinstance(final_openai_configs, APIServerConfig) - and final_openai_configs.base_url - and ( - "localhost" in final_openai_configs.base_url - or "0.0.0.0" in final_openai_configs.base_url - or "127.0.0.1" in final_openai_configs.base_url - ) + isinstance(final_openai_configs, APIServerConfig) + and final_openai_configs.base_url + and ( + "localhost" in final_openai_configs.base_url + or "0.0.0.0" in final_openai_configs.base_url + or "127.0.0.1" in final_openai_configs.base_url + ) ): warnings.warn( "You are using a local Base URL for an OpenAI compatible server in 'evaluate' mode. "