mirror of
https://github.com/NousResearch/atropos.git
synced 2026-05-03 17:53:17 +00:00
linting
This commit is contained in:
parent
bdcc3cb88f
commit
e96970f82e
3 changed files with 84 additions and 323 deletions
|
|
@ -1,18 +1,18 @@
|
|||
#!/usr/bin/env python3
|
||||
import argparse
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import argparse
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from openai import OpenAI
|
||||
|
||||
from atroposlib.envs.base import OpenaiConfig
|
||||
from atroposlib.utils.config_handler import ConfigHandler
|
||||
from environments.infinimath.infinimath_env import (
|
||||
InfiniteMathEnv,
|
||||
InfiniteMathEnvConfig,
|
||||
)
|
||||
from atroposlib.envs.base import OpenaiConfig
|
||||
from atroposlib.utils.config_handler import ConfigHandler
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
|
@ -33,27 +33,32 @@ def parse_arguments():
|
|||
|
||||
async def main():
|
||||
logger.info("Starting InfiniteMath environment server")
|
||||
|
||||
|
||||
# Parse command line arguments
|
||||
args = parse_arguments()
|
||||
|
||||
|
||||
# Initialize config handler and load configuration
|
||||
config_handler = ConfigHandler()
|
||||
|
||||
|
||||
# Determine config path
|
||||
if os.path.isabs(args.config) or "/" in args.config or args.config.endswith(".yaml"):
|
||||
if (
|
||||
os.path.isabs(args.config)
|
||||
or "/" in args.config
|
||||
or args.config.endswith(".yaml")
|
||||
):
|
||||
config_path = args.config
|
||||
else:
|
||||
# short form that defaults to the envs directory
|
||||
config_path = os.path.join(
|
||||
config_handler.config_dir, f"envs/{args.config}.yaml"
|
||||
)
|
||||
|
||||
|
||||
logger.info(f"Loading configuration from: {config_path}")
|
||||
|
||||
|
||||
try:
|
||||
with open(config_path, "r") as f:
|
||||
import yaml
|
||||
|
||||
raw_config = yaml.safe_load(f)
|
||||
logger.info(f"Loaded configuration successfully")
|
||||
except Exception as e:
|
||||
|
|
@ -64,51 +69,74 @@ async def main():
|
|||
# Configure the InfiniteMath environment with values from config
|
||||
config = InfiniteMathEnvConfig(
|
||||
# Base environment parameters
|
||||
tokenizer_name=raw_config.get("tokenizer_name", "NousResearch/DeepHermes-3-Llama-3-8B-Preview"),
|
||||
tokenizer_name=raw_config.get(
|
||||
"tokenizer_name", "NousResearch/DeepHermes-3-Llama-3-8B-Preview"
|
||||
),
|
||||
group_size=raw_config.get("group_size", 1),
|
||||
use_wandb=raw_config.get("use_wandb", False),
|
||||
max_num_workers=raw_config.get("max_num_workers", 1),
|
||||
rollout_server_url=raw_config.get("rollout_server_url", "http://localhost:8000"),
|
||||
rollout_server_url=raw_config.get(
|
||||
"rollout_server_url", "http://localhost:8000"
|
||||
),
|
||||
total_steps=raw_config.get("total_steps", 1),
|
||||
batch_size=raw_config.get("batch_size", 1),
|
||||
steps_per_eval=raw_config.get("steps_per_eval", 2),
|
||||
max_token_length=raw_config.get("max_token_length", 4096),
|
||||
wandb_name=raw_config.get("wandb_name", "infinite_math_test"),
|
||||
ensure_scores_are_not_same=raw_config.get("ensure_scores_are_not_same", False),
|
||||
|
||||
# InfiniteMath specific parameters
|
||||
starting_level=raw_config.get("infinimath", {}).get("starting_level", 1),
|
||||
progress_threshold=raw_config.get("infinimath", {}).get("progress_threshold", 0.7),
|
||||
progress_threshold=raw_config.get("infinimath", {}).get(
|
||||
"progress_threshold", 0.7
|
||||
),
|
||||
min_evaluations=raw_config.get("infinimath", {}).get("min_evaluations", 3),
|
||||
correct_reward=raw_config.get("infinimath", {}).get("correct_reward", 1.0),
|
||||
incorrect_reward=raw_config.get("infinimath", {}).get("incorrect_reward", -0.5),
|
||||
apply_length_penalty=raw_config.get("infinimath", {}).get("apply_length_penalty", True),
|
||||
length_threshold_ratio=raw_config.get("infinimath", {}).get("length_threshold_ratio", 0.6),
|
||||
apply_length_penalty=raw_config.get("infinimath", {}).get(
|
||||
"apply_length_penalty", True
|
||||
),
|
||||
length_threshold_ratio=raw_config.get("infinimath", {}).get(
|
||||
"length_threshold_ratio", 0.6
|
||||
),
|
||||
temperature=raw_config.get("infinimath", {}).get("temperature", 0.7),
|
||||
top_p=raw_config.get("infinimath", {}).get("top_p", 0.9),
|
||||
reward_functions=raw_config.get("infinimath", {}).get("reward_functions", ["accuracy", "format", "boxed"]),
|
||||
accuracy_reward_weight=raw_config.get("infinimath", {}).get("accuracy_reward_weight", 1.0),
|
||||
format_reward_weight=raw_config.get("infinimath", {}).get("format_reward_weight", 0.2),
|
||||
boxed_reward_weight=raw_config.get("infinimath", {}).get("boxed_reward_weight", 0.3),
|
||||
reward_functions=raw_config.get("infinimath", {}).get(
|
||||
"reward_functions", ["accuracy", "format", "boxed"]
|
||||
),
|
||||
accuracy_reward_weight=raw_config.get("infinimath", {}).get(
|
||||
"accuracy_reward_weight", 1.0
|
||||
),
|
||||
format_reward_weight=raw_config.get("infinimath", {}).get(
|
||||
"format_reward_weight", 0.2
|
||||
),
|
||||
boxed_reward_weight=raw_config.get("infinimath", {}).get(
|
||||
"boxed_reward_weight", 0.3
|
||||
),
|
||||
)
|
||||
|
||||
# Server configuration from config file or defaults
|
||||
server_configs = []
|
||||
|
||||
|
||||
if "server_configs" in raw_config:
|
||||
for server_config in raw_config["server_configs"]:
|
||||
api_key = server_config.get("api_key", os.environ.get("OPENAI_API_KEY"))
|
||||
# Handle environment variable references like ${OPENAI_API_KEY}
|
||||
if isinstance(api_key, str) and api_key.startswith("${") and api_key.endswith("}"):
|
||||
if (
|
||||
isinstance(api_key, str)
|
||||
and api_key.startswith("${")
|
||||
and api_key.endswith("}")
|
||||
):
|
||||
env_var = api_key[2:-1]
|
||||
api_key = os.environ.get(env_var, "")
|
||||
|
||||
|
||||
server_configs.append(
|
||||
OpenaiConfig(
|
||||
model_name=server_config.get("model_name", "gpt-4.1-nano"),
|
||||
base_url=server_config.get("base_url", None),
|
||||
api_key=api_key,
|
||||
num_requests_for_eval=server_config.get("num_requests_for_eval", 70),
|
||||
num_requests_for_eval=server_config.get(
|
||||
"num_requests_for_eval", 70
|
||||
),
|
||||
)
|
||||
)
|
||||
else:
|
||||
|
|
@ -149,11 +177,11 @@ async def main():
|
|||
# Collect trajectories
|
||||
logger.info("Collecting trajectories...")
|
||||
trajectories_data, backlog = await env.collect_trajectories(item)
|
||||
|
||||
|
||||
# Score the collected trajectories
|
||||
logger.info("Scoring trajectories...")
|
||||
scored_data = await env.score(trajectories_data)
|
||||
|
||||
|
||||
input("Press Enter to continue...")
|
||||
# Print scores
|
||||
logger.info(f"Scores: {scored_data['scores']}")
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue