added cli & config init

This commit is contained in:
Shannon Sands 2025-05-12 07:49:51 +10:00
parent 04b32fd8f3
commit bdcc3cb88f
2 changed files with 31 additions and 267 deletions

View file

@ -494,12 +494,11 @@ class InfiniteMathEnv(BaseEnv):
return scored_data
if __name__ == "__main__":
import asyncio
async def main():
config = InfiniteMathEnvConfig(
@classmethod
def config_init(cls) -> Tuple[InfiniteMathEnvConfig, List[OpenaiConfig]]:
"""Initialize environment and OpenAI configurations with default values."""
env_config = InfiniteMathEnvConfig(
# BaseEnvConfig fields
tokenizer_name="NousResearch/Nous-Hermes-2-Yi-34B",
group_size=8,
use_wandb=True,
@ -512,34 +511,37 @@ if __name__ == "__main__":
inference_weight=1.0,
wandb_name="infinite_math",
data_path_to_save_groups="data/infinite_math_groups.jsonl",
# InfiniteMathEnvConfig specific fields
starting_level=1,
progress_threshold=0.8,
min_evaluations=10,
correct_reward=1.0,
incorrect_reward=-0.5,
apply_length_penalty=True,
length_threshold_ratio=0.6,
temperature=0.7,
top_p=0.9,
reward_functions=["accuracy", "format", "boxed"],
accuracy_reward_weight=1.0,
format_reward_weight=0.2,
boxed_reward_weight=0.3,
max_attempts_per_problem=3, # Default from class, not in old main
correct_reward=1.0, # As in old main
incorrect_reward=-0.5, # As in old main (class default was -1.0)
think_block_bonus=0.2, # As per previous update
boxed_answer_bonus=0.2, # As per previous update
apply_length_penalty=True, # As in old main
length_threshold_ratio=0.6, # As in old main (class default was 0.5)
temperature=0.7, # As in old main
top_p=0.9 # As in old main
)
openai_config = OpenaiConfig(
model_name="NousResearch/Nous-Hermes-2-Yi-34B",
base_url="http://localhost:9004/v1",
api_key="x",
num_requests_for_eval=64,
)
server_configs = [
OpenaiConfig(
model_name="NousResearch/Nous-Hermes-2-Yi-34B",
base_url="http://localhost:9004/v1",
api_key="x",
num_requests_for_eval=64,
)
]
return env_config, server_configs
env = InfiniteMathEnv(
config=config,
server_configs=[openai_config],
slurm=False,
)
@classmethod
def cli(cls):
"""Command Line Interface runner for the environment."""
super().cli()
await env.env_manager()
asyncio.run(main())
if __name__ == "__main__":
InfiniteMathEnv.cli()