mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-29 17:35:07 +00:00
parent
9a28777332
commit
bba93552f5
8 changed files with 3231 additions and 27 deletions
|
|
@ -9,12 +9,7 @@ from typing import Dict, List, Optional, Tuple, Union
|
|||
from dotenv import load_dotenv
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
from atroposlib.envs.base import (
|
||||
APIServerConfig,
|
||||
BaseEnv,
|
||||
BaseEnvConfig,
|
||||
ScoredDataGroup,
|
||||
)
|
||||
from atroposlib.envs.base import BaseEnv, BaseEnvConfig, OpenaiConfig, ScoredDataGroup
|
||||
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
|
||||
|
||||
from .curriculum import MathCurriculum
|
||||
|
|
@ -81,7 +76,7 @@ class InfiniteMathEnv(BaseEnv):
|
|||
def __init__(
|
||||
self,
|
||||
config: InfiniteMathEnvConfig,
|
||||
server_configs: Union[List[APIServerConfig], APIServerConfig],
|
||||
server_configs: Union[List[OpenaiConfig], OpenaiConfig],
|
||||
slurm=True,
|
||||
testing=False,
|
||||
):
|
||||
|
|
@ -616,10 +611,10 @@ class InfiniteMathEnv(BaseEnv):
|
|||
return scored_data
|
||||
|
||||
@classmethod
|
||||
def config_init(cls) -> Tuple[InfiniteMathEnvConfig, List[APIServerConfig]]:
|
||||
def config_init(cls) -> Tuple[InfiniteMathEnvConfig, List[OpenaiConfig]]:
|
||||
"""Initialize environment and OpenAI configurations with default values."""
|
||||
env_config = InfiniteMathEnvConfig(
|
||||
tokenizer_name="NousResearch/Nous-Hermes-3-Llama-3-8B-Preview",
|
||||
tokenizer_name="NousResearch/Nous-Hermes-2-Yi-34B",
|
||||
group_size=8,
|
||||
use_wandb=True,
|
||||
max_num_workers=64,
|
||||
|
|
@ -649,8 +644,8 @@ class InfiniteMathEnv(BaseEnv):
|
|||
)
|
||||
|
||||
server_configs = [
|
||||
APIServerConfig(
|
||||
model_name="NousResearch/Nous-Hermes-3-Llama-3-8B-Preview",
|
||||
OpenaiConfig(
|
||||
model_name="NousResearch/Nous-Hermes-2-Yi-34B",
|
||||
base_url="http://localhost:9004/v1",
|
||||
api_key="x",
|
||||
num_requests_for_eval=64,
|
||||
|
|
@ -658,6 +653,11 @@ class InfiniteMathEnv(BaseEnv):
|
|||
]
|
||||
return env_config, server_configs
|
||||
|
||||
@classmethod
|
||||
def cli(cls):
|
||||
"""Command Line Interface runner for the environment."""
|
||||
super().cli()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
InfiniteMathEnv.cli()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue