mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-26 17:13:09 +00:00
linting and local testing tidy up
This commit is contained in:
parent
141ab66792
commit
bfc967c4bd
3 changed files with 242 additions and 310 deletions
|
|
@ -1,14 +1,11 @@
|
|||
#!/usr/bin/env python3
|
||||
import argparse
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from openai import OpenAI
|
||||
|
||||
from atroposlib.envs.base import OpenaiConfig
|
||||
from atroposlib.utils.config_handler import ConfigHandler
|
||||
from environments.infinimath.infinimath_env import (
|
||||
InfiniteMathEnv,
|
||||
InfiniteMathEnvConfig,
|
||||
|
|
@ -20,182 +17,119 @@ logging.basicConfig(level=logging.INFO)
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def parse_arguments():
|
||||
parser = argparse.ArgumentParser(description="InfiniteMath environment server")
|
||||
parser.add_argument(
|
||||
"--config",
|
||||
type=str,
|
||||
default="infinimath",
|
||||
help="Configuration file name (without .yaml extension or path for configs/envs/ directory, or full path)",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
async def main():
|
||||
logger.info("Starting InfiniteMath environment server")
|
||||
logger.info("Starting InfiniteMath environment local runner")
|
||||
|
||||
# Parse command line arguments
|
||||
args = parse_arguments()
|
||||
config = InfiniteMathEnvConfig(
|
||||
tokenizer_name="NousResearch/Nous-Hermes-2-Yi-34B",
|
||||
group_size=1,
|
||||
use_wandb=False,
|
||||
max_num_workers=1,
|
||||
rollout_server_url="http://localhost:8000",
|
||||
total_steps=1,
|
||||
batch_size=1,
|
||||
steps_per_eval=0,
|
||||
max_token_length=2048,
|
||||
wandb_name="infinite_math_local_debug",
|
||||
ensure_scores_are_not_same=False,
|
||||
starting_level=1,
|
||||
progress_threshold=0.8,
|
||||
min_evaluations=3,
|
||||
correct_reward=1.0,
|
||||
incorrect_reward=-0.5,
|
||||
think_block_bonus=0.1,
|
||||
boxed_answer_bonus=0.2,
|
||||
apply_length_penalty=False,
|
||||
length_threshold_ratio=0.6,
|
||||
temperature=0.3,
|
||||
top_p=0.9,
|
||||
)
|
||||
|
||||
# Initialize config handler and load configuration
|
||||
config_handler = ConfigHandler()
|
||||
|
||||
# Determine config path
|
||||
if (
|
||||
os.path.isabs(args.config)
|
||||
or "/" in args.config
|
||||
or args.config.endswith(".yaml")
|
||||
):
|
||||
config_path = args.config
|
||||
else:
|
||||
# short form that defaults to the envs directory
|
||||
config_path = os.path.join(
|
||||
config_handler.config_dir, f"envs/{args.config}.yaml"
|
||||
server_configs = [
|
||||
OpenaiConfig(
|
||||
model_name="NousResearch/Nous-Hermes-2-Yi-34B",
|
||||
base_url=os.getenv("OPENAI_BASE_URL", "http://localhost:9004/v1"),
|
||||
api_key=os.getenv("OPENAI_API_KEY", "dummy-key"),
|
||||
num_requests_for_eval=0,
|
||||
)
|
||||
]
|
||||
|
||||
logger.info("Using hardcoded debug configuration.")
|
||||
logger.debug(f"Env Config: {config}")
|
||||
logger.debug(f"Server Configs: {server_configs}")
|
||||
|
||||
logger.info(f"Loading configuration from: {config_path}")
|
||||
|
||||
try:
|
||||
with open(config_path, "r") as f:
|
||||
import yaml
|
||||
|
||||
raw_config = yaml.safe_load(f)
|
||||
logger.info(f"Loaded configuration successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading config directly: {e}")
|
||||
logger.info("Falling back to default config handler")
|
||||
raw_config = config_handler.load_config(args)
|
||||
|
||||
# Configure the InfiniteMath environment with values from config
|
||||
config = InfiniteMathEnvConfig(
|
||||
# Base environment parameters
|
||||
tokenizer_name=raw_config.get(
|
||||
"tokenizer_name", "NousResearch/DeepHermes-3-Llama-3-8B-Preview"
|
||||
),
|
||||
group_size=raw_config.get("group_size", 1),
|
||||
use_wandb=raw_config.get("use_wandb", False),
|
||||
max_num_workers=raw_config.get("max_num_workers", 1),
|
||||
rollout_server_url=raw_config.get(
|
||||
"rollout_server_url", "http://localhost:8000"
|
||||
),
|
||||
total_steps=raw_config.get("total_steps", 1),
|
||||
batch_size=raw_config.get("batch_size", 1),
|
||||
steps_per_eval=raw_config.get("steps_per_eval", 2),
|
||||
max_token_length=raw_config.get("max_token_length", 4096),
|
||||
wandb_name=raw_config.get("wandb_name", "infinite_math_test"),
|
||||
ensure_scores_are_not_same=raw_config.get("ensure_scores_are_not_same", False),
|
||||
# InfiniteMath specific parameters
|
||||
starting_level=raw_config.get("infinimath", {}).get("starting_level", 1),
|
||||
progress_threshold=raw_config.get("infinimath", {}).get(
|
||||
"progress_threshold", 0.7
|
||||
),
|
||||
min_evaluations=raw_config.get("infinimath", {}).get("min_evaluations", 3),
|
||||
correct_reward=raw_config.get("infinimath", {}).get("correct_reward", 1.0),
|
||||
incorrect_reward=raw_config.get("infinimath", {}).get("incorrect_reward", -0.5),
|
||||
apply_length_penalty=raw_config.get("infinimath", {}).get(
|
||||
"apply_length_penalty", True
|
||||
),
|
||||
length_threshold_ratio=raw_config.get("infinimath", {}).get(
|
||||
"length_threshold_ratio", 0.6
|
||||
),
|
||||
temperature=raw_config.get("infinimath", {}).get("temperature", 0.7),
|
||||
top_p=raw_config.get("infinimath", {}).get("top_p", 0.9),
|
||||
reward_functions=raw_config.get("infinimath", {}).get(
|
||||
"reward_functions", ["accuracy", "format", "boxed"]
|
||||
),
|
||||
accuracy_reward_weight=raw_config.get("infinimath", {}).get(
|
||||
"accuracy_reward_weight", 1.0
|
||||
),
|
||||
format_reward_weight=raw_config.get("infinimath", {}).get(
|
||||
"format_reward_weight", 0.2
|
||||
),
|
||||
boxed_reward_weight=raw_config.get("infinimath", {}).get(
|
||||
"boxed_reward_weight", 0.3
|
||||
),
|
||||
)
|
||||
|
||||
# Server configuration from config file or defaults
|
||||
server_configs = []
|
||||
|
||||
if "server_configs" in raw_config:
|
||||
for server_config in raw_config["server_configs"]:
|
||||
api_key = server_config.get("api_key", os.environ.get("OPENAI_API_KEY"))
|
||||
# Handle environment variable references like ${OPENAI_API_KEY}
|
||||
if (
|
||||
isinstance(api_key, str)
|
||||
and api_key.startswith("${")
|
||||
and api_key.endswith("}")
|
||||
):
|
||||
env_var = api_key[2:-1]
|
||||
api_key = os.environ.get(env_var, "")
|
||||
|
||||
server_configs.append(
|
||||
OpenaiConfig(
|
||||
model_name=server_config.get("model_name", "gpt-4.1-nano"),
|
||||
base_url=server_config.get("base_url", None),
|
||||
api_key=api_key,
|
||||
num_requests_for_eval=server_config.get(
|
||||
"num_requests_for_eval", 70
|
||||
),
|
||||
)
|
||||
)
|
||||
else:
|
||||
# Default configuration if not specified in config file
|
||||
server_configs.append(
|
||||
OpenaiConfig(
|
||||
model_name="gpt-4.1-nano",
|
||||
base_url=None,
|
||||
api_key=os.environ.get("OPENAI_API_KEY"),
|
||||
num_requests_for_eval=70,
|
||||
)
|
||||
env = InfiniteMathEnv(
|
||||
config=config,
|
||||
server_configs=server_configs,
|
||||
slurm=False,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to initialize InfiniteMathEnv: {e}")
|
||||
return
|
||||
|
||||
# Create the environment
|
||||
env = InfiniteMathEnv(
|
||||
config=config,
|
||||
server_configs=server_configs,
|
||||
slurm=False,
|
||||
)
|
||||
|
||||
# Setup the environment
|
||||
logger.info("Setting up environment...")
|
||||
await env.setup()
|
||||
logger.info("Environment setup complete")
|
||||
logger.info("Environment setup complete.")
|
||||
|
||||
# Log the number of evaluation problems
|
||||
total_problems = sum(len(probs) for probs in env.eval_problems.values())
|
||||
logger.info(
|
||||
f"Using {total_problems} evaluation problems across {len(env.eval_problems)} difficulty levels"
|
||||
)
|
||||
|
||||
# Get a math problem
|
||||
logger.info("Getting a math problem...")
|
||||
item = await env.get_next_item()
|
||||
problem_prompt, solution, generator_id = item
|
||||
|
||||
logger.info(f"Problem: {dict(problem_prompt[0])['content']}")
|
||||
logger.info(f"Solution: {solution}")
|
||||
problem_content = dict(problem_prompt[0])['content']
|
||||
logger.info(f"Problem (ID: {generator_id}, Level: {env.curriculum.get_current_level()}): {problem_content}")
|
||||
logger.info(f"Expected Solution: {solution}")
|
||||
|
||||
# Collect trajectories
|
||||
logger.info("Collecting trajectories...")
|
||||
trajectories_data, backlog = await env.collect_trajectories(item)
|
||||
|
||||
if not trajectories_data:
|
||||
logger.error("No trajectories were collected.")
|
||||
return
|
||||
|
||||
# Score the collected trajectories
|
||||
logger.info(f"Collected {len(trajectories_data)} data points for scoring (should be 1 for group_size=1).")
|
||||
|
||||
logger.info("Scoring trajectories...")
|
||||
scored_data = await env.score(trajectories_data)
|
||||
|
||||
input("Press Enter to continue...")
|
||||
# Print scores
|
||||
logger.info(f"Scores: {scored_data['scores']}")
|
||||
logger.info("\n========== Trajectory Summary ==========")
|
||||
if scored_data and scored_data.get("messages") and scored_data.get("scores"):
|
||||
for i, messages_list in enumerate(scored_data["messages"]):
|
||||
assistant_response = ""
|
||||
if messages_list and messages_list[-1].get("role") == "assistant":
|
||||
assistant_response = messages_list[-1].get("content", "N/A")
|
||||
|
||||
logger.info(f"--- Attempt {i+1} ---")
|
||||
logger.info(f"Problem: {problem_content}")
|
||||
logger.info(f"Full Assistant Response:\\n{assistant_response}")
|
||||
logger.info(f"Score: {scored_data['scores'][i]}")
|
||||
is_correct_task = env.check_answer(assistant_response, solution)
|
||||
logger.info(f"Checked Correct by env.check_answer: {is_correct_task}")
|
||||
|
||||
# Log the correct/incorrect counts
|
||||
correct_count = sum(1 for score in scored_data["scores"] if score > 0)
|
||||
logger.info(f"Correct answers: {correct_count}/{len(scored_data['scores'])}")
|
||||
|
||||
# Test evaluation function specifically
|
||||
correct_count_buffer = sum(env.percent_correct_buffer)
|
||||
total_attempts_buffer = len(env.percent_correct_buffer)
|
||||
|
||||
logger.info("\n--- Overall for this run ---")
|
||||
logger.info(f"Expected Solution: {solution}")
|
||||
logger.info(f"Score(s) from env.score: {scored_data['scores']}")
|
||||
if total_attempts_buffer > 0:
|
||||
logger.info(f"Correct based on internal buffer: {correct_count_buffer}/{total_attempts_buffer}")
|
||||
else:
|
||||
logger.info("No attempts recorded in percent_correct_buffer.")
|
||||
|
||||
else:
|
||||
logger.error("Scored data is missing expected fields ('messages' or 'scores').")
|
||||
|
||||
logger.info("=======================================")
|
||||
|
||||
# Re-add curriculum and evaluation testing
|
||||
logger.info("\n=== Testing Evaluation Function ===")
|
||||
|
||||
# Record the current level
|
||||
initial_level = env.curriculum.get_current_level()
|
||||
logger.info(f"Current level before evaluation: {initial_level}")
|
||||
initial_level_eval = env.curriculum.get_current_level()
|
||||
logger.info(f"Current level before evaluation: {initial_level_eval}")
|
||||
logger.info(f"Level description: {env.curriculum.get_level_description()}")
|
||||
logger.info(f"Progress threshold: {env.curriculum.progress_threshold}")
|
||||
logger.info(f"Min evaluations needed: {env.curriculum.min_evaluations}")
|
||||
|
|
@ -205,65 +139,76 @@ async def main():
|
|||
|
||||
# Display evaluation results
|
||||
logger.info("Evaluation metrics:")
|
||||
for metric_name, metric_value in eval_metrics:
|
||||
logger.info(f" - {metric_name}: {metric_value}")
|
||||
if eval_metrics:
|
||||
for metric_name, metric_value in eval_metrics:
|
||||
logger.info(f" - {metric_name}: {metric_value}")
|
||||
else:
|
||||
logger.info(" No evaluation metrics returned.")
|
||||
|
||||
# Check if the level advanced
|
||||
new_level = env.curriculum.get_current_level()
|
||||
if new_level > initial_level:
|
||||
logger.info(f"Successfully advanced to level {new_level}!")
|
||||
new_level_eval = env.curriculum.get_current_level()
|
||||
if new_level_eval > initial_level_eval:
|
||||
logger.info(f"Successfully advanced from level {initial_level_eval} to level {new_level_eval} during evaluation!")
|
||||
logger.info(f"New level description: {env.curriculum.get_level_description()}")
|
||||
else:
|
||||
logger.info(f"Did not advance during evaluation. Remained at level {initial_level_eval}.")
|
||||
# Show current progress toward advancement
|
||||
current_level = env.curriculum.get_current_level()
|
||||
if current_level in env.curriculum.performance_history:
|
||||
history = env.curriculum.performance_history[current_level]
|
||||
current_level_desc = env.curriculum.get_current_level()
|
||||
if current_level_desc in env.curriculum.performance_history:
|
||||
history = env.curriculum.performance_history[current_level_desc]
|
||||
if len(history) >= env.curriculum.min_evaluations:
|
||||
recent_history = history[-env.curriculum.min_evaluations :]
|
||||
success_rate = sum(recent_history) / len(recent_history)
|
||||
logger.info(
|
||||
f"Current success rate: {success_rate:.2f} (need {env.curriculum.progress_threshold} to advance)"
|
||||
f"Current success rate for level {current_level_desc}: {success_rate:.2f} (need {env.curriculum.progress_threshold} to advance)"
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
f"Need more evaluations: {len(history)}/{env.curriculum.min_evaluations}"
|
||||
f"Need more evaluations for level {current_level_desc}: {len(history)}/{env.curriculum.min_evaluations}"
|
||||
)
|
||||
|
||||
# Show all levels and their performance history
|
||||
logger.info("\nPerformance history by level:")
|
||||
for level in sorted(env.curriculum.performance_history.keys()):
|
||||
history = env.curriculum.performance_history[level]
|
||||
if history:
|
||||
success_rate = sum(history) / len(history)
|
||||
# Show all levels and their performance history after evaluation
|
||||
logger.info("\nPerformance history by level (after evaluation run):")
|
||||
for level_hist_key in sorted(env.curriculum.performance_history.keys()):
|
||||
history_list = env.curriculum.performance_history[level_hist_key]
|
||||
if history_list:
|
||||
success_rate_hist = sum(history_list) / len(history_list)
|
||||
logger.info(
|
||||
f" Level {level}: {success_rate:.2f} ({sum(history)}/{len(history)} correct)"
|
||||
f" Level {level_hist_key}: {success_rate_hist:.2f} ({sum(history_list)}/{len(history_list)} correct)"
|
||||
)
|
||||
else:
|
||||
logger.info(f" Level {level}: No data")
|
||||
logger.info(f" Level {level_hist_key}: No data")
|
||||
|
||||
# Test curriculum advancement with simulated performance history
|
||||
logger.info("\n=== Testing Curriculum Advancement ===")
|
||||
logger.info("\n=== Testing Curriculum Advancement Manually ===")
|
||||
initial_level_manual_adv = env.curriculum.get_current_level()
|
||||
logger.info(f"Starting manual advancement test from level: {initial_level_manual_adv}")
|
||||
|
||||
# Simulate good performance at current level
|
||||
for _ in range(env.config.min_evaluations):
|
||||
# Get a problem from current level
|
||||
item = await env.get_next_item()
|
||||
_, _, generator_id = item
|
||||
# Ensure we don't try to get items if curriculum is already at max level from previous eval
|
||||
max_level_possible = max(env.curriculum.DIFFICULTY_LEVELS.keys())
|
||||
if initial_level_manual_adv < max_level_possible:
|
||||
logger.info(f"Simulating {config.min_evaluations} correct answers for level {initial_level_manual_adv}...")
|
||||
for _ in range(config.min_evaluations): # Use config for min_evaluations
|
||||
# Get a problem from current level to ensure generator_id is valid for the level
|
||||
# The level might have changed due to the previous env.evaluate() call
|
||||
problem_item_adv_test = await env.get_next_item()
|
||||
_, _, generator_id_adv_test = problem_item_adv_test
|
||||
env.curriculum.record_performance(generator_id_adv_test, True)
|
||||
|
||||
# Try to advance difficulty
|
||||
did_advance = env.curriculum.advance_difficulty()
|
||||
new_level_manual_adv = env.curriculum.get_current_level()
|
||||
|
||||
# Record positive performance
|
||||
env.curriculum.record_performance(generator_id, True)
|
||||
logger.info(f"Curriculum advancement test results:")
|
||||
logger.info(f" - Level before manual simulation: {initial_level_manual_adv}")
|
||||
logger.info(f" - Recorded {config.min_evaluations} correct answers manually.")
|
||||
logger.info(f" - Did advance: {did_advance}")
|
||||
logger.info(f" - Level after manual advancement attempt: {new_level_manual_adv}")
|
||||
else:
|
||||
logger.info(f"Skipping manual advancement simulation as current level {initial_level_manual_adv} is already max level {max_level_possible}.")
|
||||
|
||||
# Try to advance difficulty
|
||||
did_advance = env.curriculum.advance_difficulty()
|
||||
new_level = env.curriculum.get_current_level()
|
||||
|
||||
logger.info(f"Curriculum advancement test:")
|
||||
logger.info(f" - Starting level: {initial_level}")
|
||||
logger.info(f" - Recorded {env.config.min_evaluations} correct answers")
|
||||
logger.info(f" - Did advance: {did_advance}")
|
||||
logger.info(f" - New level: {new_level}")
|
||||
|
||||
logger.info("Test server completed successfully")
|
||||
logger.info("InfiniteMath local runner completed successfully.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue