mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-29 17:35:07 +00:00
word problem generation working
This commit is contained in:
parent
5c0c7f5b10
commit
fd5b87011d
2 changed files with 180 additions and 38 deletions
|
|
@ -43,6 +43,10 @@ async def main():
|
|||
length_threshold_ratio=0.6,
|
||||
temperature=0.3,
|
||||
top_p=0.9,
|
||||
word_problem_model_name="gpt-4.1-mini",
|
||||
word_problem_openai_api_key=os.getenv("OPENAI_API_KEY_WORD_PROBLEM")
|
||||
or os.getenv("OPENAI_API_KEY"),
|
||||
word_problem_openai_base_url=os.getenv("OPENAI_BASE_URL_WORD_PROBLEM"),
|
||||
)
|
||||
|
||||
server_configs = [
|
||||
|
|
@ -53,12 +57,11 @@ async def main():
|
|||
timeout=600,
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
logger.info("Using hardcoded debug configuration.")
|
||||
logger.debug(f"Env Config: {config}")
|
||||
logger.debug(f"Server Configs: {server_configs}")
|
||||
|
||||
|
||||
try:
|
||||
env = InfiniteMathEnv(
|
||||
config=config,
|
||||
|
|
@ -77,19 +80,23 @@ async def main():
|
|||
item = await env.get_next_item()
|
||||
problem_prompt, solution, generator_id = item
|
||||
|
||||
problem_content = dict(problem_prompt[0])['content']
|
||||
logger.info(f"Problem (ID: {generator_id}, Level: {env.curriculum.get_current_level()}): {problem_content}")
|
||||
problem_content = dict(problem_prompt[0])["content"]
|
||||
logger.info(
|
||||
f"Problem (ID: {generator_id}, Level: {env.curriculum.get_current_level()}): {problem_content}"
|
||||
)
|
||||
logger.info(f"Expected Solution: {solution}")
|
||||
|
||||
logger.info("Collecting trajectories...")
|
||||
trajectories_data, backlog = await env.collect_trajectories(item)
|
||||
|
||||
|
||||
if not trajectories_data:
|
||||
logger.error("No trajectories were collected.")
|
||||
return
|
||||
|
||||
logger.info(f"Collected {len(trajectories_data)} data points for scoring (should be 1 for group_size=1).")
|
||||
|
||||
logger.info(
|
||||
f"Collected {len(trajectories_data)} data points for scoring (should be 1 for group_size=1)."
|
||||
)
|
||||
|
||||
logger.info("Scoring trajectories...")
|
||||
scored_data = await env.score(trajectories_data)
|
||||
|
||||
|
|
@ -99,7 +106,7 @@ async def main():
|
|||
assistant_response = ""
|
||||
if messages_list and messages_list[-1].get("role") == "assistant":
|
||||
assistant_response = messages_list[-1].get("content", "N/A")
|
||||
|
||||
|
||||
logger.info(f"--- Attempt {i+1} ---")
|
||||
logger.info(f"Problem: {problem_content}")
|
||||
logger.info(f"Full Assistant Response:\\n{assistant_response}")
|
||||
|
|
@ -107,21 +114,22 @@ async def main():
|
|||
is_correct_task = env.check_answer(assistant_response, solution)
|
||||
logger.info(f"Checked Correct by env.check_answer: {is_correct_task}")
|
||||
|
||||
|
||||
correct_count_buffer = sum(env.percent_correct_buffer)
|
||||
total_attempts_buffer = len(env.percent_correct_buffer)
|
||||
|
||||
|
||||
logger.info("\n--- Overall for this run ---")
|
||||
logger.info(f"Expected Solution: {solution}")
|
||||
logger.info(f"Score(s) from env.score: {scored_data['scores']}")
|
||||
if total_attempts_buffer > 0:
|
||||
logger.info(f"Correct based on internal buffer: {correct_count_buffer}/{total_attempts_buffer}")
|
||||
logger.info(
|
||||
f"Correct based on internal buffer: {correct_count_buffer}/{total_attempts_buffer}"
|
||||
)
|
||||
else:
|
||||
logger.info("No attempts recorded in percent_correct_buffer.")
|
||||
|
||||
else:
|
||||
logger.error("Scored data is missing expected fields ('messages' or 'scores').")
|
||||
|
||||
|
||||
logger.info("=======================================")
|
||||
|
||||
# Re-add curriculum and evaluation testing
|
||||
|
|
@ -148,10 +156,14 @@ async def main():
|
|||
# Check if the level advanced
|
||||
new_level_eval = env.curriculum.get_current_level()
|
||||
if new_level_eval > initial_level_eval:
|
||||
logger.info(f"Successfully advanced from level {initial_level_eval} to level {new_level_eval} during evaluation!")
|
||||
logger.info(
|
||||
f"Successfully advanced from level {initial_level_eval} to level {new_level_eval} during evaluation!"
|
||||
)
|
||||
logger.info(f"New level description: {env.curriculum.get_level_description()}")
|
||||
else:
|
||||
logger.info(f"Did not advance during evaluation. Remained at level {initial_level_eval}.")
|
||||
logger.info(
|
||||
f"Did not advance during evaluation. Remained at level {initial_level_eval}."
|
||||
)
|
||||
# Show current progress toward advancement
|
||||
current_level_desc = env.curriculum.get_current_level()
|
||||
if current_level_desc in env.curriculum.performance_history:
|
||||
|
|
@ -182,20 +194,24 @@ async def main():
|
|||
# Test curriculum advancement with simulated performance history
|
||||
logger.info("\n=== Testing Curriculum Advancement Manually ===")
|
||||
initial_level_manual_adv = env.curriculum.get_current_level()
|
||||
logger.info(f"Starting manual advancement test from level: {initial_level_manual_adv}")
|
||||
logger.info(
|
||||
f"Starting manual advancement test from level: {initial_level_manual_adv}"
|
||||
)
|
||||
|
||||
# Simulate good performance at current level
|
||||
# Ensure we don't try to get items if curriculum is already at max level from previous eval
|
||||
max_level_possible = max(env.curriculum.DIFFICULTY_LEVELS.keys())
|
||||
if initial_level_manual_adv < max_level_possible:
|
||||
logger.info(f"Simulating {config.min_evaluations} correct answers for level {initial_level_manual_adv}...")
|
||||
for _ in range(config.min_evaluations): # Use config for min_evaluations
|
||||
logger.info(
|
||||
f"Simulating {config.min_evaluations} correct answers for level {initial_level_manual_adv}..."
|
||||
)
|
||||
for _ in range(config.min_evaluations): # Use config for min_evaluations
|
||||
# Get a problem from current level to ensure generator_id is valid for the level
|
||||
# The level might have changed due to the previous env.evaluate() call
|
||||
problem_item_adv_test = await env.get_next_item()
|
||||
problem_item_adv_test = await env.get_next_item()
|
||||
_, _, generator_id_adv_test = problem_item_adv_test
|
||||
env.curriculum.record_performance(generator_id_adv_test, True)
|
||||
|
||||
|
||||
# Try to advance difficulty
|
||||
did_advance = env.curriculum.advance_difficulty()
|
||||
new_level_manual_adv = env.curriculum.get_current_level()
|
||||
|
|
@ -204,9 +220,13 @@ async def main():
|
|||
logger.info(f" - Level before manual simulation: {initial_level_manual_adv}")
|
||||
logger.info(f" - Recorded {config.min_evaluations} correct answers manually.")
|
||||
logger.info(f" - Did advance: {did_advance}")
|
||||
logger.info(f" - Level after manual advancement attempt: {new_level_manual_adv}")
|
||||
logger.info(
|
||||
f" - Level after manual advancement attempt: {new_level_manual_adv}"
|
||||
)
|
||||
else:
|
||||
logger.info(f"Skipping manual advancement simulation as current level {initial_level_manual_adv} is already max level {max_level_possible}.")
|
||||
logger.info(
|
||||
f"Skipping manual advancement simulation as current level {initial_level_manual_adv} is already max level {max_level_possible}."
|
||||
)
|
||||
|
||||
logger.info("InfiniteMath local runner completed successfully.")
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue