diff --git a/atroposlib/tests/test_env_imports.py b/atroposlib/tests/test_env_imports.py new file mode 100644 index 00000000..103da511 --- /dev/null +++ b/atroposlib/tests/test_env_imports.py @@ -0,0 +1,22 @@ +"""Regression tests for environment module imports. + +Ensures every environment module can be imported without errors +(e.g. no stale references to renamed symbols like OpenaiConfig). +""" + +import importlib + +import pytest + + +@pytest.mark.parametrize( + "module_path", + [ + "environments.sft_loader_server", + "environments.community.ufc_prediction_env.ufc_server", + "environments.community.ufc_prediction_env.ufc_image_env", + ], +) +def test_environment_module_imports(module_path): + """Each environment module should import without ImportError.""" + importlib.import_module(module_path) diff --git a/environments/community/ufc_prediction_env/ufc_image_env.py b/environments/community/ufc_prediction_env/ufc_image_env.py index a63f21db..8e3959eb 100644 --- a/environments/community/ufc_prediction_env/ufc_image_env.py +++ b/environments/community/ufc_prediction_env/ufc_image_env.py @@ -10,7 +10,7 @@ from typing import List, Optional, Tuple from PIL import Image from pydantic import Field -from atroposlib.envs.base import BaseEnv, BaseEnvConfig, OpenaiConfig, ScoredDataGroup +from atroposlib.envs.base import BaseEnv, BaseEnvConfig, APIServerConfig, ScoredDataGroup from atroposlib.type_definitions import GameHistory, Item from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer @@ -44,7 +44,7 @@ class UFCImageEnv(BaseEnv): def __init__( self, config: UFCImageEnvConfig, - server_configs: List[OpenaiConfig], + server_configs: List[APIServerConfig], slurm=True, testing=False, ): @@ -323,7 +323,7 @@ class UFCImageEnv(BaseEnv): return @classmethod - def config_init(cls) -> Tuple[BaseEnvConfig, List[OpenaiConfig]]: + def config_init(cls) -> Tuple[BaseEnvConfig, List[APIServerConfig]]: """Initialize configuration for the environment""" if not os.environ.get("OPENAI_API_KEY"): print("ERROR: OPENAI_API_KEY environment variable is not set!") @@ -343,7 +343,7 @@ class UFCImageEnv(BaseEnv): ) server_configs = [ - OpenaiConfig( + APIServerConfig( model_name="gpt-4o", base_url=None, api_key=os.environ.get("OPENAI_API_KEY"), diff --git a/environments/community/ufc_prediction_env/ufc_server.py b/environments/community/ufc_prediction_env/ufc_server.py index f127449d..adc0d1c4 100644 --- a/environments/community/ufc_prediction_env/ufc_server.py +++ b/environments/community/ufc_prediction_env/ufc_server.py @@ -7,7 +7,7 @@ from typing import List, Optional, Tuple from pydantic import Field -from atroposlib.envs.base import BaseEnv, BaseEnvConfig, OpenaiConfig, ScoredDataGroup +from atroposlib.envs.base import BaseEnv, BaseEnvConfig, APIServerConfig, ScoredDataGroup from atroposlib.type_definitions import GameHistory, Item from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer @@ -37,7 +37,7 @@ class UFCEnv(BaseEnv): def __init__( self, config: UFCEnvConfig, - server_configs: List[OpenaiConfig], + server_configs: List[APIServerConfig], slurm=True, testing=False, ): diff --git a/environments/sft_loader_server.py b/environments/sft_loader_server.py index af0811d0..a8c00863 100644 --- a/environments/sft_loader_server.py +++ b/environments/sft_loader_server.py @@ -5,7 +5,7 @@ from typing import Dict, List, Optional, Tuple from datasets import load_dataset from pydantic import Field -from atroposlib.envs.base import BaseEnv, BaseEnvConfig, OpenaiConfig, ScoredDataGroup +from atroposlib.envs.base import BaseEnv, BaseEnvConfig, APIServerConfig, ScoredDataGroup from atroposlib.type_definitions import Item @@ -58,7 +58,7 @@ class SFTEnv(BaseEnv): def __init__( self, config: SFTConfig, - server_configs: List[OpenaiConfig], + server_configs: List[APIServerConfig], slurm=True, testing=False, ): @@ -72,7 +72,7 @@ class SFTEnv(BaseEnv): self.last_step = -1 @classmethod - def config_init(cls) -> Tuple[BaseEnvConfig, List[OpenaiConfig]]: + def config_init(cls) -> Tuple[BaseEnvConfig, List[APIServerConfig]]: env_config = SFTConfig( tokenizer_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview", group_size=8, @@ -91,7 +91,7 @@ class SFTEnv(BaseEnv): max_sft_per_step=8, ) server_configs = [ - OpenaiConfig( + APIServerConfig( model_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview", base_url="http://localhost:9001/v1", api_key="x", @@ -235,7 +235,7 @@ async def checkout_formatting(): dataset_column_name="conversations", ), server_configs=[ - OpenaiConfig( + APIServerConfig( model_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview", base_url="http://localhost:9001/v1", api_key="x",