mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
fix some bugs
This commit is contained in:
parent
830a129655
commit
bcfbd647e3
2 changed files with 96 additions and 31 deletions
|
|
@ -959,12 +959,43 @@ def extract_all_boxed(latex_str: str) -> List[str]:
|
|||
"""
|
||||
Extract all \\boxed{} contents from a LaTeX string.
|
||||
|
||||
Handles arbitrarily nested braces by counting brace depth.
|
||||
|
||||
Args:
|
||||
latex_str: LaTeX string
|
||||
|
||||
Returns:
|
||||
List of contents from all \\boxed{} occurrences
|
||||
"""
|
||||
pattern = r"\\boxed\{([^{}]*(?:\{[^{}]*\}[^{}]*)*)\}"
|
||||
return re.findall(pattern, latex_str)
|
||||
results = []
|
||||
i = 0
|
||||
boxed_pattern = "\\boxed{"
|
||||
|
||||
while i < len(latex_str):
|
||||
# Find next \boxed{
|
||||
pos = latex_str.find(boxed_pattern, i)
|
||||
if pos == -1:
|
||||
break
|
||||
|
||||
# Start after \boxed{
|
||||
start = pos + len(boxed_pattern)
|
||||
depth = 1
|
||||
j = start
|
||||
|
||||
# Count braces to find matching closing brace
|
||||
while j < len(latex_str) and depth > 0:
|
||||
if latex_str[j] == '{':
|
||||
depth += 1
|
||||
elif latex_str[j] == '}':
|
||||
depth -= 1
|
||||
j += 1
|
||||
|
||||
if depth == 0:
|
||||
# Extract content between braces
|
||||
content = latex_str[start:j-1].strip()
|
||||
results.append(content)
|
||||
|
||||
i = j
|
||||
|
||||
return results
|
||||
|
||||
|
|
|
|||
|
|
@ -32,6 +32,7 @@ Supports thinking mode with <think></think> tags for extended reasoning.
|
|||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
|
@ -54,6 +55,7 @@ from atroposlib.envs.base import (
|
|||
APIServerConfig,
|
||||
BaseEnv,
|
||||
BaseEnvConfig,
|
||||
EvalHandlingEnum,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -136,7 +138,7 @@ class PHYBenchEvalConfig(BaseEnvConfig):
|
|||
|
||||
# Thinking mode configuration
|
||||
thinking_mode: bool = Field(
|
||||
default=True,
|
||||
default=False,
|
||||
description="Whether to use thinking mode with <think></think> tags",
|
||||
)
|
||||
custom_thinking_prompt: Optional[str] = Field(
|
||||
|
|
@ -174,16 +176,6 @@ class PHYBenchEvalConfig(BaseEnvConfig):
|
|||
description="Enable full debug output",
|
||||
)
|
||||
|
||||
# Override defaults for eval-only environment
|
||||
group_size: int = 1
|
||||
max_num_workers: int = 1024
|
||||
max_eval_workers: int = 256
|
||||
max_num_workers_per_node: int = 128
|
||||
use_wandb: bool = True
|
||||
rollout_server_url: str = "http://localhost:8000"
|
||||
total_steps: int = 1
|
||||
wandb_name: str = "phybench_eval"
|
||||
steps_per_eval: int = 1
|
||||
|
||||
|
||||
class PHYBenchEvalEnv(BaseEnv):
|
||||
|
|
@ -195,15 +187,16 @@ class PHYBenchEvalEnv(BaseEnv):
|
|||
"""
|
||||
|
||||
name = "phybench_eval"
|
||||
env_config_cls = PHYBenchEvalConfig
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PHYBenchEvalConfig,
|
||||
server_configs: List[APIServerConfig],
|
||||
slurm_job_id: Optional[str] = None,
|
||||
slurm: bool = False,
|
||||
testing: bool = False,
|
||||
):
|
||||
super().__init__(config, server_configs, slurm_job_id, testing)
|
||||
super().__init__(config, server_configs, slurm, testing)
|
||||
self.config: PHYBenchEvalConfig = config
|
||||
self.eval_items: List[Dict] = []
|
||||
self._dataset_loaded = False
|
||||
|
|
@ -219,13 +212,45 @@ class PHYBenchEvalEnv(BaseEnv):
|
|||
)
|
||||
|
||||
@classmethod
|
||||
def config_cls(cls) -> type:
|
||||
return PHYBenchEvalConfig
|
||||
def config_init(cls) -> Tuple[PHYBenchEvalConfig, List[APIServerConfig]]:
|
||||
"""Initialize default configuration for the environment."""
|
||||
env_config = PHYBenchEvalConfig(
|
||||
tokenizer_name="NousResearch/Hermes-3-Llama-3.1-8B",
|
||||
group_size=1,
|
||||
use_wandb=True,
|
||||
max_num_workers_per_node=128,
|
||||
rollout_server_url="http://localhost:8000",
|
||||
total_steps=1,
|
||||
batch_size=1,
|
||||
steps_per_eval=1,
|
||||
inference_weight=1.0,
|
||||
wandb_name="phybench_eval",
|
||||
eval_handling=EvalHandlingEnum.STOP_TRAIN,
|
||||
max_eval_workers=256,
|
||||
max_num_workers=1024,
|
||||
# PHYBench specific defaults
|
||||
dataset_name="Eureka-Lab/PHYBench",
|
||||
eval_split="train",
|
||||
eval_temperature=0.6,
|
||||
eval_max_tokens=0, # Use model default
|
||||
thinking_mode=False,
|
||||
compute_eed_score=True,
|
||||
)
|
||||
|
||||
server_configs = [
|
||||
APIServerConfig(
|
||||
model_name="gpt-4.1",
|
||||
base_url="https://api.openai.com/v1",
|
||||
api_key=os.getenv("OPENAI_API_KEY", "none"),
|
||||
num_max_requests_at_once=32,
|
||||
num_requests_for_eval=1024,
|
||||
),
|
||||
]
|
||||
|
||||
return env_config, server_configs
|
||||
|
||||
async def setup(self) -> None:
|
||||
"""Initialize the environment and load the dataset."""
|
||||
await super().setup()
|
||||
|
||||
if not self._dataset_loaded:
|
||||
await self._load_dataset()
|
||||
|
||||
|
|
@ -262,9 +287,10 @@ class PHYBenchEvalEnv(BaseEnv):
|
|||
|
||||
split_data = dataset[self.config.eval_split]
|
||||
|
||||
# Process items
|
||||
# Process items (deduplicate by content - dataset has duplicates)
|
||||
self.eval_items = []
|
||||
tag_counts: Dict[str, int] = {}
|
||||
seen_content: set = set()
|
||||
|
||||
for item in split_data:
|
||||
problem_id = item.get("id", "")
|
||||
|
|
@ -277,6 +303,11 @@ class PHYBenchEvalEnv(BaseEnv):
|
|||
if not content or not answer:
|
||||
continue
|
||||
|
||||
# Skip duplicates (dataset contains each question twice)
|
||||
if content in seen_content:
|
||||
continue
|
||||
seen_content.add(content)
|
||||
|
||||
# Apply tag filter if specified
|
||||
if self.config.tags_filter and tag not in self.config.tags_filter:
|
||||
continue
|
||||
|
|
@ -453,12 +484,10 @@ class PHYBenchEvalEnv(BaseEnv):
|
|||
|
||||
return result
|
||||
|
||||
async def rollout_and_score_eval(
|
||||
self,
|
||||
item: Dict,
|
||||
server: APIServerConfig,
|
||||
) -> Optional[Dict]:
|
||||
async def rollout_and_score_eval(self, item: Dict) -> Optional[Dict]:
|
||||
"""Run evaluation on a single item and return the result."""
|
||||
if self.config.full_debug:
|
||||
print(f"[DEBUG] Starting eval for item: {item.get('id', 'unknown')}", flush=True)
|
||||
prompt = self._format_prompt(item)
|
||||
system_content = self._create_system_content()
|
||||
|
||||
|
|
@ -469,9 +498,10 @@ class PHYBenchEvalEnv(BaseEnv):
|
|||
|
||||
# Build API call parameters
|
||||
kwargs = {
|
||||
"model": server.model_name,
|
||||
"messages": messages,
|
||||
"n": 1,
|
||||
"temperature": self.config.eval_temperature,
|
||||
"split": "eval",
|
||||
}
|
||||
if self.config.eval_max_tokens > 0:
|
||||
kwargs["max_tokens"] = self.config.eval_max_tokens
|
||||
|
|
@ -479,6 +509,11 @@ class PHYBenchEvalEnv(BaseEnv):
|
|||
response_text = ""
|
||||
for attempt in range(self.config.max_retries):
|
||||
try:
|
||||
if self.config.full_debug:
|
||||
print(f" Making API request (attempt {attempt + 1}/{self.config.max_retries})...", flush=True)
|
||||
print(f" Temperature: {self.config.eval_temperature}", flush=True)
|
||||
print(f" Max tokens: {self.config.eval_max_tokens if self.config.eval_max_tokens > 0 else 'model default'}", flush=True)
|
||||
|
||||
response = await self.server.chat_completion(**kwargs)
|
||||
response_text = response.choices[0].message.content or ""
|
||||
|
||||
|
|
@ -560,13 +595,12 @@ class PHYBenchEvalEnv(BaseEnv):
|
|||
print(f"{'='*60}\n")
|
||||
|
||||
# Create evaluation tasks
|
||||
async def eval_task(item):
|
||||
return await self.rollout_and_score_eval(item, self.server_configs[0])
|
||||
|
||||
tasks = [eval_task(item) for item in self.eval_items]
|
||||
eval_tasks = [
|
||||
self.rollout_and_score_eval(item) for item in self.eval_items
|
||||
]
|
||||
|
||||
# Run with progress bar
|
||||
results = await tqdm_asyncio.gather(*tasks, desc="Evaluating PHYBench")
|
||||
results = await tqdm_asyncio.gather(*eval_tasks, desc="Evaluating PHYBench")
|
||||
|
||||
# Filter out failed results
|
||||
valid_results = [r for r in results if r is not None]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue