fix some bugs

This commit is contained in:
teknium 2025-12-28 04:09:34 +00:00
parent 830a129655
commit bcfbd647e3
2 changed files with 96 additions and 31 deletions

View file

@ -959,12 +959,43 @@ def extract_all_boxed(latex_str: str) -> List[str]:
"""
Extract all \\boxed{} contents from a LaTeX string.
Handles arbitrarily nested braces by counting brace depth.
Args:
latex_str: LaTeX string
Returns:
List of contents from all \\boxed{} occurrences
"""
pattern = r"\\boxed\{([^{}]*(?:\{[^{}]*\}[^{}]*)*)\}"
return re.findall(pattern, latex_str)
results = []
i = 0
boxed_pattern = "\\boxed{"
while i < len(latex_str):
# Find next \boxed{
pos = latex_str.find(boxed_pattern, i)
if pos == -1:
break
# Start after \boxed{
start = pos + len(boxed_pattern)
depth = 1
j = start
# Count braces to find matching closing brace
while j < len(latex_str) and depth > 0:
if latex_str[j] == '{':
depth += 1
elif latex_str[j] == '}':
depth -= 1
j += 1
if depth == 0:
# Extract content between braces
content = latex_str[start:j-1].strip()
results.append(content)
i = j
return results

View file

@ -32,6 +32,7 @@ Supports thinking mode with <think></think> tags for extended reasoning.
"""
import asyncio
import os
import random
import re
from typing import Dict, List, Optional, Tuple
@ -54,6 +55,7 @@ from atroposlib.envs.base import (
APIServerConfig,
BaseEnv,
BaseEnvConfig,
EvalHandlingEnum,
)
@ -136,7 +138,7 @@ class PHYBenchEvalConfig(BaseEnvConfig):
# Thinking mode configuration
thinking_mode: bool = Field(
default=True,
default=False,
description="Whether to use thinking mode with <think></think> tags",
)
custom_thinking_prompt: Optional[str] = Field(
@ -174,16 +176,6 @@ class PHYBenchEvalConfig(BaseEnvConfig):
description="Enable full debug output",
)
# Override defaults for eval-only environment
group_size: int = 1
max_num_workers: int = 1024
max_eval_workers: int = 256
max_num_workers_per_node: int = 128
use_wandb: bool = True
rollout_server_url: str = "http://localhost:8000"
total_steps: int = 1
wandb_name: str = "phybench_eval"
steps_per_eval: int = 1
class PHYBenchEvalEnv(BaseEnv):
@ -195,15 +187,16 @@ class PHYBenchEvalEnv(BaseEnv):
"""
name = "phybench_eval"
env_config_cls = PHYBenchEvalConfig
def __init__(
self,
config: PHYBenchEvalConfig,
server_configs: List[APIServerConfig],
slurm_job_id: Optional[str] = None,
slurm: bool = False,
testing: bool = False,
):
super().__init__(config, server_configs, slurm_job_id, testing)
super().__init__(config, server_configs, slurm, testing)
self.config: PHYBenchEvalConfig = config
self.eval_items: List[Dict] = []
self._dataset_loaded = False
@ -219,13 +212,45 @@ class PHYBenchEvalEnv(BaseEnv):
)
@classmethod
def config_cls(cls) -> type:
return PHYBenchEvalConfig
def config_init(cls) -> Tuple[PHYBenchEvalConfig, List[APIServerConfig]]:
"""Initialize default configuration for the environment."""
env_config = PHYBenchEvalConfig(
tokenizer_name="NousResearch/Hermes-3-Llama-3.1-8B",
group_size=1,
use_wandb=True,
max_num_workers_per_node=128,
rollout_server_url="http://localhost:8000",
total_steps=1,
batch_size=1,
steps_per_eval=1,
inference_weight=1.0,
wandb_name="phybench_eval",
eval_handling=EvalHandlingEnum.STOP_TRAIN,
max_eval_workers=256,
max_num_workers=1024,
# PHYBench specific defaults
dataset_name="Eureka-Lab/PHYBench",
eval_split="train",
eval_temperature=0.6,
eval_max_tokens=0, # Use model default
thinking_mode=False,
compute_eed_score=True,
)
server_configs = [
APIServerConfig(
model_name="gpt-4.1",
base_url="https://api.openai.com/v1",
api_key=os.getenv("OPENAI_API_KEY", "none"),
num_max_requests_at_once=32,
num_requests_for_eval=1024,
),
]
return env_config, server_configs
async def setup(self) -> None:
"""Initialize the environment and load the dataset."""
await super().setup()
if not self._dataset_loaded:
await self._load_dataset()
@ -262,9 +287,10 @@ class PHYBenchEvalEnv(BaseEnv):
split_data = dataset[self.config.eval_split]
# Process items
# Process items (deduplicate by content - dataset has duplicates)
self.eval_items = []
tag_counts: Dict[str, int] = {}
seen_content: set = set()
for item in split_data:
problem_id = item.get("id", "")
@ -277,6 +303,11 @@ class PHYBenchEvalEnv(BaseEnv):
if not content or not answer:
continue
# Skip duplicates (dataset contains each question twice)
if content in seen_content:
continue
seen_content.add(content)
# Apply tag filter if specified
if self.config.tags_filter and tag not in self.config.tags_filter:
continue
@ -453,12 +484,10 @@ class PHYBenchEvalEnv(BaseEnv):
return result
async def rollout_and_score_eval(
self,
item: Dict,
server: APIServerConfig,
) -> Optional[Dict]:
async def rollout_and_score_eval(self, item: Dict) -> Optional[Dict]:
"""Run evaluation on a single item and return the result."""
if self.config.full_debug:
print(f"[DEBUG] Starting eval for item: {item.get('id', 'unknown')}", flush=True)
prompt = self._format_prompt(item)
system_content = self._create_system_content()
@ -469,9 +498,10 @@ class PHYBenchEvalEnv(BaseEnv):
# Build API call parameters
kwargs = {
"model": server.model_name,
"messages": messages,
"n": 1,
"temperature": self.config.eval_temperature,
"split": "eval",
}
if self.config.eval_max_tokens > 0:
kwargs["max_tokens"] = self.config.eval_max_tokens
@ -479,6 +509,11 @@ class PHYBenchEvalEnv(BaseEnv):
response_text = ""
for attempt in range(self.config.max_retries):
try:
if self.config.full_debug:
print(f" Making API request (attempt {attempt + 1}/{self.config.max_retries})...", flush=True)
print(f" Temperature: {self.config.eval_temperature}", flush=True)
print(f" Max tokens: {self.config.eval_max_tokens if self.config.eval_max_tokens > 0 else 'model default'}", flush=True)
response = await self.server.chat_completion(**kwargs)
response_text = response.choices[0].message.content or ""
@ -560,13 +595,12 @@ class PHYBenchEvalEnv(BaseEnv):
print(f"{'='*60}\n")
# Create evaluation tasks
async def eval_task(item):
return await self.rollout_and_score_eval(item, self.server_configs[0])
tasks = [eval_task(item) for item in self.eval_items]
eval_tasks = [
self.rollout_and_score_eval(item) for item in self.eval_items
]
# Run with progress bar
results = await tqdm_asyncio.gather(*tasks, desc="Evaluating PHYBench")
results = await tqdm_asyncio.gather(*eval_tasks, desc="Evaluating PHYBench")
# Filter out failed results
valid_results = [r for r in results if r is not None]