add wandb to eval

This commit is contained in:
balyan.sid@gmail.com 2026-01-09 16:51:19 +05:30
parent dda85430da
commit 636715bb08

View file

@ -26,6 +26,7 @@ import verifiers as vf
from pydantic import Field from pydantic import Field
from tqdm.asyncio import tqdm_asyncio from tqdm.asyncio import tqdm_asyncio
import wandb
from atroposlib.envs.base import ( from atroposlib.envs.base import (
APIServerConfig, APIServerConfig,
BaseEnv, BaseEnv,
@ -46,13 +47,10 @@ class VerifiersEvaluationConfig(BaseEnvConfig):
description="Additional arguments for verifiers environment", description="Additional arguments for verifiers environment",
) )
# Generation parameters
temperature: float = Field( temperature: float = Field(
default=0.0, description="Temperature for generation (0.0 for deterministic)" default=0.0, description="Temperature for generation (0.0 for deterministic)"
) )
max_tokens: int = Field(default=2048, description="Maximum tokens for generation")
# Retry and debug configuration
max_retries: int = Field( max_retries: int = Field(
default=3, description="Maximum retries for failed API calls" default=3, description="Maximum retries for failed API calls"
) )
@ -64,16 +62,6 @@ class VerifiersEvaluationConfig(BaseEnvConfig):
) )
full_debug: bool = Field(default=False, description="Enable full debug output") full_debug: bool = Field(default=False, description="Enable full debug output")
# Override defaults for evaluation mode
group_size: int = 1
max_num_workers: int = 256
max_num_workers_per_node: int = 64
use_wandb: bool = True
rollout_server_url: str = "http://localhost:8000"
total_steps: int = 1
wandb_name: str = "verifiers_evaluation"
steps_per_eval: int = 1
class VerifiersEvaluationEnv(BaseEnv): class VerifiersEvaluationEnv(BaseEnv):
""" """
@ -118,17 +106,11 @@ class VerifiersEvaluationEnv(BaseEnv):
"""Default configuration for evaluation.""" """Default configuration for evaluation."""
env_config = VerifiersEvaluationConfig( env_config = VerifiersEvaluationConfig(
vf_env_name="primeintellect/gsm8k", vf_env_name="primeintellect/gsm8k",
temperature=0.0,
max_tokens=2048,
use_wandb=True,
wandb_name="verifiers_evaluation",
) )
server_configs = [ server_configs = [
APIServerConfig( APIServerConfig(
model_name="gpt-4.1-nano", model_name="gpt-4.1-nano",
base_url=None,
api_key=os.getenv("OPENAI_API_KEY"), api_key=os.getenv("OPENAI_API_KEY"),
num_requests_for_eval=256,
), ),
] ]
return env_config, server_configs return env_config, server_configs
@ -169,7 +151,7 @@ class VerifiersEvaluationEnv(BaseEnv):
kwargs = { kwargs = {
"messages": messages, "messages": messages,
"temperature": self.config.temperature, "temperature": self.config.temperature,
"max_tokens": self.config.max_tokens, "max_tokens": self.config.max_token_length,
"n": 1, "n": 1,
} }
@ -323,23 +305,42 @@ class VerifiersEvaluationEnv(BaseEnv):
end_time=end_time, end_time=end_time,
generation_parameters={ generation_parameters={
"temperature": self.config.temperature, "temperature": self.config.temperature,
"max_tokens": self.config.max_tokens, "max_tokens": self.config.max_token_length,
}, },
) )
# Log to wandb
await self.wandb_log(metrics)
return metrics return metrics
async def wandb_log(self, wandb_metrics: Optional[Dict] = None) -> None: async def wandb_log(self, wandb_metrics: Optional[Dict] = None) -> None:
"""Log metrics to Weights & Biases.""" """Log metrics to Weights & Biases."""
if wandb_metrics is None: if not self.config.use_wandb or wandb_metrics is None:
wandb_metrics = {} return
# Add config info # Lazy init if wandb not already initialized
wandb_metrics["config/vf_env_name"] = self.config.vf_env_name if wandb.run is None:
wandb_metrics["config/temperature"] = self.config.temperature wandb.init(
wandb_metrics["config/max_tokens"] = self.config.max_tokens project="verifiers-eval",
name=self.config.wandb_name,
config=self.config.model_dump(),
)
await super().wandb_log(wandb_metrics) log_dict = {
"verifiers/accuracy": wandb_metrics.get("accuracy", 0),
"verifiers/avg_score": wandb_metrics.get("avg_score", 0),
"verifiers/total_evaluated": wandb_metrics.get("total_evaluated", 0),
"verifiers/total_correct": wandb_metrics.get("total_correct", 0),
}
# Add per-reward function metrics
reward_breakdown = wandb_metrics.get("reward_breakdown", {})
for func_name, data in reward_breakdown.items():
log_dict[f"verifiers/{func_name}_avg"] = data.get("avg", 0)
log_dict[f"verifiers/{func_name}_correct"] = data.get("correct", 0)
wandb.log(log_dict)
# Required abstract method implementations (stubs for evaluation-only mode) # Required abstract method implementations (stubs for evaluation-only mode)
async def get_next_item(self) -> Optional[Dict]: async def get_next_item(self) -> Optional[Dict]: