revert base.py

This commit is contained in:
Dakota 2025-10-29 10:11:05 -05:00
parent c483840f59
commit 17bb7bdf15

View file

@ -185,11 +185,11 @@ class BaseEnv(ABC):
server_cls: APIServer = APIServer
def __init__(
self,
config: BaseEnvConfig,
server_configs: Union[ServerBaseline, List[APIServerConfig]],
slurm=False,
testing=False,
self,
config: BaseEnvConfig,
server_configs: Union[ServerBaseline, List[APIServerConfig]],
slurm=False,
testing=False,
):
self.items_sent_this_step = 0
self.eval_runner = None # type: Optional[asyncio.Task]
@ -272,7 +272,7 @@ class BaseEnv(ABC):
@classmethod
def config_init(
cls,
cls,
) -> Tuple[BaseEnvConfig, Union[ServerBaseline, List[APIServerConfig]]]:
"""
Initialize the config
@ -280,7 +280,7 @@ class BaseEnv(ABC):
return cls.env_config_cls(), ServerBaseline()
async def collect_trajectory(
self, item: Item
self, item: Item
) -> Tuple[Optional[Union[ScoredDataItem, Any]], List[Item]]:
raise NotImplementedError(
"Handle env single method must be implemented in subclass "
@ -339,8 +339,8 @@ class BaseEnv(ABC):
return to_postprocess, backlog
async def postprocess_histories(
self,
trajectories: Union[Optional[ScoredDataGroup], List[Optional[ScoredDataGroup]]],
self,
trajectories: Union[Optional[ScoredDataGroup], List[Optional[ScoredDataGroup]]],
) -> Union[Optional[ScoredDataGroup], List[Optional[ScoredDataGroup]]]:
"""
Postprocess the histories, this is called after the collect_trajectories method
@ -428,7 +428,7 @@ class BaseEnv(ABC):
while self.wandb_project is None:
async with aiohttp.ClientSession() as session:
async with session.get(
f"{self.config.rollout_server_url}/wandb_info"
f"{self.config.rollout_server_url}/wandb_info"
) as resp:
data = await parse_http_response(resp, logger)
self.wandb_group = data["group"]
@ -462,14 +462,14 @@ class BaseEnv(ABC):
try:
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.config.rollout_server_url}/register-env",
json={
"max_token_length": self.config.max_token_length,
"desired_name": self.config.wandb_name,
"weight": self.config.inference_weight,
"min_batch_allocation": self.config.min_batch_allocation,
"group_size": self.config.group_size,
},
f"{self.config.rollout_server_url}/register-env",
json={
"max_token_length": self.config.max_token_length,
"desired_name": self.config.wandb_name,
"weight": self.config.inference_weight,
"min_batch_allocation": self.config.min_batch_allocation,
"group_size": self.config.group_size,
},
) as resp:
data = await parse_http_response(resp, logger)
return data
@ -577,9 +577,9 @@ class BaseEnv(ABC):
return wandb_metrics
async def add_rollouts_for_wandb(
self,
scored_data: Union[ScoredDataGroup, List[ScoredDataGroup]],
item: Item = None,
self,
scored_data: Union[ScoredDataGroup, List[ScoredDataGroup]],
item: Item = None,
):
# Save rollout to trajectory
num_keep = self.config.num_rollouts_per_group_for_logging
@ -626,7 +626,7 @@ class BaseEnv(ABC):
self.completion_lengths
)
wandb_metrics["train/completion_lengths_p95"] = (
np.array(self.completion_lengths) > (0.95 * self.max_token_len)
np.array(self.completion_lengths) > (0.95 * self.max_token_len)
).mean()
wandb_metrics = await self.create_rollout_table(wandb_metrics)
wandb_metrics = self.perf_stats(wandb_metrics)
@ -642,15 +642,15 @@ class BaseEnv(ABC):
wandb.log(wandb_metrics, step=self.curr_step)
async def evaluate_log(
self,
metrics: Dict,
task_name: Optional[str] = None,
model_name: Optional[str] = None,
start_time: Optional[float] = None,
end_time: Optional[float] = None,
generation_parameters: Optional[Dict] = None,
samples: Optional[List[Dict]] = None,
verbose: bool = True,
self,
metrics: Dict,
task_name: Optional[str] = None,
model_name: Optional[str] = None,
start_time: Optional[float] = None,
end_time: Optional[float] = None,
generation_parameters: Optional[Dict] = None,
samples: Optional[List[Dict]] = None,
verbose: bool = True,
):
"""
Log evaluation results to a JSON file in the format expected by nous-evals.
@ -690,7 +690,7 @@ class BaseEnv(ABC):
# Get model name from first server config
first_server = self.server.servers[0]
if hasattr(first_server, "config") and hasattr(
first_server.config, "model_name"
first_server.config, "model_name"
):
model_name = first_server.config.model_name
if start_time is None:
@ -767,8 +767,8 @@ class BaseEnv(ABC):
)
async with aiohttp.ClientSession() as session:
async with session.post(
url,
json=scored_data,
url,
json=scored_data,
) as resp:
if resp.status >= 500:
# Server errors (5xx) should trigger a retry
@ -782,11 +782,11 @@ class BaseEnv(ABC):
print(await resp.text())
async def handle_send_to_api(
self,
scored_data: Union[ScoredDataGroup, List[ScoredDataGroup]],
item: Item = None,
do_send_to_api: bool = True,
abort_on_any_max_length_exceeded: bool = True,
self,
scored_data: Union[ScoredDataGroup, List[ScoredDataGroup]],
item: Item = None,
do_send_to_api: bool = True,
abort_on_any_max_length_exceeded: bool = True,
):
"""
Send the chats to the API with robust error handling and support for multiple ScoredDataGroups.
@ -810,7 +810,7 @@ class BaseEnv(ABC):
)
if not (
(None not in group) and (len(group.get("tokens", [])) == group_size)
(None not in group) and (len(group.get("tokens", [])) == group_size)
):
logger.warning(
f"Group structure invalid, or token count mismatch (expected {group_size}), "
@ -819,8 +819,8 @@ class BaseEnv(ABC):
continue
if (
self.config.ensure_scores_are_not_same
and len(set(group["scores"])) == 1
self.config.ensure_scores_are_not_same
and len(set(group["scores"])) == 1
):
logger.warning("Scores are the same in a group, skipping...")
continue
@ -838,7 +838,7 @@ class BaseEnv(ABC):
"ensure your trainer handles this appropriately."
)
elif abort_on_any_max_length_exceeded and any(
[len(x) >= self.max_token_len for x in group["tokens"]]
[len(x) >= self.max_token_len for x in group["tokens"]]
):
logger.warning("Token length is too long in a group, skipping...")
continue
@ -877,7 +877,7 @@ class BaseEnv(ABC):
print(f"Failed to send {data_type_str} after retries: {e}")
async def handle_env(
self, item_uuid: str
self, item_uuid: str
) -> Optional[Union[ScoredDataGroup, List[ScoredDataGroup]]]:
"""
Handle the rollout of an item
@ -938,8 +938,8 @@ class BaseEnv(ABC):
async def get_status(self):
async with aiohttp.ClientSession() as session:
async with session.get(
f"{self.config.rollout_server_url}/status-env",
json={"env_id": self.env_id},
f"{self.config.rollout_server_url}/status-env",
json={"env_id": self.env_id},
) as resp:
self.status_dict = await parse_http_response(resp, logger)
new_weight = self.status_dict["env_weight"]
@ -956,7 +956,7 @@ class BaseEnv(ABC):
if self.curr_step != self.status_dict["current_step"]:
if self.config.steps_per_eval > 0:
if (self.curr_step % self.config.steps_per_eval) > (
self.status_dict["current_step"] % self.config.steps_per_eval
self.status_dict["current_step"] % self.config.steps_per_eval
):
if (self.eval_runner is None) or (self.eval_runner.done()):
eval_task = asyncio.create_task(self.evaluate())
@ -976,11 +976,11 @@ class BaseEnv(ABC):
)
if self.checkpoint_interval > 0:
if (self.curr_step % self.checkpoint_interval) > (
self.status_dict["current_step"] % self.checkpoint_interval
self.status_dict["current_step"] % self.checkpoint_interval
):
checkpoint_step = (
self.status_dict["current_step"] // self.checkpoint_interval
) * self.checkpoint_interval
self.status_dict["current_step"] // self.checkpoint_interval
) * self.checkpoint_interval
self.save_checkpoint(checkpoint_step)
self.curr_step = self.status_dict["current_step"]
if self.items_sent_this_step >= self.config.min_items_sent_before_logging:
@ -1003,14 +1003,12 @@ class BaseEnv(ABC):
max_num_workers = min(
max_num_workers,
(
self.config.max_batches_offpolicy
* self.derived_batch_size
// self.config.group_size
self.config.max_batches_offpolicy
* self.derived_batch_size
// self.config.group_size
)
- (self.status_dict["queue_size"]),
)
- (self.status_dict["self_queue_size"]),
)
# now minimum num workers based on allocation
# Now if we have a minimum batch allocation, we need to add workers to fill the self queue, in case of
# overruns by other environments
if self.config.min_batch_allocation is not None:
@ -1018,31 +1016,31 @@ class BaseEnv(ABC):
0,
math.ceil(
(
(
(
math.ceil(
self.config.min_batch_allocation
* self.config.batch_size
* self.config.max_batches_offpolicy
/ self.status_dict["max_group_size"]
)
+ (
self.status_dict["max_group_size"]
// self.config.group_size
(
math.ceil(
self.config.min_batch_allocation
* self.config.batch_size
* self.config.max_batches_offpolicy
/ self.status_dict["max_group_size"]
)
+ (
self.status_dict["max_group_size"]
// self.config.group_size
)
)
* self.status_dict["max_group_size"]
)
- (
(
self.status_dict["max_group_size"]
* self.status_dict["self_queue_size"]
// (
self.status_dict["max_group_size"]
/ self.config.group_size
)
)
)
* self.status_dict["max_group_size"]
)
- (
(
self.status_dict["max_group_size"]
* self.status_dict["self_queue_size"]
// (
self.status_dict["max_group_size"]
/ self.config.group_size
)
)
)
)
/ self.config.group_size
),
@ -1124,45 +1122,45 @@ class BaseEnv(ABC):
await self.env_step_checks()
logger.info(f"env_manager: Status dict: {self.status_dict}")
if (
self.status_dict["current_step"]
+ (
self.status_dict["self_queue_size"]
* self.config.group_size
// self.config.batch_size
)
self.status_dict["current_step"]
+ (
self.status_dict["queue_size"]
* self.config.group_size
// self.config.batch_size
)
) > self.config.total_steps:
for worker in self.workers:
worker.cancel()
break
if (
(
self.status_dict["self_queue_size"] * self.config.group_size
>= self.config.max_batches_offpolicy * self.config.batch_size
)
and (self.config.max_batches_offpolicy > 0)
and (
(self.config.min_batch_allocation is None)
or (
(
(
(
math.ceil(
self.config.min_batch_allocation
* self.config.batch_size
* self.config.max_batches_offpolicy
/ self.status_dict["max_group_size"]
)
* (
self.status_dict["max_group_size"]
// self.config.group_size
)
)
)
- (self.status_dict["self_queue_size"])
)
<= 0
(
self.status_dict["queue_size"] * self.config.group_size
>= self.config.max_batches_offpolicy * self.config.batch_size
)
and (self.config.max_batches_offpolicy > 0)
and (
(self.config.min_batch_allocation is None)
or (
(
(
(
math.ceil(
self.config.min_batch_allocation
* self.config.batch_size
* self.config.max_batches_offpolicy
/ self.status_dict["max_group_size"]
)
* (
self.status_dict["max_group_size"]
// self.config.group_size
)
)
)
- (self.status_dict["self_queue_size"])
)
<= 0
)
)
)
) or (self.derived_batch_size == -1):
# We have too many, lets cleanup the tasks and wait a bit
self.backlog.extend([x["item"] for x in self.running_items.values()])
@ -1335,8 +1333,8 @@ class BaseEnv(ABC):
# Note: This modifies the 'self' instance based on CLI args before full parsing.
wandb_name_attr = f"{ENV_NAMESPACE}{NAMESPACE_SEP}wandb_name"
if (
getattr(self, wandb_name_attr, None) is None
and cls.name is not None
getattr(self, wandb_name_attr, None) is None
and cls.name is not None
):
setattr(self, wandb_name_attr, cls.name)
@ -1367,14 +1365,14 @@ class BaseEnv(ABC):
) # CLI args
yaml_oai_config = yaml_config.get(OPENAI_NAMESPACE, {})
if isinstance(default_server_configs, ServerBaseline) and (
oai_cli_passed_args or yaml_oai_config
oai_cli_passed_args or yaml_oai_config
):
raise ValueError(
"ServerBaseline is not compatible with OpenAI-namespaced CLI arguments. Please edit `config_init` directly or use APIServerConfig." # noqa: E501
)
if (
isinstance(default_server_configs, list)
and len(default_server_configs) == 1
isinstance(default_server_configs, list)
and len(default_server_configs) == 1
):
# can't use the same var name because it shadows the class variable and we get an error
default_openai_config_ = default_server_configs[0]
@ -1383,7 +1381,7 @@ class BaseEnv(ABC):
if isinstance(yaml_oai_config, list) and len(yaml_oai_config) == 1:
yaml_oai_config = yaml_oai_config[0]
if isinstance(default_openai_config_, APIServerConfig) and isinstance(
yaml_oai_config, dict
yaml_oai_config, dict
):
openai_config_dict = merge_dicts(
default_openai_config_.model_dump(), # Default APIServerConfig (or from class init)
@ -1487,7 +1485,7 @@ class BaseEnv(ABC):
# If it's ServerBaseline, we use APIServerConfig type for CLI args to allow overrides.
if isinstance(default_server_configs_from_init, list):
if default_server_configs_from_init and isinstance(
default_server_configs_from_init[0], APIServerConfig
default_server_configs_from_init[0], APIServerConfig
):
openai_config_cls_for_cli = type(default_server_configs_from_init[0])
# Use the actual instance for default values later if it's a single config
@ -1530,8 +1528,8 @@ class BaseEnv(ABC):
# Set default wandb name if not provided and class has a name
wandb_name_attr = f"{ENV_NAMESPACE}{NAMESPACE_SEP}wandb_name"
if (
getattr(self, wandb_name_attr, None) is None
and cls.name is not None
getattr(self, wandb_name_attr, None) is None
and cls.name is not None
):
setattr(self, wandb_name_attr, cls.name)
@ -1580,7 +1578,7 @@ class BaseEnv(ABC):
)
if isinstance(default_server_configs_from_init, ServerBaseline) and (
oai_cli_passed_args or yaml_oai_config
oai_cli_passed_args or yaml_oai_config
):
# If config_init provided ServerBaseline, but CLI/YAML provides OpenAI specifics,
# it implies an override intent for a single server.
@ -1659,13 +1657,13 @@ class BaseEnv(ABC):
if isinstance(final_openai_configs, list):
for cfg in final_openai_configs:
if (
isinstance(cfg, APIServerConfig)
and cfg.base_url
and (
isinstance(cfg, APIServerConfig)
and cfg.base_url
and (
"localhost" in cfg.base_url
or "0.0.0.0" in cfg.base_url
or "127.0.0.1" in cfg.base_url
)
)
):
warnings.warn(
"You are using a local Base URL for an OpenAI compatible server in 'process' mode. "
@ -1674,13 +1672,13 @@ class BaseEnv(ABC):
)
break # Warn once
elif (
isinstance(final_openai_configs, APIServerConfig)
and final_openai_configs.base_url
and (
"localhost" in final_openai_configs.base_url
or "0.0.0.0" in final_openai_configs.base_url
or "127.0.0.1" in final_openai_configs.base_url
)
isinstance(final_openai_configs, APIServerConfig)
and final_openai_configs.base_url
and (
"localhost" in final_openai_configs.base_url
or "0.0.0.0" in final_openai_configs.base_url
or "127.0.0.1" in final_openai_configs.base_url
)
):
warnings.warn(
"You are using a local Base URL for an OpenAI compatible server in 'process' mode. "
@ -1756,7 +1754,7 @@ class BaseEnv(ABC):
# If it's ServerBaseline, we use APIServerConfig type for CLI args to allow overrides.
if isinstance(default_server_configs_from_init, list):
if default_server_configs_from_init and isinstance(
default_server_configs_from_init[0], APIServerConfig
default_server_configs_from_init[0], APIServerConfig
):
openai_config_cls_for_cli = type(default_server_configs_from_init[0])
# Use the actual instance for default values later if it's a single config
@ -1799,8 +1797,8 @@ class BaseEnv(ABC):
# Set default wandb name if not provided and class has a name
wandb_name_attr = f"{ENV_NAMESPACE}{NAMESPACE_SEP}wandb_name"
if (
getattr(self, wandb_name_attr, None) is None
and cls.name is not None
getattr(self, wandb_name_attr, None) is None
and cls.name is not None
):
setattr(self, wandb_name_attr, cls.name)
@ -1843,7 +1841,7 @@ class BaseEnv(ABC):
)
if isinstance(default_server_configs_from_init, ServerBaseline) and (
oai_cli_passed_args or yaml_oai_config
oai_cli_passed_args or yaml_oai_config
):
# If config_init provided ServerBaseline, but CLI/YAML provides OpenAI specifics,
# it implies an override intent for a single server.
@ -1922,13 +1920,13 @@ class BaseEnv(ABC):
if isinstance(final_openai_configs, list):
for cfg in final_openai_configs:
if (
isinstance(cfg, APIServerConfig)
and cfg.base_url
and (
isinstance(cfg, APIServerConfig)
and cfg.base_url
and (
"localhost" in cfg.base_url
or "0.0.0.0" in cfg.base_url
or "127.0.0.1" in cfg.base_url
)
)
):
warnings.warn(
"You are using a local Base URL for an OpenAI compatible server in 'evaluate' mode. "
@ -1937,13 +1935,13 @@ class BaseEnv(ABC):
)
break # Warn once
elif (
isinstance(final_openai_configs, APIServerConfig)
and final_openai_configs.base_url
and (
"localhost" in final_openai_configs.base_url
or "0.0.0.0" in final_openai_configs.base_url
or "127.0.0.1" in final_openai_configs.base_url
)
isinstance(final_openai_configs, APIServerConfig)
and final_openai_configs.base_url
and (
"localhost" in final_openai_configs.base_url
or "0.0.0.0" in final_openai_configs.base_url
or "127.0.0.1" in final_openai_configs.base_url
)
):
warnings.warn(
"You are using a local Base URL for an OpenAI compatible server in 'evaluate' mode. "