mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
rework server and eval for rl rollout. add in asyncmanagedserver for
verifiers
This commit is contained in:
parent
3449a4c23d
commit
cf636595d2
4 changed files with 652 additions and 91 deletions
|
|
@ -18,15 +18,15 @@ Usage:
|
|||
"""
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
import os
|
||||
import time
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
|
||||
import verifiers as vf
|
||||
import wandb
|
||||
from pydantic import Field
|
||||
from tqdm.asyncio import tqdm_asyncio
|
||||
|
||||
import wandb
|
||||
from atroposlib.envs.base import (
|
||||
APIServerConfig,
|
||||
BaseEnv,
|
||||
|
|
@ -34,6 +34,36 @@ from atroposlib.envs.base import (
|
|||
)
|
||||
|
||||
|
||||
# Patch math_verify timeout to work in async context
|
||||
# The signal-based timeout doesn't work in non-main threads (asyncio event loop)
|
||||
def _no_signal_timeout(timeout_seconds: int):
|
||||
"""Replacement timeout decorator that doesn't use signals."""
|
||||
|
||||
def decorator(func):
|
||||
def wrapper(*args, **kwargs):
|
||||
# Just call the function without timeout
|
||||
# This is safe because we're in an async context with our own timeouts
|
||||
# timeout_seconds is intentionally unused - we're replacing the timeout logic
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
try:
|
||||
import math_verify.grader
|
||||
import math_verify.parser
|
||||
import math_verify.utils
|
||||
|
||||
# Patch all modules that use the timeout decorator
|
||||
math_verify.utils.timeout = _no_signal_timeout
|
||||
math_verify.parser.timeout = _no_signal_timeout
|
||||
math_verify.grader.timeout = _no_signal_timeout
|
||||
except ImportError:
|
||||
pass # math_verify not installed
|
||||
|
||||
|
||||
class VerifiersEvaluationConfig(BaseEnvConfig):
|
||||
"""Configuration for Verifiers evaluation environment."""
|
||||
|
||||
|
|
@ -91,13 +121,30 @@ class VerifiersEvaluationEnv(BaseEnv):
|
|||
self.vf_env = vf.load_environment(config.vf_env_name, **config.env_args)
|
||||
self.rubric = self.vf_env.rubric
|
||||
|
||||
# Extract rubric components
|
||||
# Extract rubric components from RubricGroup
|
||||
# RubricGroup.funcs is empty - need to collect from individual rubrics
|
||||
self.parser = self.rubric.parser
|
||||
self.reward_funcs = self.rubric.funcs
|
||||
self.reward_weights = self.rubric.weights
|
||||
self.reward_scales = [
|
||||
weight / sum(self.reward_weights) for weight in self.reward_weights
|
||||
]
|
||||
self.reward_funcs: List[Callable] = []
|
||||
self.reward_weights: List[float] = []
|
||||
self.rubric_class_objects: List[Dict[str, Any]] = [] # class_objects per func
|
||||
|
||||
if hasattr(self.rubric, "rubrics"):
|
||||
# RubricGroup: collect from all individual rubrics
|
||||
for rubric in self.rubric.rubrics:
|
||||
class_objects = getattr(rubric, "class_objects", {})
|
||||
for func, weight in zip(rubric.funcs, rubric.weights):
|
||||
self.reward_funcs.append(func)
|
||||
self.reward_weights.append(weight)
|
||||
self.rubric_class_objects.append(class_objects)
|
||||
else:
|
||||
# Single Rubric
|
||||
self.reward_funcs = self.rubric.funcs
|
||||
self.reward_weights = self.rubric.weights
|
||||
class_objects = getattr(self.rubric, "class_objects", {})
|
||||
self.rubric_class_objects = [class_objects] * len(self.rubric.funcs)
|
||||
|
||||
total_weight = sum(self.reward_weights) if self.reward_weights else 1.0
|
||||
self.reward_scales = [weight / total_weight for weight in self.reward_weights]
|
||||
self.system_prompt = self.vf_env.system_prompt
|
||||
|
||||
# Tracking
|
||||
|
|
@ -192,14 +239,40 @@ class VerifiersEvaluationEnv(BaseEnv):
|
|||
# Parse answer
|
||||
answer_parsed = self.parser.parse_answer(completion=response_text)
|
||||
|
||||
# Score using reward funcs
|
||||
# Score using reward funcs (async functions need await)
|
||||
# Use signature introspection to pass only required params (like verifiers does)
|
||||
rewards = []
|
||||
for func in self.reward_funcs:
|
||||
reward = func(
|
||||
parser=self.parser,
|
||||
completion=completion_messages,
|
||||
answer=answer,
|
||||
)
|
||||
for i, func in enumerate(self.reward_funcs):
|
||||
try:
|
||||
# Build merged dict of all possible parameters
|
||||
class_objects = self.rubric_class_objects[i]
|
||||
merged = {
|
||||
"completion": completion_messages,
|
||||
"answer": answer,
|
||||
"prompt": question,
|
||||
}
|
||||
merged.update(class_objects) # Adds parser, etc.
|
||||
|
||||
# Filter to only params the function accepts
|
||||
sig = inspect.signature(func)
|
||||
if any(p.kind == p.VAR_KEYWORD for p in sig.parameters.values()):
|
||||
# Function accepts **kwargs, pass everything
|
||||
kwargs = merged
|
||||
else:
|
||||
# Only pass params in signature
|
||||
kwargs = {k: v for k, v in merged.items() if k in sig.parameters}
|
||||
|
||||
result = func(**kwargs)
|
||||
# Reward functions may be async coroutines
|
||||
if asyncio.iscoroutine(result):
|
||||
reward = await result
|
||||
else:
|
||||
reward = result
|
||||
reward = float(reward)
|
||||
except Exception as e:
|
||||
if self.config.full_debug:
|
||||
print(f" Reward func {func.__name__} error: {e}")
|
||||
reward = 0.0
|
||||
rewards.append(reward)
|
||||
|
||||
weighted_rewards = [r * self.reward_scales[j] for j, r in enumerate(rewards)]
|
||||
|
|
@ -235,11 +308,14 @@ class VerifiersEvaluationEnv(BaseEnv):
|
|||
|
||||
start_time = time.time()
|
||||
|
||||
# Create evaluation tasks
|
||||
tasks = [self.rollout_and_score(item) for item in self.eval_items]
|
||||
# Run sequentially to avoid signal/threading issues with math_verify parser
|
||||
# The parser uses signals for timeouts which only work in main thread
|
||||
from tqdm import tqdm
|
||||
|
||||
# Run with progress bar
|
||||
results = await tqdm_asyncio.gather(*tasks, desc="Evaluating")
|
||||
results = []
|
||||
for item in tqdm(self.eval_items, desc="Evaluating"):
|
||||
result = await self.rollout_and_score(item)
|
||||
results.append(result)
|
||||
|
||||
# Filter out failed results
|
||||
valid_results = [r for r in results if r is not None]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue