mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-22 16:48:57 +00:00
Merge branch 'main' into blackjack2-env
This commit is contained in:
commit
00dd120067
34 changed files with 1620 additions and 386 deletions
|
|
@ -40,7 +40,8 @@ from atroposlib.utils.metrics import get_std_min_max_avg
|
|||
|
||||
from ..type_definitions import Item, Message
|
||||
from .server_handling.server_manager import (
|
||||
OpenaiConfig,
|
||||
APIServer,
|
||||
APIServerConfig,
|
||||
ServerBaseline,
|
||||
ServerManager,
|
||||
ServerManagerConfig,
|
||||
|
|
@ -163,13 +164,14 @@ class BaseEnvConfig(BaseModel):
|
|||
|
||||
class BaseEnv(ABC):
|
||||
|
||||
name = None
|
||||
env_config_cls = BaseEnvConfig
|
||||
name: Optional[str] = None
|
||||
env_config_cls: BaseEnvConfig = BaseEnvConfig
|
||||
server_cls: APIServer = APIServer
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: BaseEnvConfig,
|
||||
server_configs: Union[ServerBaseline, List[OpenaiConfig]],
|
||||
server_configs: Union[ServerBaseline, List[APIServerConfig]],
|
||||
slurm=False,
|
||||
testing=False,
|
||||
):
|
||||
|
|
@ -184,7 +186,9 @@ class BaseEnv(ABC):
|
|||
self.last_loop_time = None
|
||||
self.last_completed_item = None
|
||||
self.config = config
|
||||
self.server = ServerManager(server_configs, slurm=slurm, testing=testing)
|
||||
self.server = ServerManager(
|
||||
server_configs, slurm=slurm, testing=testing, server_class=self.server_cls
|
||||
)
|
||||
self.workers = set()
|
||||
self.eval_workers = set()
|
||||
self.backlog = []
|
||||
|
|
@ -234,7 +238,7 @@ class BaseEnv(ABC):
|
|||
@classmethod
|
||||
def config_init(
|
||||
cls,
|
||||
) -> Tuple[BaseEnvConfig, Union[ServerBaseline, List[OpenaiConfig]]]:
|
||||
) -> Tuple[BaseEnvConfig, Union[ServerBaseline, List[APIServerConfig]]]:
|
||||
"""
|
||||
Initialize the config
|
||||
"""
|
||||
|
|
@ -1020,7 +1024,6 @@ class BaseEnv(ABC):
|
|||
Returns:
|
||||
type: The CliServeConfig class for serving commands.
|
||||
"""
|
||||
|
||||
# Get the default configurations defined by the specific environment class
|
||||
default_env_config, default_server_configs = cls.config_init()
|
||||
|
||||
|
|
@ -1032,8 +1035,8 @@ class BaseEnv(ABC):
|
|||
class CliServeConfig(
|
||||
get_prefixed_pydantic_model(type(default_env_config), env_full_prefix),
|
||||
get_prefixed_pydantic_model(
|
||||
OpenaiConfig, openai_full_prefix
|
||||
), # Use OpenaiConfig for CLI args
|
||||
APIServerConfig, openai_full_prefix
|
||||
), # Use APIServerConfig for CLI args
|
||||
ServerManagerConfig, # ServerManager args are not namespaced by default
|
||||
Cmd,
|
||||
):
|
||||
|
|
@ -1089,7 +1092,7 @@ class BaseEnv(ABC):
|
|||
oai_cli_passed_args or yaml_oai_config
|
||||
):
|
||||
raise ValueError(
|
||||
"ServerBaseline is not compatible with OpenAI-namespaced CLI arguments. Please edit `config_init` directly or use OpenaiConfig." # noqa: E501
|
||||
"ServerBaseline is not compatible with OpenAI-namespaced CLI arguments. Please edit `config_init` directly or use APIServerConfig." # noqa: E501
|
||||
)
|
||||
if (
|
||||
isinstance(default_server_configs, list)
|
||||
|
|
@ -1101,11 +1104,11 @@ class BaseEnv(ABC):
|
|||
default_openai_config_ = default_server_configs
|
||||
if isinstance(yaml_oai_config, list) and len(yaml_oai_config) == 1:
|
||||
yaml_oai_config = yaml_oai_config[0]
|
||||
if isinstance(default_openai_config_, OpenaiConfig) and isinstance(
|
||||
if isinstance(default_openai_config_, APIServerConfig) and isinstance(
|
||||
yaml_oai_config, dict
|
||||
):
|
||||
openai_config_dict = merge_dicts(
|
||||
default_openai_config_.model_dump(), # Default OpenaiConfig (or from class init)
|
||||
default_openai_config_.model_dump(), # Default APIServerConfig (or from class init)
|
||||
yaml_oai_config,
|
||||
oai_cli_passed_args,
|
||||
)
|
||||
|
|
@ -1189,7 +1192,7 @@ class BaseEnv(ABC):
|
|||
data_path_to_save_groups=f"data/{cls.name or 'groups'}.jsonl",
|
||||
use_wandb=True,
|
||||
)
|
||||
PROCESS_MODE_OPENAI_DEFAULT_CONFIG = OpenaiConfig(
|
||||
PROCESS_MODE_OPENAI_DEFAULT_CONFIG = APIServerConfig(
|
||||
model_name="gpt-4.1-nano",
|
||||
base_url=None,
|
||||
api_key=None,
|
||||
|
|
@ -1200,10 +1203,7 @@ class BaseEnv(ABC):
|
|||
)
|
||||
|
||||
# Get the base default configurations from the specific environment class
|
||||
(
|
||||
default_env_config,
|
||||
default_server_configs,
|
||||
) = cls.config_init()
|
||||
default_env_config, default_server_configs = cls.config_init()
|
||||
|
||||
# Define namespace prefixes
|
||||
env_full_prefix = f"{ENV_NAMESPACE}{NAMESPACE_SEP}"
|
||||
|
|
@ -1215,8 +1215,7 @@ class BaseEnv(ABC):
|
|||
type(default_env_config), PROCESS_MODE_ENV_DEFAULT_CONFIG
|
||||
)
|
||||
openai_config_cls_new_defaults = adjust_model_defaults(
|
||||
OpenaiConfig,
|
||||
PROCESS_MODE_OPENAI_DEFAULT_CONFIG,
|
||||
APIServerConfig, PROCESS_MODE_OPENAI_DEFAULT_CONFIG
|
||||
)
|
||||
server_manager_config_cls_new_defaults = adjust_model_defaults(
|
||||
ServerManagerConfig,
|
||||
|
|
@ -1283,7 +1282,7 @@ class BaseEnv(ABC):
|
|||
oai_cli_passed_args or yaml_oai_config
|
||||
):
|
||||
raise ValueError(
|
||||
"ServerBaseline is not compatible with OpenAI-namespaced CLI arguments. Please edit `config_init` directly or use OpenaiConfig." # noqa: E501
|
||||
"ServerBaseline is not compatible with OpenAI-namespaced CLI arguments. Please edit `config_init` directly or use APIServerConfig." # noqa: E501
|
||||
)
|
||||
|
||||
if (
|
||||
|
|
@ -1296,11 +1295,11 @@ class BaseEnv(ABC):
|
|||
default_openai_config_ = default_server_configs
|
||||
if isinstance(yaml_oai_config, list) and len(yaml_oai_config) == 1:
|
||||
yaml_oai_config = yaml_oai_config[0]
|
||||
if isinstance(default_openai_config_, OpenaiConfig) and isinstance(
|
||||
if isinstance(default_openai_config_, APIServerConfig) and isinstance(
|
||||
yaml_oai_config, dict
|
||||
):
|
||||
openai_config_dict = merge_dicts(
|
||||
default_openai_config_.model_dump(), # Default OpenaiConfig (or from class init)
|
||||
default_openai_config_.model_dump(), # Default APIServerConfig (or from class init)
|
||||
PROCESS_MODE_OPENAI_DEFAULT_CONFIG.model_dump(), # Process Mode Defaults
|
||||
yaml_oai_config,
|
||||
oai_cli_passed_args,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue