Revert "merged latest"

This reverts commit d768ad68aa.
This commit is contained in:
Shannon Sands 2025-05-15 12:11:05 -07:00
parent 9a28777332
commit bba93552f5
8 changed files with 3231 additions and 27 deletions

View file

@ -9,12 +9,7 @@ from typing import Dict, List, Optional, Tuple, Union
from dotenv import load_dotenv
from openai import AsyncOpenAI
from atroposlib.envs.base import (
APIServerConfig,
BaseEnv,
BaseEnvConfig,
ScoredDataGroup,
)
from atroposlib.envs.base import BaseEnv, BaseEnvConfig, OpenaiConfig, ScoredDataGroup
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
from .curriculum import MathCurriculum
@ -81,7 +76,7 @@ class InfiniteMathEnv(BaseEnv):
def __init__(
self,
config: InfiniteMathEnvConfig,
server_configs: Union[List[APIServerConfig], APIServerConfig],
server_configs: Union[List[OpenaiConfig], OpenaiConfig],
slurm=True,
testing=False,
):
@ -616,10 +611,10 @@ class InfiniteMathEnv(BaseEnv):
return scored_data
@classmethod
def config_init(cls) -> Tuple[InfiniteMathEnvConfig, List[APIServerConfig]]:
def config_init(cls) -> Tuple[InfiniteMathEnvConfig, List[OpenaiConfig]]:
"""Initialize environment and OpenAI configurations with default values."""
env_config = InfiniteMathEnvConfig(
tokenizer_name="NousResearch/Nous-Hermes-3-Llama-3-8B-Preview",
tokenizer_name="NousResearch/Nous-Hermes-2-Yi-34B",
group_size=8,
use_wandb=True,
max_num_workers=64,
@ -649,8 +644,8 @@ class InfiniteMathEnv(BaseEnv):
)
server_configs = [
APIServerConfig(
model_name="NousResearch/Nous-Hermes-3-Llama-3-8B-Preview",
OpenaiConfig(
model_name="NousResearch/Nous-Hermes-2-Yi-34B",
base_url="http://localhost:9004/v1",
api_key="x",
num_requests_for_eval=64,
@ -658,6 +653,11 @@ class InfiniteMathEnv(BaseEnv):
]
return env_config, server_configs
@classmethod
def cli(cls):
"""Command Line Interface runner for the environment."""
super().cli()
if __name__ == "__main__":
InfiniteMathEnv.cli()