rework server and eval for rl rollout. add in asyncmanagedserver for

verifiers
This commit is contained in:
balyan.sid@gmail.com 2026-01-10 14:55:08 +05:30
parent 3449a4c23d
commit cf636595d2
4 changed files with 652 additions and 91 deletions

View file

@ -18,15 +18,15 @@ Usage:
"""
import asyncio
import inspect
import os
import time
from typing import Dict, List, Optional, Tuple
from typing import Any, Callable, Dict, List, Optional, Tuple
import verifiers as vf
import wandb
from pydantic import Field
from tqdm.asyncio import tqdm_asyncio
import wandb
from atroposlib.envs.base import (
APIServerConfig,
BaseEnv,
@ -34,6 +34,36 @@ from atroposlib.envs.base import (
)
# Patch math_verify timeout to work in async context
# The signal-based timeout doesn't work in non-main threads (asyncio event loop)
def _no_signal_timeout(timeout_seconds: int):
"""Replacement timeout decorator that doesn't use signals."""
def decorator(func):
def wrapper(*args, **kwargs):
# Just call the function without timeout
# This is safe because we're in an async context with our own timeouts
# timeout_seconds is intentionally unused - we're replacing the timeout logic
return func(*args, **kwargs)
return wrapper
return decorator
try:
import math_verify.grader
import math_verify.parser
import math_verify.utils
# Patch all modules that use the timeout decorator
math_verify.utils.timeout = _no_signal_timeout
math_verify.parser.timeout = _no_signal_timeout
math_verify.grader.timeout = _no_signal_timeout
except ImportError:
pass # math_verify not installed
class VerifiersEvaluationConfig(BaseEnvConfig):
"""Configuration for Verifiers evaluation environment."""
@ -91,13 +121,30 @@ class VerifiersEvaluationEnv(BaseEnv):
self.vf_env = vf.load_environment(config.vf_env_name, **config.env_args)
self.rubric = self.vf_env.rubric
# Extract rubric components
# Extract rubric components from RubricGroup
# RubricGroup.funcs is empty - need to collect from individual rubrics
self.parser = self.rubric.parser
self.reward_funcs = self.rubric.funcs
self.reward_weights = self.rubric.weights
self.reward_scales = [
weight / sum(self.reward_weights) for weight in self.reward_weights
]
self.reward_funcs: List[Callable] = []
self.reward_weights: List[float] = []
self.rubric_class_objects: List[Dict[str, Any]] = [] # class_objects per func
if hasattr(self.rubric, "rubrics"):
# RubricGroup: collect from all individual rubrics
for rubric in self.rubric.rubrics:
class_objects = getattr(rubric, "class_objects", {})
for func, weight in zip(rubric.funcs, rubric.weights):
self.reward_funcs.append(func)
self.reward_weights.append(weight)
self.rubric_class_objects.append(class_objects)
else:
# Single Rubric
self.reward_funcs = self.rubric.funcs
self.reward_weights = self.rubric.weights
class_objects = getattr(self.rubric, "class_objects", {})
self.rubric_class_objects = [class_objects] * len(self.rubric.funcs)
total_weight = sum(self.reward_weights) if self.reward_weights else 1.0
self.reward_scales = [weight / total_weight for weight in self.reward_weights]
self.system_prompt = self.vf_env.system_prompt
# Tracking
@ -192,14 +239,40 @@ class VerifiersEvaluationEnv(BaseEnv):
# Parse answer
answer_parsed = self.parser.parse_answer(completion=response_text)
# Score using reward funcs
# Score using reward funcs (async functions need await)
# Use signature introspection to pass only required params (like verifiers does)
rewards = []
for func in self.reward_funcs:
reward = func(
parser=self.parser,
completion=completion_messages,
answer=answer,
)
for i, func in enumerate(self.reward_funcs):
try:
# Build merged dict of all possible parameters
class_objects = self.rubric_class_objects[i]
merged = {
"completion": completion_messages,
"answer": answer,
"prompt": question,
}
merged.update(class_objects) # Adds parser, etc.
# Filter to only params the function accepts
sig = inspect.signature(func)
if any(p.kind == p.VAR_KEYWORD for p in sig.parameters.values()):
# Function accepts **kwargs, pass everything
kwargs = merged
else:
# Only pass params in signature
kwargs = {k: v for k, v in merged.items() if k in sig.parameters}
result = func(**kwargs)
# Reward functions may be async coroutines
if asyncio.iscoroutine(result):
reward = await result
else:
reward = result
reward = float(reward)
except Exception as e:
if self.config.full_debug:
print(f" Reward func {func.__name__} error: {e}")
reward = 0.0
rewards.append(reward)
weighted_rewards = [r * self.reward_scales[j] for j, r in enumerate(rewards)]
@ -235,11 +308,14 @@ class VerifiersEvaluationEnv(BaseEnv):
start_time = time.time()
# Create evaluation tasks
tasks = [self.rollout_and_score(item) for item in self.eval_items]
# Run sequentially to avoid signal/threading issues with math_verify parser
# The parser uses signals for timeouts which only work in main thread
from tqdm import tqdm
# Run with progress bar
results = await tqdm_asyncio.gather(*tasks, desc="Evaluating")
results = []
for item in tqdm(self.eval_items, desc="Evaluating"):
result = await self.rollout_and_score(item)
results.append(result)
# Filter out failed results
valid_results = [r for r in results if r is not None]