mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-28 17:29:30 +00:00
add wandb to eval
This commit is contained in:
parent
dda85430da
commit
636715bb08
1 changed files with 29 additions and 28 deletions
|
|
@ -26,6 +26,7 @@ import verifiers as vf
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from tqdm.asyncio import tqdm_asyncio
|
from tqdm.asyncio import tqdm_asyncio
|
||||||
|
|
||||||
|
import wandb
|
||||||
from atroposlib.envs.base import (
|
from atroposlib.envs.base import (
|
||||||
APIServerConfig,
|
APIServerConfig,
|
||||||
BaseEnv,
|
BaseEnv,
|
||||||
|
|
@ -46,13 +47,10 @@ class VerifiersEvaluationConfig(BaseEnvConfig):
|
||||||
description="Additional arguments for verifiers environment",
|
description="Additional arguments for verifiers environment",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Generation parameters
|
|
||||||
temperature: float = Field(
|
temperature: float = Field(
|
||||||
default=0.0, description="Temperature for generation (0.0 for deterministic)"
|
default=0.0, description="Temperature for generation (0.0 for deterministic)"
|
||||||
)
|
)
|
||||||
max_tokens: int = Field(default=2048, description="Maximum tokens for generation")
|
|
||||||
|
|
||||||
# Retry and debug configuration
|
|
||||||
max_retries: int = Field(
|
max_retries: int = Field(
|
||||||
default=3, description="Maximum retries for failed API calls"
|
default=3, description="Maximum retries for failed API calls"
|
||||||
)
|
)
|
||||||
|
|
@ -64,16 +62,6 @@ class VerifiersEvaluationConfig(BaseEnvConfig):
|
||||||
)
|
)
|
||||||
full_debug: bool = Field(default=False, description="Enable full debug output")
|
full_debug: bool = Field(default=False, description="Enable full debug output")
|
||||||
|
|
||||||
# Override defaults for evaluation mode
|
|
||||||
group_size: int = 1
|
|
||||||
max_num_workers: int = 256
|
|
||||||
max_num_workers_per_node: int = 64
|
|
||||||
use_wandb: bool = True
|
|
||||||
rollout_server_url: str = "http://localhost:8000"
|
|
||||||
total_steps: int = 1
|
|
||||||
wandb_name: str = "verifiers_evaluation"
|
|
||||||
steps_per_eval: int = 1
|
|
||||||
|
|
||||||
|
|
||||||
class VerifiersEvaluationEnv(BaseEnv):
|
class VerifiersEvaluationEnv(BaseEnv):
|
||||||
"""
|
"""
|
||||||
|
|
@ -118,17 +106,11 @@ class VerifiersEvaluationEnv(BaseEnv):
|
||||||
"""Default configuration for evaluation."""
|
"""Default configuration for evaluation."""
|
||||||
env_config = VerifiersEvaluationConfig(
|
env_config = VerifiersEvaluationConfig(
|
||||||
vf_env_name="primeintellect/gsm8k",
|
vf_env_name="primeintellect/gsm8k",
|
||||||
temperature=0.0,
|
|
||||||
max_tokens=2048,
|
|
||||||
use_wandb=True,
|
|
||||||
wandb_name="verifiers_evaluation",
|
|
||||||
)
|
)
|
||||||
server_configs = [
|
server_configs = [
|
||||||
APIServerConfig(
|
APIServerConfig(
|
||||||
model_name="gpt-4.1-nano",
|
model_name="gpt-4.1-nano",
|
||||||
base_url=None,
|
|
||||||
api_key=os.getenv("OPENAI_API_KEY"),
|
api_key=os.getenv("OPENAI_API_KEY"),
|
||||||
num_requests_for_eval=256,
|
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
return env_config, server_configs
|
return env_config, server_configs
|
||||||
|
|
@ -169,7 +151,7 @@ class VerifiersEvaluationEnv(BaseEnv):
|
||||||
kwargs = {
|
kwargs = {
|
||||||
"messages": messages,
|
"messages": messages,
|
||||||
"temperature": self.config.temperature,
|
"temperature": self.config.temperature,
|
||||||
"max_tokens": self.config.max_tokens,
|
"max_tokens": self.config.max_token_length,
|
||||||
"n": 1,
|
"n": 1,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -323,23 +305,42 @@ class VerifiersEvaluationEnv(BaseEnv):
|
||||||
end_time=end_time,
|
end_time=end_time,
|
||||||
generation_parameters={
|
generation_parameters={
|
||||||
"temperature": self.config.temperature,
|
"temperature": self.config.temperature,
|
||||||
"max_tokens": self.config.max_tokens,
|
"max_tokens": self.config.max_token_length,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Log to wandb
|
||||||
|
await self.wandb_log(metrics)
|
||||||
|
|
||||||
return metrics
|
return metrics
|
||||||
|
|
||||||
async def wandb_log(self, wandb_metrics: Optional[Dict] = None) -> None:
|
async def wandb_log(self, wandb_metrics: Optional[Dict] = None) -> None:
|
||||||
"""Log metrics to Weights & Biases."""
|
"""Log metrics to Weights & Biases."""
|
||||||
if wandb_metrics is None:
|
if not self.config.use_wandb or wandb_metrics is None:
|
||||||
wandb_metrics = {}
|
return
|
||||||
|
|
||||||
# Add config info
|
# Lazy init if wandb not already initialized
|
||||||
wandb_metrics["config/vf_env_name"] = self.config.vf_env_name
|
if wandb.run is None:
|
||||||
wandb_metrics["config/temperature"] = self.config.temperature
|
wandb.init(
|
||||||
wandb_metrics["config/max_tokens"] = self.config.max_tokens
|
project="verifiers-eval",
|
||||||
|
name=self.config.wandb_name,
|
||||||
|
config=self.config.model_dump(),
|
||||||
|
)
|
||||||
|
|
||||||
await super().wandb_log(wandb_metrics)
|
log_dict = {
|
||||||
|
"verifiers/accuracy": wandb_metrics.get("accuracy", 0),
|
||||||
|
"verifiers/avg_score": wandb_metrics.get("avg_score", 0),
|
||||||
|
"verifiers/total_evaluated": wandb_metrics.get("total_evaluated", 0),
|
||||||
|
"verifiers/total_correct": wandb_metrics.get("total_correct", 0),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add per-reward function metrics
|
||||||
|
reward_breakdown = wandb_metrics.get("reward_breakdown", {})
|
||||||
|
for func_name, data in reward_breakdown.items():
|
||||||
|
log_dict[f"verifiers/{func_name}_avg"] = data.get("avg", 0)
|
||||||
|
log_dict[f"verifiers/{func_name}_correct"] = data.get("correct", 0)
|
||||||
|
|
||||||
|
wandb.log(log_dict)
|
||||||
|
|
||||||
# Required abstract method implementations (stubs for evaluation-only mode)
|
# Required abstract method implementations (stubs for evaluation-only mode)
|
||||||
async def get_next_item(self) -> Optional[Dict]:
|
async def get_next_item(self) -> Optional[Dict]:
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue