fix: improve verifiers environments consistency and correctness

- verifiers_server.py: consistent dataset column selection for train/test,
  remove redundant comments, preserve float precision for scores
- verifiers_eval.py: add env_config_cls, fix constructor signature to match
  BaseEnv (slurm bool), make stub methods raise NotImplementedError
This commit is contained in:
balyan.sid@gmail.com 2026-01-09 16:18:46 +05:30
parent b62c416130
commit 9d5cd2b593
2 changed files with 441 additions and 50 deletions

View file

@ -1,14 +1,21 @@
# Verifiers Training Environment for Atropos
#
# NOTE: This environment requires a LOCAL inference server (vLLM, SGLang, TRL)
# for ALL modes (serve, process, evaluate) because it uses ManagedServer for
# token/logprob tracking. For evaluation with OpenAI API, use:
# environments/eval_environments/verifiers_eval.py
#
# To install a Verifiers/Prime environment:
# 1. uv tool install prime
# 2. prime login
# 3. prime env install will/wordle (or any owner/environment)
# Docs: https://docs.primeintellect.ai/tutorials-environments/install
import os
import time
from typing import List, Tuple
from typing import Any, Dict, List, Optional, Tuple
import verifiers as vf
from pydantic import Field
from tqdm.asyncio import tqdm_asyncio
from atroposlib.envs.base import (
@ -20,8 +27,12 @@ from atroposlib.envs.base import (
class VfEnvConfig(BaseEnvConfig):
"""
Configuration for the Verifiers environments.
"""
vf_env_name: str = ""
env_args: dict = {}
env_args: Dict[str, Any] = Field(default_factory=dict)
class VerifiersEnv(BaseEnv):
@ -36,6 +47,7 @@ class VerifiersEnv(BaseEnv):
testing=False,
):
super().__init__(config, server_configs, slurm, testing)
self.percent_correct_buffer = list()
self.eval_metrics = list()
self.vf_env = vf.load_environment(config.vf_env_name, **config.env_args)
self.rubric = self.vf_env.rubric
@ -51,31 +63,59 @@ class VerifiersEnv(BaseEnv):
@classmethod
def config_init(cls) -> Tuple[VfEnvConfig, List[APIServerConfig]]:
env_config = VfEnvConfig(
tokenizer_name="Qwen/Qwen2.5-1.5B-Instruct",
group_size=8,
use_wandb=False,
use_wandb=True,
rollout_server_url="http://localhost:8000",
total_steps=10,
total_steps=1000,
batch_size=4,
steps_per_eval=1,
steps_per_eval=100,
max_token_length=2048,
wandb_name="verifiers",
)
# Requires local inference server (vLLM, SGLang, TRL)
# For evaluation with OpenAI, use eval_environments/verifiers_evaluation.py
server_configs = [
APIServerConfig(
model_name="gpt-4.1-nano",
base_url=None,
api_key=os.getenv("OPENAI_API_KEY"),
model_name="Qwen/Qwen2.5-1.5B-Instruct",
base_url="http://localhost:9001/v1",
api_key="x",
num_requests_for_eval=4,
),
]
return env_config, server_configs
async def wandb_log(self, wandb_metrics: Optional[Dict] = None):
if wandb_metrics is None:
wandb_metrics = {}
# Calculate percent_correct from buffer
if self.percent_correct_buffer:
wandb_metrics["train/percent_correct"] = sum(
self.percent_correct_buffer
) / len(self.percent_correct_buffer)
self.percent_correct_buffer = list()
for item in self.eval_metrics:
wandb_metrics[item[0]] = item[1]
self.eval_metrics = list()
await super().wandb_log(wandb_metrics)
async def setup(self):
self.train = self.vf_env.get_dataset()
train_data = self.vf_env.get_dataset()
self.train = train_data.select_columns(["question", "answer"]).to_list()
test_data = self.vf_env.get_eval_dataset()
self.test = test_data.select_columns(["question", "answer"]).to_list()
self.iter = 0
def save_checkpoint(self, step, data=None):
if data is None:
data = {}
data["iter"] = self.iter
super().save_checkpoint(step, data)
async def rollout_and_score_eval(
self, question: str, answer: str, **kwargs
) -> dict:
@ -85,12 +125,13 @@ class VerifiersEnv(BaseEnv):
{"role": "user", "content": question},
]
completion = await self.server.chat_completion(
messages=messages,
n=1,
max_tokens=self.config.max_token_length,
temperature=0.0,
)
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
completion = await managed.chat_completion(
messages=messages,
n=1,
max_tokens=self.config.max_token_length,
temperature=0.0,
)
response_content = completion.choices[0].message.content or ""
messages.append({"role": "assistant", "content": response_content})
@ -114,7 +155,7 @@ class VerifiersEnv(BaseEnv):
"question": question,
"gold_answer": answer,
"model_parsed": str(answer_parsed) if answer_parsed else None,
"score": int(score),
"score": score,
"correct": bool(score),
"finish_reason": completion.choices[0].finish_reason,
}
@ -171,38 +212,25 @@ class VerifiersEnv(BaseEnv):
{"role": "user", "content": question},
]
completions = await self.server.chat_completion(
messages=messages,
n=self.config.group_size,
max_tokens=self.config.max_token_length,
temperature=1.0,
)
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
completions = await managed.chat_completion(
messages=messages,
n=self.config.group_size,
max_tokens=self.config.max_token_length,
temperature=1.0,
)
state = managed.get_state()
nodes = state["nodes"]
prompt_text = self.tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
prompt_tokens = self.tokenizer.encode(prompt_text, add_special_tokens=False)
prompt_len = len(prompt_tokens)
scored_data = ScoredDataGroup()
scored_data["tokens"] = list()
scored_data["masks"] = list()
scored_data["scores"] = list()
scored_data["inference_logprobs"] = list()
scores: ScoredDataGroup = {
"tokens": [],
"masks": [],
"scores": [],
"inference_logprobs": [],
}
for choice in completions.choices:
for i, choice in enumerate(completions.choices):
response = choice.message.content or ""
# Tokenize full sequence (prompt + completion)
full_text = prompt_text + response
full_tokens = self.tokenizer.encode(full_text, add_special_tokens=False)
# Create masks: -100 for prompt, actual tokens for completion
masks = [-100] * prompt_len + full_tokens[prompt_len:]
logprobs = [1.0] * prompt_len + [0.0] * (len(full_tokens) - prompt_len)
# Score using reward funcs
completion_messages = messages + [
{"role": "assistant", "content": response}
@ -220,12 +248,18 @@ class VerifiersEnv(BaseEnv):
]
score = sum(weighted_rewards)
scores["tokens"].append(full_tokens)
scores["masks"].append(masks)
scores["inference_logprobs"].append(logprobs)
scores["scores"].append(score)
# Use ManagedServer's properly aligned tokens/masks/logprobs
node = nodes[i]
scored_data["tokens"].append(node.tokens)
scored_data["masks"].append(node.masked_tokens)
scored_data["inference_logprobs"].append(node.logprobs)
scored_data["scores"].append(score)
return scores, []
# Track scores for wandb logging
for score in scored_data["scores"]:
self.percent_correct_buffer.append(max(score, 0))
return scored_data, []
if __name__ == "__main__":