Merge branch 'main' into blackjack2-env

This commit is contained in:
Shannon Sands 2025-05-14 17:27:44 -07:00
commit 00dd120067
34 changed files with 1620 additions and 386 deletions

View file

@ -40,7 +40,8 @@ from atroposlib.utils.metrics import get_std_min_max_avg
from ..type_definitions import Item, Message
from .server_handling.server_manager import (
OpenaiConfig,
APIServer,
APIServerConfig,
ServerBaseline,
ServerManager,
ServerManagerConfig,
@ -163,13 +164,14 @@ class BaseEnvConfig(BaseModel):
class BaseEnv(ABC):
name = None
env_config_cls = BaseEnvConfig
name: Optional[str] = None
env_config_cls: BaseEnvConfig = BaseEnvConfig
server_cls: APIServer = APIServer
def __init__(
self,
config: BaseEnvConfig,
server_configs: Union[ServerBaseline, List[OpenaiConfig]],
server_configs: Union[ServerBaseline, List[APIServerConfig]],
slurm=False,
testing=False,
):
@ -184,7 +186,9 @@ class BaseEnv(ABC):
self.last_loop_time = None
self.last_completed_item = None
self.config = config
self.server = ServerManager(server_configs, slurm=slurm, testing=testing)
self.server = ServerManager(
server_configs, slurm=slurm, testing=testing, server_class=self.server_cls
)
self.workers = set()
self.eval_workers = set()
self.backlog = []
@ -234,7 +238,7 @@ class BaseEnv(ABC):
@classmethod
def config_init(
cls,
) -> Tuple[BaseEnvConfig, Union[ServerBaseline, List[OpenaiConfig]]]:
) -> Tuple[BaseEnvConfig, Union[ServerBaseline, List[APIServerConfig]]]:
"""
Initialize the config
"""
@ -1020,7 +1024,6 @@ class BaseEnv(ABC):
Returns:
type: The CliServeConfig class for serving commands.
"""
# Get the default configurations defined by the specific environment class
default_env_config, default_server_configs = cls.config_init()
@ -1032,8 +1035,8 @@ class BaseEnv(ABC):
class CliServeConfig(
get_prefixed_pydantic_model(type(default_env_config), env_full_prefix),
get_prefixed_pydantic_model(
OpenaiConfig, openai_full_prefix
), # Use OpenaiConfig for CLI args
APIServerConfig, openai_full_prefix
), # Use APIServerConfig for CLI args
ServerManagerConfig, # ServerManager args are not namespaced by default
Cmd,
):
@ -1089,7 +1092,7 @@ class BaseEnv(ABC):
oai_cli_passed_args or yaml_oai_config
):
raise ValueError(
"ServerBaseline is not compatible with OpenAI-namespaced CLI arguments. Please edit `config_init` directly or use OpenaiConfig." # noqa: E501
"ServerBaseline is not compatible with OpenAI-namespaced CLI arguments. Please edit `config_init` directly or use APIServerConfig." # noqa: E501
)
if (
isinstance(default_server_configs, list)
@ -1101,11 +1104,11 @@ class BaseEnv(ABC):
default_openai_config_ = default_server_configs
if isinstance(yaml_oai_config, list) and len(yaml_oai_config) == 1:
yaml_oai_config = yaml_oai_config[0]
if isinstance(default_openai_config_, OpenaiConfig) and isinstance(
if isinstance(default_openai_config_, APIServerConfig) and isinstance(
yaml_oai_config, dict
):
openai_config_dict = merge_dicts(
default_openai_config_.model_dump(), # Default OpenaiConfig (or from class init)
default_openai_config_.model_dump(), # Default APIServerConfig (or from class init)
yaml_oai_config,
oai_cli_passed_args,
)
@ -1189,7 +1192,7 @@ class BaseEnv(ABC):
data_path_to_save_groups=f"data/{cls.name or 'groups'}.jsonl",
use_wandb=True,
)
PROCESS_MODE_OPENAI_DEFAULT_CONFIG = OpenaiConfig(
PROCESS_MODE_OPENAI_DEFAULT_CONFIG = APIServerConfig(
model_name="gpt-4.1-nano",
base_url=None,
api_key=None,
@ -1200,10 +1203,7 @@ class BaseEnv(ABC):
)
# Get the base default configurations from the specific environment class
(
default_env_config,
default_server_configs,
) = cls.config_init()
default_env_config, default_server_configs = cls.config_init()
# Define namespace prefixes
env_full_prefix = f"{ENV_NAMESPACE}{NAMESPACE_SEP}"
@ -1215,8 +1215,7 @@ class BaseEnv(ABC):
type(default_env_config), PROCESS_MODE_ENV_DEFAULT_CONFIG
)
openai_config_cls_new_defaults = adjust_model_defaults(
OpenaiConfig,
PROCESS_MODE_OPENAI_DEFAULT_CONFIG,
APIServerConfig, PROCESS_MODE_OPENAI_DEFAULT_CONFIG
)
server_manager_config_cls_new_defaults = adjust_model_defaults(
ServerManagerConfig,
@ -1283,7 +1282,7 @@ class BaseEnv(ABC):
oai_cli_passed_args or yaml_oai_config
):
raise ValueError(
"ServerBaseline is not compatible with OpenAI-namespaced CLI arguments. Please edit `config_init` directly or use OpenaiConfig." # noqa: E501
"ServerBaseline is not compatible with OpenAI-namespaced CLI arguments. Please edit `config_init` directly or use APIServerConfig." # noqa: E501
)
if (
@ -1296,11 +1295,11 @@ class BaseEnv(ABC):
default_openai_config_ = default_server_configs
if isinstance(yaml_oai_config, list) and len(yaml_oai_config) == 1:
yaml_oai_config = yaml_oai_config[0]
if isinstance(default_openai_config_, OpenaiConfig) and isinstance(
if isinstance(default_openai_config_, APIServerConfig) and isinstance(
yaml_oai_config, dict
):
openai_config_dict = merge_dicts(
default_openai_config_.model_dump(), # Default OpenaiConfig (or from class init)
default_openai_config_.model_dump(), # Default APIServerConfig (or from class init)
PROCESS_MODE_OPENAI_DEFAULT_CONFIG.model_dump(), # Process Mode Defaults
yaml_oai_config,
oai_cli_passed_args,