mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
tidying up comments and methods
This commit is contained in:
parent
137f8381ec
commit
04b32fd8f3
1 changed files with 16 additions and 112 deletions
|
|
@ -45,25 +45,21 @@ Remember to format your final answer correctly as this is important for evaluati
|
|||
class InfiniteMathEnvConfig(BaseEnvConfig):
|
||||
"""Configuration for the InfiniteMath environment."""
|
||||
|
||||
# Curriculum parameters
|
||||
starting_level: int = 1
|
||||
progress_threshold: float = 0.8
|
||||
min_evaluations: int = 5
|
||||
|
||||
# Environment parameters
|
||||
max_attempts_per_problem: int = 3
|
||||
correct_reward: float = 1.0
|
||||
incorrect_reward: float = -1.0
|
||||
think_block_bonus: float = 0.2 # Bonus for a well-formed think block
|
||||
boxed_answer_bonus: float = 0.2 # Bonus for a well-formed boxed answer
|
||||
think_block_bonus: float = 0.2
|
||||
boxed_answer_bonus: float = 0.2
|
||||
|
||||
# Length penalty parameters
|
||||
apply_length_penalty: bool = True
|
||||
length_threshold_ratio: float = (
|
||||
0.5 # Percentage of max_token_length before penalties apply
|
||||
0.5
|
||||
)
|
||||
|
||||
# Completion parameters
|
||||
temperature: float = 0.7
|
||||
top_p: float = 0.9
|
||||
|
||||
|
|
@ -79,40 +75,34 @@ class InfiniteMathEnv(BaseEnv):
|
|||
testing=False,
|
||||
):
|
||||
super().__init__(config, server_configs, slurm, testing)
|
||||
self.config = config # Override with our specific config class
|
||||
self.config = config
|
||||
|
||||
# Initialize tracking metrics
|
||||
self.percent_correct_buffer = []
|
||||
self.level_correct_buffer = {
|
||||
i: [] for i in range(1, 8)
|
||||
} # Track correctness for each level
|
||||
}
|
||||
self.eval_metrics = []
|
||||
|
||||
# Curriculum will be initialized in setup()
|
||||
self.curriculum = None
|
||||
|
||||
# Set the system prompt
|
||||
self.system_prompt = system_prompt
|
||||
|
||||
async def setup(self):
|
||||
"""Initialize the environment and curriculum."""
|
||||
logger.info("Setting up InfiniteMathEnv")
|
||||
|
||||
# Initialize curriculum
|
||||
self.curriculum = MathCurriculum(
|
||||
starting_level=self.config.starting_level,
|
||||
progress_threshold=self.config.progress_threshold,
|
||||
min_evaluations=self.config.min_evaluations,
|
||||
)
|
||||
|
||||
# Generate some test problems for each level for evaluation
|
||||
self.eval_problems = {}
|
||||
for level in range(1, 8):
|
||||
self.eval_problems[level] = []
|
||||
temp_curriculum = MathCurriculum(starting_level=level)
|
||||
# Generate 10 test problems for each level
|
||||
attempts = 0
|
||||
max_attempts_per_level = 20 # Try at most 20 problems to get 10 valid ones
|
||||
max_attempts_per_level = 20
|
||||
|
||||
while (
|
||||
len(self.eval_problems[level]) < 10
|
||||
|
|
@ -120,9 +110,8 @@ class InfiniteMathEnv(BaseEnv):
|
|||
):
|
||||
try:
|
||||
problem, solution, generator_id = temp_curriculum.get_problem()
|
||||
# Strip LaTeX delimiters
|
||||
problem = self.strip_latex_delimiters(problem)
|
||||
solution = self.strip_latex_delimiters(solution)
|
||||
problem = self._strip_latex_delimiters(problem)
|
||||
solution = self._strip_latex_delimiters(solution)
|
||||
self.eval_problems[level].append((problem, solution, generator_id))
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
|
|
@ -134,7 +123,6 @@ class InfiniteMathEnv(BaseEnv):
|
|||
f"Generated {len(self.eval_problems[level])} evaluation problems for level {level}"
|
||||
)
|
||||
|
||||
# If any levels have no problems, add a simple fallback
|
||||
for level in range(1, 8):
|
||||
if not self.eval_problems[level]:
|
||||
logger.warning(
|
||||
|
|
@ -169,9 +157,8 @@ class InfiniteMathEnv(BaseEnv):
|
|||
else:
|
||||
self.eval_problems[level].append(("What is |3 - 10|?", "7", 71))
|
||||
|
||||
def strip_latex_delimiters(self, text: str) -> str:
|
||||
def _strip_latex_delimiters(self, text: str) -> str:
|
||||
"""Strip LaTeX delimiters ($...$) from text."""
|
||||
# Handle both inline expressions $...$ and expressions that make up the entire string
|
||||
return re.sub(r"\$(.*?)\$", r"\1", text)
|
||||
|
||||
def save_checkpoint(self, step, data=None):
|
||||
|
|
@ -179,7 +166,6 @@ class InfiniteMathEnv(BaseEnv):
|
|||
if data is None:
|
||||
data = {}
|
||||
|
||||
# Save curriculum state
|
||||
data["curriculum_level"] = self.curriculum.get_current_level()
|
||||
data["performance_history"] = {
|
||||
str(k): v for k, v in self.curriculum.performance_history.items()
|
||||
|
|
@ -191,19 +177,16 @@ class InfiniteMathEnv(BaseEnv):
|
|||
"""Load curriculum state from checkpoint."""
|
||||
super().load_checkpoint()
|
||||
|
||||
# Check if we have curriculum data in the checkpoint
|
||||
checkpoint_path = f"{self.checkpoint_dir}/env_checkpoints/{self.wandb_prepend}/step-{self.curr_step}.json"
|
||||
try:
|
||||
with open(checkpoint_path, "r") as f:
|
||||
data = json.load(f)
|
||||
|
||||
# Restore curriculum state if available
|
||||
if "curriculum_level" in data:
|
||||
level = data["curriculum_level"]
|
||||
self.curriculum.current_level = level
|
||||
|
||||
if "performance_history" in data:
|
||||
# Convert string keys back to integers
|
||||
self.curriculum.performance_history = {
|
||||
int(k): v for k, v in data["performance_history"].items()
|
||||
}
|
||||
|
|
@ -215,7 +198,6 @@ class InfiniteMathEnv(BaseEnv):
|
|||
if wandb_metrics is None:
|
||||
wandb_metrics = {}
|
||||
|
||||
# Log overall correct percentage
|
||||
try:
|
||||
wandb_metrics["train/percent_correct"] = sum(
|
||||
self.percent_correct_buffer
|
||||
|
|
@ -223,7 +205,6 @@ class InfiniteMathEnv(BaseEnv):
|
|||
except ZeroDivisionError:
|
||||
pass
|
||||
|
||||
# Log per-level metrics
|
||||
for level, buffer in self.level_correct_buffer.items():
|
||||
if buffer:
|
||||
wandb_metrics[f"train/level_{level}_correct"] = sum(buffer) / len(
|
||||
|
|
@ -231,7 +212,6 @@ class InfiniteMathEnv(BaseEnv):
|
|||
)
|
||||
wandb_metrics[f"train/level_{level}_count"] = len(buffer)
|
||||
|
||||
# Log current level and curriculum progress
|
||||
if self.curriculum:
|
||||
current_level = self.curriculum.get_current_level()
|
||||
max_level = max(self.curriculum.DIFFICULTY_LEVELS.keys())
|
||||
|
|
@ -242,12 +222,10 @@ class InfiniteMathEnv(BaseEnv):
|
|||
current_level / max_level
|
||||
) * 100
|
||||
|
||||
# Log level description
|
||||
wandb_metrics["curriculum/level_description"] = (
|
||||
self.curriculum.get_level_description()
|
||||
)
|
||||
|
||||
# Log performance history for current level
|
||||
if current_level in self.curriculum.performance_history:
|
||||
history = self.curriculum.performance_history[current_level]
|
||||
if history:
|
||||
|
|
@ -266,49 +244,25 @@ class InfiniteMathEnv(BaseEnv):
|
|||
0, self.curriculum.progress_threshold - success_rate
|
||||
)
|
||||
|
||||
# Log reward function metrics
|
||||
# REMOVED: Specific reward function config logging as it's not used anymore
|
||||
# if hasattr(self, "reward_function") and self.wandb:
|
||||
# if hasattr(self.reward_function, "set_wandb_logger"):
|
||||
# self.reward_function.set_wandb_logger(self.wandb)
|
||||
|
||||
# # Log the reward configurations
|
||||
# if isinstance(self.config.reward_functions, list) and self.config.reward_functions:
|
||||
# # Log the reward configuration
|
||||
# wandb_metrics["reward/format_reward_enabled"] = "format" in self.config.reward_functions
|
||||
# wandb_metrics["reward/boxed_reward_enabled"] = "boxed" in self.config.reward_functions
|
||||
|
||||
# if hasattr(self.config, "format_reward_weight"):
|
||||
# wandb_metrics["reward/format_reward_weight"] = self.config.format_reward_weight
|
||||
|
||||
# if hasattr(self.config, "boxed_reward_weight"):
|
||||
# wandb_metrics["reward/boxed_reward_weight"] = self.config.boxed_reward_weight
|
||||
|
||||
# Add eval metrics
|
||||
for item in self.eval_metrics:
|
||||
wandb_metrics[item[0]] = item[1]
|
||||
|
||||
# Reset buffers
|
||||
self.percent_correct_buffer = []
|
||||
for level in self.level_correct_buffer:
|
||||
self.level_correct_buffer[level] = []
|
||||
self.eval_metrics = []
|
||||
|
||||
# Call the parent method to handle remaining metrics
|
||||
await super().wandb_log(wandb_metrics)
|
||||
|
||||
async def get_next_item(self):
|
||||
"""Get the next problem based on current curriculum level."""
|
||||
problem, solution, generator_id = self.curriculum.get_problem()
|
||||
|
||||
# Strip LaTeX delimiters from problem and solution
|
||||
problem = self.strip_latex_delimiters(problem)
|
||||
solution = self.strip_latex_delimiters(solution)
|
||||
problem = self._strip_latex_delimiters(problem)
|
||||
solution = self._strip_latex_delimiters(solution)
|
||||
|
||||
# Create a message with the problem
|
||||
prompt = tuple([frozenset({"role": "user", "content": problem}.items())])
|
||||
|
||||
# Return the problem with metadata
|
||||
return (prompt, solution, generator_id)
|
||||
|
||||
async def evaluate(self, *args, **kwargs):
|
||||
|
|
@ -316,7 +270,6 @@ class InfiniteMathEnv(BaseEnv):
|
|||
current_level = self.curriculum.get_current_level()
|
||||
logger.info(f"Starting evaluation for curriculum level {current_level}")
|
||||
|
||||
# Only evaluate problems at the current level
|
||||
eval_tasks = []
|
||||
eval_generator_ids = []
|
||||
if current_level in self.eval_problems:
|
||||
|
|
@ -332,11 +285,9 @@ class InfiniteMathEnv(BaseEnv):
|
|||
)
|
||||
return []
|
||||
|
||||
# Run evaluation tasks
|
||||
logger.info(f"Evaluating {len(eval_tasks)} problems at level {current_level}")
|
||||
results = await asyncio.gather(*eval_tasks)
|
||||
|
||||
# Calculate accuracy for the current level
|
||||
correct_count = sum(1 for _, is_correct in results if is_correct)
|
||||
total_count = len(results)
|
||||
accuracy = correct_count / total_count if total_count > 0 else 0
|
||||
|
|
@ -345,23 +296,18 @@ class InfiniteMathEnv(BaseEnv):
|
|||
f"Level {current_level} accuracy: {accuracy:.2f} ({correct_count}/{total_count})"
|
||||
)
|
||||
|
||||
# Record metrics for the current level
|
||||
self.eval_metrics.append((f"eval/level_{current_level}_accuracy", accuracy))
|
||||
self.eval_metrics.append(("eval/current_level", current_level))
|
||||
|
||||
# Record the actual evaluation results in the curriculum's performance history
|
||||
for i, (_, is_correct) in enumerate(results):
|
||||
if i < len(eval_generator_ids):
|
||||
# Record the actual result
|
||||
self.curriculum.record_performance(eval_generator_ids[i], is_correct)
|
||||
else:
|
||||
# Fallback if somehow the lists are different lengths
|
||||
sample_generator_id = random.choice(
|
||||
self.curriculum.DIFFICULTY_LEVELS[current_level]
|
||||
)
|
||||
self.curriculum.record_performance(sample_generator_id, is_correct)
|
||||
|
||||
# Try to advance to the next level
|
||||
advanced = self.curriculum.advance_difficulty()
|
||||
new_level = self.curriculum.get_current_level()
|
||||
|
||||
|
|
@ -381,74 +327,62 @@ class InfiniteMathEnv(BaseEnv):
|
|||
try:
|
||||
logger.debug(f"Evaluating level {level} problem: {problem[:30]}...")
|
||||
|
||||
# Convert messages to a single prompt using the tokenizer
|
||||
messages = [
|
||||
{"role": "system", "content": self.system_prompt},
|
||||
{"role": "user", "content": problem},
|
||||
]
|
||||
prompt = self.tokenizer.apply_chat_template(messages, tokenize=False)
|
||||
|
||||
# Add prefilled thinking starter
|
||||
prefill = "\n<think>\n"
|
||||
prefilled_prompt = prompt + prefill
|
||||
|
||||
# Generate completion using the prompt
|
||||
logger.debug(f"Requesting completion for problem: {problem[:30]}...")
|
||||
completion = await self.server.completion(
|
||||
prompt=prefilled_prompt,
|
||||
n=1,
|
||||
max_tokens=self.config.max_token_length,
|
||||
temperature=0.0, # Use 0 temperature for deterministic results
|
||||
temperature=0.0,
|
||||
top_p=1.0,
|
||||
split="eval",
|
||||
)
|
||||
|
||||
# Extract the completion text and prepend the thinking starter
|
||||
model_answer = prefill + (
|
||||
completion.choices[0].text
|
||||
if hasattr(completion.choices[0], "text")
|
||||
else completion.choices[0].message.content
|
||||
)
|
||||
|
||||
# Check if the answer is correct
|
||||
is_correct = self.check_answer(model_answer, solution)
|
||||
logger.debug(f"Problem evaluated: level={level}, correct={is_correct}")
|
||||
|
||||
return level, is_correct
|
||||
except Exception as e:
|
||||
logger.error(f"Error evaluating problem: {e}")
|
||||
# Return a failed result in case of error
|
||||
return level, False
|
||||
|
||||
def check_answer(self, model_answer: str, solution: str) -> bool:
|
||||
"""Check if the model's answer matches the solution."""
|
||||
# Extract the part after the thinking block
|
||||
after_think_part = (
|
||||
model_answer.split("</think>")[-1].strip()
|
||||
if "</think>" in model_answer
|
||||
else model_answer
|
||||
)
|
||||
|
||||
# Extract the boxed answer if present
|
||||
boxed_answer = self._extract_boxed_answer(after_think_part)
|
||||
if not boxed_answer:
|
||||
# Try to find the answer in the last line
|
||||
lines = after_think_part.strip().split("\n")
|
||||
if lines:
|
||||
boxed_answer = lines[-1].strip()
|
||||
|
||||
# Clean up answers for comparison (remove spaces, convert to lowercase)
|
||||
model_clean = self._clean_for_comparison(
|
||||
boxed_answer if boxed_answer else after_think_part
|
||||
)
|
||||
solution_clean = self._clean_for_comparison(solution)
|
||||
|
||||
# Check if they match
|
||||
return model_clean == solution_clean
|
||||
|
||||
def _extract_boxed_answer(self, text: str) -> Optional[str]:
|
||||
"""Extract answer from a LaTeX boxed expression."""
|
||||
# Try to find boxed content
|
||||
boxed_match = re.search(r"\\boxed{([^}]*)}", text)
|
||||
if boxed_match:
|
||||
return boxed_match.group(1)
|
||||
|
|
@ -456,7 +390,6 @@ class InfiniteMathEnv(BaseEnv):
|
|||
|
||||
def _clean_for_comparison(self, text: str) -> str:
|
||||
"""Clean text for comparison."""
|
||||
# Remove LaTeX commands, spaces, commas, and convert to lowercase
|
||||
cleaned = re.sub(r"\\[a-zA-Z]+", "", text)
|
||||
cleaned = re.sub(r"[,\s]", "", cleaned)
|
||||
cleaned = cleaned.lower()
|
||||
|
|
@ -464,11 +397,8 @@ class InfiniteMathEnv(BaseEnv):
|
|||
|
||||
async def collect_trajectories(self, item) -> Tuple[List, List]:
|
||||
"""Collect trajectories for the current item."""
|
||||
# Extract information from the item
|
||||
problem_prompt, solution, generator_id = item
|
||||
|
||||
# Create prompt using tokenizer's chat template
|
||||
# Add prefilled thinking starter
|
||||
prefill = "\n<think>\n"
|
||||
messages = [
|
||||
{"role": "system", "content": self.system_prompt},
|
||||
|
|
@ -477,7 +407,6 @@ class InfiniteMathEnv(BaseEnv):
|
|||
]
|
||||
|
||||
prompt = self.tokenizer.apply_chat_template(messages, tokenize=False)
|
||||
# Generate completions using completion API
|
||||
completions = await self.server.completion(
|
||||
prompt=prompt,
|
||||
n=self.config.group_size,
|
||||
|
|
@ -486,19 +415,15 @@ class InfiniteMathEnv(BaseEnv):
|
|||
top_p=self.config.top_p,
|
||||
)
|
||||
|
||||
# Prepare data for scoring
|
||||
to_score = []
|
||||
|
||||
# Track level for metrics
|
||||
level = None
|
||||
for lvl, generator_ids in self.curriculum.DIFFICULTY_LEVELS.items():
|
||||
if generator_id in generator_ids:
|
||||
level = lvl
|
||||
break
|
||||
|
||||
# Process each completion
|
||||
for i, completion in enumerate(completions.choices):
|
||||
# Get the completion text and prepend the thinking starter
|
||||
model_answer = prefill + (
|
||||
completion.text
|
||||
if hasattr(completion, "text")
|
||||
|
|
@ -506,20 +431,14 @@ class InfiniteMathEnv(BaseEnv):
|
|||
)
|
||||
print("model_answer", model_answer)
|
||||
|
||||
# Build complete message sequence
|
||||
full_messages = [
|
||||
{"role": "system", "content": self.system_prompt},
|
||||
{"role": "user", "content": dict(problem_prompt[0])["content"]},
|
||||
{"role": "assistant", "content": model_answer},
|
||||
]
|
||||
|
||||
# Add to scoring list
|
||||
to_score.append((full_messages, solution, generator_id, level))
|
||||
|
||||
# Record performance in curriculum for each item we're scoring
|
||||
# This will be called again after scoring, but that's fine
|
||||
|
||||
# No additional items for backlog
|
||||
backlog = []
|
||||
|
||||
return to_score, backlog
|
||||
|
|
@ -532,57 +451,45 @@ class InfiniteMathEnv(BaseEnv):
|
|||
scored_data["scores"] = []
|
||||
scored_data["messages"] = []
|
||||
|
||||
# Process each item in the rollout data
|
||||
for i, (messages, solution, generator_id, level) in enumerate(rollout_group_data):
|
||||
model_answer = messages[-1]["content"]
|
||||
current_score = 0.0
|
||||
|
||||
# 1. Accuracy Check
|
||||
is_correct = self.check_answer(model_answer, solution)
|
||||
if is_correct:
|
||||
current_score += self.config.correct_reward
|
||||
else:
|
||||
current_score += self.config.incorrect_reward
|
||||
|
||||
# Record answer correctness for tracking and curriculum
|
||||
self.percent_correct_buffer.append(1 if is_correct else 0)
|
||||
if level is not None:
|
||||
self.level_correct_buffer[level].append(1 if is_correct else 0)
|
||||
self.curriculum.record_performance(generator_id, is_correct)
|
||||
|
||||
# 2. Thinking Block Check
|
||||
think_match = re.search(r"<think>(.*?)</think>", model_answer, re.DOTALL)
|
||||
if think_match:
|
||||
think_content = think_match.group(1).strip()
|
||||
if think_content: # Check if there's actual content
|
||||
if think_content:
|
||||
current_score += self.config.think_block_bonus
|
||||
# else: penalty for empty think block, or neutral
|
||||
# else: penalty for missing think block, or neutral
|
||||
|
||||
# 3. Boxed Answer Check
|
||||
# Extract the part after the thinking block for boxed answer validation
|
||||
after_think_part = model_answer.split("</think>")[-1].strip() if "</think>" in model_answer else model_answer
|
||||
boxed_answer_content = self._extract_boxed_answer(after_think_part)
|
||||
if boxed_answer_content is not None: # Check if \boxed{} is present and has content
|
||||
if boxed_answer_content is not None:
|
||||
current_score += self.config.boxed_answer_bonus
|
||||
# else: penalty for missing/malformed boxed answer, or neutral
|
||||
|
||||
logger.info(f"Item {i}: Correct: {is_correct}, Think Bonus: {self.config.think_block_bonus if think_match and think_match.group(1).strip() else 0}, Boxed Bonus: {self.config.boxed_answer_bonus if boxed_answer_content is not None else 0}, Final Score: {current_score}")
|
||||
|
||||
# Tokenize for the trainer
|
||||
tokens_dict = tokenize_for_trainer(
|
||||
self.tokenizer,
|
||||
messages, # These are the full messages including system, user, assistant
|
||||
None, # Not used by this tokenizer function apparently
|
||||
messages,
|
||||
None,
|
||||
)
|
||||
|
||||
# Add to scored data
|
||||
scored_data["tokens"].append(tokens_dict["tokens"])
|
||||
scored_data["masks"].append(tokens_dict["masks"])
|
||||
scored_data["scores"].append(current_score)
|
||||
scored_data["messages"].append(messages)
|
||||
|
||||
# Advance difficulty if criteria met
|
||||
self.curriculum.advance_difficulty()
|
||||
|
||||
return scored_data
|
||||
|
|
@ -605,7 +512,6 @@ if __name__ == "__main__":
|
|||
inference_weight=1.0,
|
||||
wandb_name="infinite_math",
|
||||
data_path_to_save_groups="data/infinite_math_groups.jsonl",
|
||||
# InfiniteMath specific config
|
||||
starting_level=1,
|
||||
progress_threshold=0.8,
|
||||
min_evaluations=10,
|
||||
|
|
@ -613,10 +519,8 @@ if __name__ == "__main__":
|
|||
incorrect_reward=-0.5,
|
||||
apply_length_penalty=True,
|
||||
length_threshold_ratio=0.6,
|
||||
# Completion parameters
|
||||
temperature=0.7,
|
||||
top_p=0.9,
|
||||
# Reward function configuration - use name directly
|
||||
reward_functions=["accuracy", "format", "boxed"],
|
||||
accuracy_reward_weight=1.0,
|
||||
format_reward_weight=0.2,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue