mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-25 17:10:42 +00:00
revert base.py
This commit is contained in:
parent
c483840f59
commit
17bb7bdf15
1 changed files with 146 additions and 148 deletions
|
|
@ -185,11 +185,11 @@ class BaseEnv(ABC):
|
|||
server_cls: APIServer = APIServer
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: BaseEnvConfig,
|
||||
server_configs: Union[ServerBaseline, List[APIServerConfig]],
|
||||
slurm=False,
|
||||
testing=False,
|
||||
self,
|
||||
config: BaseEnvConfig,
|
||||
server_configs: Union[ServerBaseline, List[APIServerConfig]],
|
||||
slurm=False,
|
||||
testing=False,
|
||||
):
|
||||
self.items_sent_this_step = 0
|
||||
self.eval_runner = None # type: Optional[asyncio.Task]
|
||||
|
|
@ -272,7 +272,7 @@ class BaseEnv(ABC):
|
|||
|
||||
@classmethod
|
||||
def config_init(
|
||||
cls,
|
||||
cls,
|
||||
) -> Tuple[BaseEnvConfig, Union[ServerBaseline, List[APIServerConfig]]]:
|
||||
"""
|
||||
Initialize the config
|
||||
|
|
@ -280,7 +280,7 @@ class BaseEnv(ABC):
|
|||
return cls.env_config_cls(), ServerBaseline()
|
||||
|
||||
async def collect_trajectory(
|
||||
self, item: Item
|
||||
self, item: Item
|
||||
) -> Tuple[Optional[Union[ScoredDataItem, Any]], List[Item]]:
|
||||
raise NotImplementedError(
|
||||
"Handle env single method must be implemented in subclass "
|
||||
|
|
@ -339,8 +339,8 @@ class BaseEnv(ABC):
|
|||
return to_postprocess, backlog
|
||||
|
||||
async def postprocess_histories(
|
||||
self,
|
||||
trajectories: Union[Optional[ScoredDataGroup], List[Optional[ScoredDataGroup]]],
|
||||
self,
|
||||
trajectories: Union[Optional[ScoredDataGroup], List[Optional[ScoredDataGroup]]],
|
||||
) -> Union[Optional[ScoredDataGroup], List[Optional[ScoredDataGroup]]]:
|
||||
"""
|
||||
Postprocess the histories, this is called after the collect_trajectories method
|
||||
|
|
@ -428,7 +428,7 @@ class BaseEnv(ABC):
|
|||
while self.wandb_project is None:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(
|
||||
f"{self.config.rollout_server_url}/wandb_info"
|
||||
f"{self.config.rollout_server_url}/wandb_info"
|
||||
) as resp:
|
||||
data = await parse_http_response(resp, logger)
|
||||
self.wandb_group = data["group"]
|
||||
|
|
@ -462,14 +462,14 @@ class BaseEnv(ABC):
|
|||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{self.config.rollout_server_url}/register-env",
|
||||
json={
|
||||
"max_token_length": self.config.max_token_length,
|
||||
"desired_name": self.config.wandb_name,
|
||||
"weight": self.config.inference_weight,
|
||||
"min_batch_allocation": self.config.min_batch_allocation,
|
||||
"group_size": self.config.group_size,
|
||||
},
|
||||
f"{self.config.rollout_server_url}/register-env",
|
||||
json={
|
||||
"max_token_length": self.config.max_token_length,
|
||||
"desired_name": self.config.wandb_name,
|
||||
"weight": self.config.inference_weight,
|
||||
"min_batch_allocation": self.config.min_batch_allocation,
|
||||
"group_size": self.config.group_size,
|
||||
},
|
||||
) as resp:
|
||||
data = await parse_http_response(resp, logger)
|
||||
return data
|
||||
|
|
@ -577,9 +577,9 @@ class BaseEnv(ABC):
|
|||
return wandb_metrics
|
||||
|
||||
async def add_rollouts_for_wandb(
|
||||
self,
|
||||
scored_data: Union[ScoredDataGroup, List[ScoredDataGroup]],
|
||||
item: Item = None,
|
||||
self,
|
||||
scored_data: Union[ScoredDataGroup, List[ScoredDataGroup]],
|
||||
item: Item = None,
|
||||
):
|
||||
# Save rollout to trajectory
|
||||
num_keep = self.config.num_rollouts_per_group_for_logging
|
||||
|
|
@ -626,7 +626,7 @@ class BaseEnv(ABC):
|
|||
self.completion_lengths
|
||||
)
|
||||
wandb_metrics["train/completion_lengths_p95"] = (
|
||||
np.array(self.completion_lengths) > (0.95 * self.max_token_len)
|
||||
np.array(self.completion_lengths) > (0.95 * self.max_token_len)
|
||||
).mean()
|
||||
wandb_metrics = await self.create_rollout_table(wandb_metrics)
|
||||
wandb_metrics = self.perf_stats(wandb_metrics)
|
||||
|
|
@ -642,15 +642,15 @@ class BaseEnv(ABC):
|
|||
wandb.log(wandb_metrics, step=self.curr_step)
|
||||
|
||||
async def evaluate_log(
|
||||
self,
|
||||
metrics: Dict,
|
||||
task_name: Optional[str] = None,
|
||||
model_name: Optional[str] = None,
|
||||
start_time: Optional[float] = None,
|
||||
end_time: Optional[float] = None,
|
||||
generation_parameters: Optional[Dict] = None,
|
||||
samples: Optional[List[Dict]] = None,
|
||||
verbose: bool = True,
|
||||
self,
|
||||
metrics: Dict,
|
||||
task_name: Optional[str] = None,
|
||||
model_name: Optional[str] = None,
|
||||
start_time: Optional[float] = None,
|
||||
end_time: Optional[float] = None,
|
||||
generation_parameters: Optional[Dict] = None,
|
||||
samples: Optional[List[Dict]] = None,
|
||||
verbose: bool = True,
|
||||
):
|
||||
"""
|
||||
Log evaluation results to a JSON file in the format expected by nous-evals.
|
||||
|
|
@ -690,7 +690,7 @@ class BaseEnv(ABC):
|
|||
# Get model name from first server config
|
||||
first_server = self.server.servers[0]
|
||||
if hasattr(first_server, "config") and hasattr(
|
||||
first_server.config, "model_name"
|
||||
first_server.config, "model_name"
|
||||
):
|
||||
model_name = first_server.config.model_name
|
||||
if start_time is None:
|
||||
|
|
@ -767,8 +767,8 @@ class BaseEnv(ABC):
|
|||
)
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
url,
|
||||
json=scored_data,
|
||||
url,
|
||||
json=scored_data,
|
||||
) as resp:
|
||||
if resp.status >= 500:
|
||||
# Server errors (5xx) should trigger a retry
|
||||
|
|
@ -782,11 +782,11 @@ class BaseEnv(ABC):
|
|||
print(await resp.text())
|
||||
|
||||
async def handle_send_to_api(
|
||||
self,
|
||||
scored_data: Union[ScoredDataGroup, List[ScoredDataGroup]],
|
||||
item: Item = None,
|
||||
do_send_to_api: bool = True,
|
||||
abort_on_any_max_length_exceeded: bool = True,
|
||||
self,
|
||||
scored_data: Union[ScoredDataGroup, List[ScoredDataGroup]],
|
||||
item: Item = None,
|
||||
do_send_to_api: bool = True,
|
||||
abort_on_any_max_length_exceeded: bool = True,
|
||||
):
|
||||
"""
|
||||
Send the chats to the API with robust error handling and support for multiple ScoredDataGroups.
|
||||
|
|
@ -810,7 +810,7 @@ class BaseEnv(ABC):
|
|||
)
|
||||
|
||||
if not (
|
||||
(None not in group) and (len(group.get("tokens", [])) == group_size)
|
||||
(None not in group) and (len(group.get("tokens", [])) == group_size)
|
||||
):
|
||||
logger.warning(
|
||||
f"Group structure invalid, or token count mismatch (expected {group_size}), "
|
||||
|
|
@ -819,8 +819,8 @@ class BaseEnv(ABC):
|
|||
continue
|
||||
|
||||
if (
|
||||
self.config.ensure_scores_are_not_same
|
||||
and len(set(group["scores"])) == 1
|
||||
self.config.ensure_scores_are_not_same
|
||||
and len(set(group["scores"])) == 1
|
||||
):
|
||||
logger.warning("Scores are the same in a group, skipping...")
|
||||
continue
|
||||
|
|
@ -838,7 +838,7 @@ class BaseEnv(ABC):
|
|||
"ensure your trainer handles this appropriately."
|
||||
)
|
||||
elif abort_on_any_max_length_exceeded and any(
|
||||
[len(x) >= self.max_token_len for x in group["tokens"]]
|
||||
[len(x) >= self.max_token_len for x in group["tokens"]]
|
||||
):
|
||||
logger.warning("Token length is too long in a group, skipping...")
|
||||
continue
|
||||
|
|
@ -877,7 +877,7 @@ class BaseEnv(ABC):
|
|||
print(f"Failed to send {data_type_str} after retries: {e}")
|
||||
|
||||
async def handle_env(
|
||||
self, item_uuid: str
|
||||
self, item_uuid: str
|
||||
) -> Optional[Union[ScoredDataGroup, List[ScoredDataGroup]]]:
|
||||
"""
|
||||
Handle the rollout of an item
|
||||
|
|
@ -938,8 +938,8 @@ class BaseEnv(ABC):
|
|||
async def get_status(self):
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(
|
||||
f"{self.config.rollout_server_url}/status-env",
|
||||
json={"env_id": self.env_id},
|
||||
f"{self.config.rollout_server_url}/status-env",
|
||||
json={"env_id": self.env_id},
|
||||
) as resp:
|
||||
self.status_dict = await parse_http_response(resp, logger)
|
||||
new_weight = self.status_dict["env_weight"]
|
||||
|
|
@ -956,7 +956,7 @@ class BaseEnv(ABC):
|
|||
if self.curr_step != self.status_dict["current_step"]:
|
||||
if self.config.steps_per_eval > 0:
|
||||
if (self.curr_step % self.config.steps_per_eval) > (
|
||||
self.status_dict["current_step"] % self.config.steps_per_eval
|
||||
self.status_dict["current_step"] % self.config.steps_per_eval
|
||||
):
|
||||
if (self.eval_runner is None) or (self.eval_runner.done()):
|
||||
eval_task = asyncio.create_task(self.evaluate())
|
||||
|
|
@ -976,11 +976,11 @@ class BaseEnv(ABC):
|
|||
)
|
||||
if self.checkpoint_interval > 0:
|
||||
if (self.curr_step % self.checkpoint_interval) > (
|
||||
self.status_dict["current_step"] % self.checkpoint_interval
|
||||
self.status_dict["current_step"] % self.checkpoint_interval
|
||||
):
|
||||
checkpoint_step = (
|
||||
self.status_dict["current_step"] // self.checkpoint_interval
|
||||
) * self.checkpoint_interval
|
||||
self.status_dict["current_step"] // self.checkpoint_interval
|
||||
) * self.checkpoint_interval
|
||||
self.save_checkpoint(checkpoint_step)
|
||||
self.curr_step = self.status_dict["current_step"]
|
||||
if self.items_sent_this_step >= self.config.min_items_sent_before_logging:
|
||||
|
|
@ -1003,14 +1003,12 @@ class BaseEnv(ABC):
|
|||
max_num_workers = min(
|
||||
max_num_workers,
|
||||
(
|
||||
self.config.max_batches_offpolicy
|
||||
* self.derived_batch_size
|
||||
// self.config.group_size
|
||||
self.config.max_batches_offpolicy
|
||||
* self.derived_batch_size
|
||||
// self.config.group_size
|
||||
)
|
||||
- (self.status_dict["queue_size"]),
|
||||
)
|
||||
- (self.status_dict["self_queue_size"]),
|
||||
)
|
||||
# now minimum num workers based on allocation
|
||||
|
||||
# Now if we have a minimum batch allocation, we need to add workers to fill the self queue, in case of
|
||||
# overruns by other environments
|
||||
if self.config.min_batch_allocation is not None:
|
||||
|
|
@ -1018,31 +1016,31 @@ class BaseEnv(ABC):
|
|||
0,
|
||||
math.ceil(
|
||||
(
|
||||
(
|
||||
(
|
||||
math.ceil(
|
||||
self.config.min_batch_allocation
|
||||
* self.config.batch_size
|
||||
* self.config.max_batches_offpolicy
|
||||
/ self.status_dict["max_group_size"]
|
||||
)
|
||||
+ (
|
||||
self.status_dict["max_group_size"]
|
||||
// self.config.group_size
|
||||
(
|
||||
math.ceil(
|
||||
self.config.min_batch_allocation
|
||||
* self.config.batch_size
|
||||
* self.config.max_batches_offpolicy
|
||||
/ self.status_dict["max_group_size"]
|
||||
)
|
||||
+ (
|
||||
self.status_dict["max_group_size"]
|
||||
// self.config.group_size
|
||||
)
|
||||
)
|
||||
* self.status_dict["max_group_size"]
|
||||
)
|
||||
- (
|
||||
(
|
||||
self.status_dict["max_group_size"]
|
||||
* self.status_dict["self_queue_size"]
|
||||
// (
|
||||
self.status_dict["max_group_size"]
|
||||
/ self.config.group_size
|
||||
)
|
||||
)
|
||||
)
|
||||
* self.status_dict["max_group_size"]
|
||||
)
|
||||
- (
|
||||
(
|
||||
self.status_dict["max_group_size"]
|
||||
* self.status_dict["self_queue_size"]
|
||||
// (
|
||||
self.status_dict["max_group_size"]
|
||||
/ self.config.group_size
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
/ self.config.group_size
|
||||
),
|
||||
|
|
@ -1124,45 +1122,45 @@ class BaseEnv(ABC):
|
|||
await self.env_step_checks()
|
||||
logger.info(f"env_manager: Status dict: {self.status_dict}")
|
||||
if (
|
||||
self.status_dict["current_step"]
|
||||
+ (
|
||||
self.status_dict["self_queue_size"]
|
||||
* self.config.group_size
|
||||
// self.config.batch_size
|
||||
)
|
||||
self.status_dict["current_step"]
|
||||
+ (
|
||||
self.status_dict["queue_size"]
|
||||
* self.config.group_size
|
||||
// self.config.batch_size
|
||||
)
|
||||
) > self.config.total_steps:
|
||||
for worker in self.workers:
|
||||
worker.cancel()
|
||||
break
|
||||
if (
|
||||
(
|
||||
self.status_dict["self_queue_size"] * self.config.group_size
|
||||
>= self.config.max_batches_offpolicy * self.config.batch_size
|
||||
)
|
||||
and (self.config.max_batches_offpolicy > 0)
|
||||
and (
|
||||
(self.config.min_batch_allocation is None)
|
||||
or (
|
||||
(
|
||||
(
|
||||
(
|
||||
math.ceil(
|
||||
self.config.min_batch_allocation
|
||||
* self.config.batch_size
|
||||
* self.config.max_batches_offpolicy
|
||||
/ self.status_dict["max_group_size"]
|
||||
)
|
||||
* (
|
||||
self.status_dict["max_group_size"]
|
||||
// self.config.group_size
|
||||
)
|
||||
)
|
||||
)
|
||||
- (self.status_dict["self_queue_size"])
|
||||
)
|
||||
<= 0
|
||||
(
|
||||
self.status_dict["queue_size"] * self.config.group_size
|
||||
>= self.config.max_batches_offpolicy * self.config.batch_size
|
||||
)
|
||||
and (self.config.max_batches_offpolicy > 0)
|
||||
and (
|
||||
(self.config.min_batch_allocation is None)
|
||||
or (
|
||||
(
|
||||
(
|
||||
(
|
||||
math.ceil(
|
||||
self.config.min_batch_allocation
|
||||
* self.config.batch_size
|
||||
* self.config.max_batches_offpolicy
|
||||
/ self.status_dict["max_group_size"]
|
||||
)
|
||||
* (
|
||||
self.status_dict["max_group_size"]
|
||||
// self.config.group_size
|
||||
)
|
||||
)
|
||||
)
|
||||
- (self.status_dict["self_queue_size"])
|
||||
)
|
||||
<= 0
|
||||
)
|
||||
)
|
||||
)
|
||||
) or (self.derived_batch_size == -1):
|
||||
# We have too many, lets cleanup the tasks and wait a bit
|
||||
self.backlog.extend([x["item"] for x in self.running_items.values()])
|
||||
|
|
@ -1335,8 +1333,8 @@ class BaseEnv(ABC):
|
|||
# Note: This modifies the 'self' instance based on CLI args before full parsing.
|
||||
wandb_name_attr = f"{ENV_NAMESPACE}{NAMESPACE_SEP}wandb_name"
|
||||
if (
|
||||
getattr(self, wandb_name_attr, None) is None
|
||||
and cls.name is not None
|
||||
getattr(self, wandb_name_attr, None) is None
|
||||
and cls.name is not None
|
||||
):
|
||||
setattr(self, wandb_name_attr, cls.name)
|
||||
|
||||
|
|
@ -1367,14 +1365,14 @@ class BaseEnv(ABC):
|
|||
) # CLI args
|
||||
yaml_oai_config = yaml_config.get(OPENAI_NAMESPACE, {})
|
||||
if isinstance(default_server_configs, ServerBaseline) and (
|
||||
oai_cli_passed_args or yaml_oai_config
|
||||
oai_cli_passed_args or yaml_oai_config
|
||||
):
|
||||
raise ValueError(
|
||||
"ServerBaseline is not compatible with OpenAI-namespaced CLI arguments. Please edit `config_init` directly or use APIServerConfig." # noqa: E501
|
||||
)
|
||||
if (
|
||||
isinstance(default_server_configs, list)
|
||||
and len(default_server_configs) == 1
|
||||
isinstance(default_server_configs, list)
|
||||
and len(default_server_configs) == 1
|
||||
):
|
||||
# can't use the same var name because it shadows the class variable and we get an error
|
||||
default_openai_config_ = default_server_configs[0]
|
||||
|
|
@ -1383,7 +1381,7 @@ class BaseEnv(ABC):
|
|||
if isinstance(yaml_oai_config, list) and len(yaml_oai_config) == 1:
|
||||
yaml_oai_config = yaml_oai_config[0]
|
||||
if isinstance(default_openai_config_, APIServerConfig) and isinstance(
|
||||
yaml_oai_config, dict
|
||||
yaml_oai_config, dict
|
||||
):
|
||||
openai_config_dict = merge_dicts(
|
||||
default_openai_config_.model_dump(), # Default APIServerConfig (or from class init)
|
||||
|
|
@ -1487,7 +1485,7 @@ class BaseEnv(ABC):
|
|||
# If it's ServerBaseline, we use APIServerConfig type for CLI args to allow overrides.
|
||||
if isinstance(default_server_configs_from_init, list):
|
||||
if default_server_configs_from_init and isinstance(
|
||||
default_server_configs_from_init[0], APIServerConfig
|
||||
default_server_configs_from_init[0], APIServerConfig
|
||||
):
|
||||
openai_config_cls_for_cli = type(default_server_configs_from_init[0])
|
||||
# Use the actual instance for default values later if it's a single config
|
||||
|
|
@ -1530,8 +1528,8 @@ class BaseEnv(ABC):
|
|||
# Set default wandb name if not provided and class has a name
|
||||
wandb_name_attr = f"{ENV_NAMESPACE}{NAMESPACE_SEP}wandb_name"
|
||||
if (
|
||||
getattr(self, wandb_name_attr, None) is None
|
||||
and cls.name is not None
|
||||
getattr(self, wandb_name_attr, None) is None
|
||||
and cls.name is not None
|
||||
):
|
||||
setattr(self, wandb_name_attr, cls.name)
|
||||
|
||||
|
|
@ -1580,7 +1578,7 @@ class BaseEnv(ABC):
|
|||
)
|
||||
|
||||
if isinstance(default_server_configs_from_init, ServerBaseline) and (
|
||||
oai_cli_passed_args or yaml_oai_config
|
||||
oai_cli_passed_args or yaml_oai_config
|
||||
):
|
||||
# If config_init provided ServerBaseline, but CLI/YAML provides OpenAI specifics,
|
||||
# it implies an override intent for a single server.
|
||||
|
|
@ -1659,13 +1657,13 @@ class BaseEnv(ABC):
|
|||
if isinstance(final_openai_configs, list):
|
||||
for cfg in final_openai_configs:
|
||||
if (
|
||||
isinstance(cfg, APIServerConfig)
|
||||
and cfg.base_url
|
||||
and (
|
||||
isinstance(cfg, APIServerConfig)
|
||||
and cfg.base_url
|
||||
and (
|
||||
"localhost" in cfg.base_url
|
||||
or "0.0.0.0" in cfg.base_url
|
||||
or "127.0.0.1" in cfg.base_url
|
||||
)
|
||||
)
|
||||
):
|
||||
warnings.warn(
|
||||
"You are using a local Base URL for an OpenAI compatible server in 'process' mode. "
|
||||
|
|
@ -1674,13 +1672,13 @@ class BaseEnv(ABC):
|
|||
)
|
||||
break # Warn once
|
||||
elif (
|
||||
isinstance(final_openai_configs, APIServerConfig)
|
||||
and final_openai_configs.base_url
|
||||
and (
|
||||
"localhost" in final_openai_configs.base_url
|
||||
or "0.0.0.0" in final_openai_configs.base_url
|
||||
or "127.0.0.1" in final_openai_configs.base_url
|
||||
)
|
||||
isinstance(final_openai_configs, APIServerConfig)
|
||||
and final_openai_configs.base_url
|
||||
and (
|
||||
"localhost" in final_openai_configs.base_url
|
||||
or "0.0.0.0" in final_openai_configs.base_url
|
||||
or "127.0.0.1" in final_openai_configs.base_url
|
||||
)
|
||||
):
|
||||
warnings.warn(
|
||||
"You are using a local Base URL for an OpenAI compatible server in 'process' mode. "
|
||||
|
|
@ -1756,7 +1754,7 @@ class BaseEnv(ABC):
|
|||
# If it's ServerBaseline, we use APIServerConfig type for CLI args to allow overrides.
|
||||
if isinstance(default_server_configs_from_init, list):
|
||||
if default_server_configs_from_init and isinstance(
|
||||
default_server_configs_from_init[0], APIServerConfig
|
||||
default_server_configs_from_init[0], APIServerConfig
|
||||
):
|
||||
openai_config_cls_for_cli = type(default_server_configs_from_init[0])
|
||||
# Use the actual instance for default values later if it's a single config
|
||||
|
|
@ -1799,8 +1797,8 @@ class BaseEnv(ABC):
|
|||
# Set default wandb name if not provided and class has a name
|
||||
wandb_name_attr = f"{ENV_NAMESPACE}{NAMESPACE_SEP}wandb_name"
|
||||
if (
|
||||
getattr(self, wandb_name_attr, None) is None
|
||||
and cls.name is not None
|
||||
getattr(self, wandb_name_attr, None) is None
|
||||
and cls.name is not None
|
||||
):
|
||||
setattr(self, wandb_name_attr, cls.name)
|
||||
|
||||
|
|
@ -1843,7 +1841,7 @@ class BaseEnv(ABC):
|
|||
)
|
||||
|
||||
if isinstance(default_server_configs_from_init, ServerBaseline) and (
|
||||
oai_cli_passed_args or yaml_oai_config
|
||||
oai_cli_passed_args or yaml_oai_config
|
||||
):
|
||||
# If config_init provided ServerBaseline, but CLI/YAML provides OpenAI specifics,
|
||||
# it implies an override intent for a single server.
|
||||
|
|
@ -1922,13 +1920,13 @@ class BaseEnv(ABC):
|
|||
if isinstance(final_openai_configs, list):
|
||||
for cfg in final_openai_configs:
|
||||
if (
|
||||
isinstance(cfg, APIServerConfig)
|
||||
and cfg.base_url
|
||||
and (
|
||||
isinstance(cfg, APIServerConfig)
|
||||
and cfg.base_url
|
||||
and (
|
||||
"localhost" in cfg.base_url
|
||||
or "0.0.0.0" in cfg.base_url
|
||||
or "127.0.0.1" in cfg.base_url
|
||||
)
|
||||
)
|
||||
):
|
||||
warnings.warn(
|
||||
"You are using a local Base URL for an OpenAI compatible server in 'evaluate' mode. "
|
||||
|
|
@ -1937,13 +1935,13 @@ class BaseEnv(ABC):
|
|||
)
|
||||
break # Warn once
|
||||
elif (
|
||||
isinstance(final_openai_configs, APIServerConfig)
|
||||
and final_openai_configs.base_url
|
||||
and (
|
||||
"localhost" in final_openai_configs.base_url
|
||||
or "0.0.0.0" in final_openai_configs.base_url
|
||||
or "127.0.0.1" in final_openai_configs.base_url
|
||||
)
|
||||
isinstance(final_openai_configs, APIServerConfig)
|
||||
and final_openai_configs.base_url
|
||||
and (
|
||||
"localhost" in final_openai_configs.base_url
|
||||
or "0.0.0.0" in final_openai_configs.base_url
|
||||
or "127.0.0.1" in final_openai_configs.base_url
|
||||
)
|
||||
):
|
||||
warnings.warn(
|
||||
"You are using a local Base URL for an OpenAI compatible server in 'evaluate' mode. "
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue