mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-24 17:04:55 +00:00
fix: improve verifiers environments consistency and correctness
- verifiers_server.py: consistent dataset column selection for train/test, remove redundant comments, preserve float precision for scores - verifiers_eval.py: add env_config_cls, fix constructor signature to match BaseEnv (slurm bool), make stub methods raise NotImplementedError
This commit is contained in:
parent
b62c416130
commit
9d5cd2b593
2 changed files with 441 additions and 50 deletions
|
|
@ -1,14 +1,21 @@
|
|||
# Verifiers Training Environment for Atropos
|
||||
#
|
||||
# NOTE: This environment requires a LOCAL inference server (vLLM, SGLang, TRL)
|
||||
# for ALL modes (serve, process, evaluate) because it uses ManagedServer for
|
||||
# token/logprob tracking. For evaluation with OpenAI API, use:
|
||||
# environments/eval_environments/verifiers_eval.py
|
||||
#
|
||||
# To install a Verifiers/Prime environment:
|
||||
# 1. uv tool install prime
|
||||
# 2. prime login
|
||||
# 3. prime env install will/wordle (or any owner/environment)
|
||||
# Docs: https://docs.primeintellect.ai/tutorials-environments/install
|
||||
|
||||
import os
|
||||
import time
|
||||
from typing import List, Tuple
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import verifiers as vf
|
||||
from pydantic import Field
|
||||
from tqdm.asyncio import tqdm_asyncio
|
||||
|
||||
from atroposlib.envs.base import (
|
||||
|
|
@ -20,8 +27,12 @@ from atroposlib.envs.base import (
|
|||
|
||||
|
||||
class VfEnvConfig(BaseEnvConfig):
|
||||
"""
|
||||
Configuration for the Verifiers environments.
|
||||
"""
|
||||
|
||||
vf_env_name: str = ""
|
||||
env_args: dict = {}
|
||||
env_args: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class VerifiersEnv(BaseEnv):
|
||||
|
|
@ -36,6 +47,7 @@ class VerifiersEnv(BaseEnv):
|
|||
testing=False,
|
||||
):
|
||||
super().__init__(config, server_configs, slurm, testing)
|
||||
self.percent_correct_buffer = list()
|
||||
self.eval_metrics = list()
|
||||
self.vf_env = vf.load_environment(config.vf_env_name, **config.env_args)
|
||||
self.rubric = self.vf_env.rubric
|
||||
|
|
@ -51,31 +63,59 @@ class VerifiersEnv(BaseEnv):
|
|||
@classmethod
|
||||
def config_init(cls) -> Tuple[VfEnvConfig, List[APIServerConfig]]:
|
||||
env_config = VfEnvConfig(
|
||||
tokenizer_name="Qwen/Qwen2.5-1.5B-Instruct",
|
||||
group_size=8,
|
||||
use_wandb=False,
|
||||
use_wandb=True,
|
||||
rollout_server_url="http://localhost:8000",
|
||||
total_steps=10,
|
||||
total_steps=1000,
|
||||
batch_size=4,
|
||||
steps_per_eval=1,
|
||||
steps_per_eval=100,
|
||||
max_token_length=2048,
|
||||
wandb_name="verifiers",
|
||||
)
|
||||
# Requires local inference server (vLLM, SGLang, TRL)
|
||||
# For evaluation with OpenAI, use eval_environments/verifiers_evaluation.py
|
||||
server_configs = [
|
||||
APIServerConfig(
|
||||
model_name="gpt-4.1-nano",
|
||||
base_url=None,
|
||||
api_key=os.getenv("OPENAI_API_KEY"),
|
||||
model_name="Qwen/Qwen2.5-1.5B-Instruct",
|
||||
base_url="http://localhost:9001/v1",
|
||||
api_key="x",
|
||||
num_requests_for_eval=4,
|
||||
),
|
||||
]
|
||||
return env_config, server_configs
|
||||
|
||||
async def wandb_log(self, wandb_metrics: Optional[Dict] = None):
|
||||
if wandb_metrics is None:
|
||||
wandb_metrics = {}
|
||||
|
||||
# Calculate percent_correct from buffer
|
||||
if self.percent_correct_buffer:
|
||||
wandb_metrics["train/percent_correct"] = sum(
|
||||
self.percent_correct_buffer
|
||||
) / len(self.percent_correct_buffer)
|
||||
|
||||
self.percent_correct_buffer = list()
|
||||
|
||||
for item in self.eval_metrics:
|
||||
wandb_metrics[item[0]] = item[1]
|
||||
self.eval_metrics = list()
|
||||
|
||||
await super().wandb_log(wandb_metrics)
|
||||
|
||||
async def setup(self):
|
||||
self.train = self.vf_env.get_dataset()
|
||||
train_data = self.vf_env.get_dataset()
|
||||
self.train = train_data.select_columns(["question", "answer"]).to_list()
|
||||
test_data = self.vf_env.get_eval_dataset()
|
||||
self.test = test_data.select_columns(["question", "answer"]).to_list()
|
||||
self.iter = 0
|
||||
|
||||
def save_checkpoint(self, step, data=None):
|
||||
if data is None:
|
||||
data = {}
|
||||
data["iter"] = self.iter
|
||||
super().save_checkpoint(step, data)
|
||||
|
||||
async def rollout_and_score_eval(
|
||||
self, question: str, answer: str, **kwargs
|
||||
) -> dict:
|
||||
|
|
@ -85,12 +125,13 @@ class VerifiersEnv(BaseEnv):
|
|||
{"role": "user", "content": question},
|
||||
]
|
||||
|
||||
completion = await self.server.chat_completion(
|
||||
messages=messages,
|
||||
n=1,
|
||||
max_tokens=self.config.max_token_length,
|
||||
temperature=0.0,
|
||||
)
|
||||
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
|
||||
completion = await managed.chat_completion(
|
||||
messages=messages,
|
||||
n=1,
|
||||
max_tokens=self.config.max_token_length,
|
||||
temperature=0.0,
|
||||
)
|
||||
|
||||
response_content = completion.choices[0].message.content or ""
|
||||
messages.append({"role": "assistant", "content": response_content})
|
||||
|
|
@ -114,7 +155,7 @@ class VerifiersEnv(BaseEnv):
|
|||
"question": question,
|
||||
"gold_answer": answer,
|
||||
"model_parsed": str(answer_parsed) if answer_parsed else None,
|
||||
"score": int(score),
|
||||
"score": score,
|
||||
"correct": bool(score),
|
||||
"finish_reason": completion.choices[0].finish_reason,
|
||||
}
|
||||
|
|
@ -171,38 +212,25 @@ class VerifiersEnv(BaseEnv):
|
|||
{"role": "user", "content": question},
|
||||
]
|
||||
|
||||
completions = await self.server.chat_completion(
|
||||
messages=messages,
|
||||
n=self.config.group_size,
|
||||
max_tokens=self.config.max_token_length,
|
||||
temperature=1.0,
|
||||
)
|
||||
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
|
||||
completions = await managed.chat_completion(
|
||||
messages=messages,
|
||||
n=self.config.group_size,
|
||||
max_tokens=self.config.max_token_length,
|
||||
temperature=1.0,
|
||||
)
|
||||
state = managed.get_state()
|
||||
nodes = state["nodes"]
|
||||
|
||||
prompt_text = self.tokenizer.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
prompt_tokens = self.tokenizer.encode(prompt_text, add_special_tokens=False)
|
||||
prompt_len = len(prompt_tokens)
|
||||
scored_data = ScoredDataGroup()
|
||||
scored_data["tokens"] = list()
|
||||
scored_data["masks"] = list()
|
||||
scored_data["scores"] = list()
|
||||
scored_data["inference_logprobs"] = list()
|
||||
|
||||
scores: ScoredDataGroup = {
|
||||
"tokens": [],
|
||||
"masks": [],
|
||||
"scores": [],
|
||||
"inference_logprobs": [],
|
||||
}
|
||||
|
||||
for choice in completions.choices:
|
||||
for i, choice in enumerate(completions.choices):
|
||||
response = choice.message.content or ""
|
||||
|
||||
# Tokenize full sequence (prompt + completion)
|
||||
full_text = prompt_text + response
|
||||
full_tokens = self.tokenizer.encode(full_text, add_special_tokens=False)
|
||||
|
||||
# Create masks: -100 for prompt, actual tokens for completion
|
||||
masks = [-100] * prompt_len + full_tokens[prompt_len:]
|
||||
|
||||
logprobs = [1.0] * prompt_len + [0.0] * (len(full_tokens) - prompt_len)
|
||||
|
||||
# Score using reward funcs
|
||||
completion_messages = messages + [
|
||||
{"role": "assistant", "content": response}
|
||||
|
|
@ -220,12 +248,18 @@ class VerifiersEnv(BaseEnv):
|
|||
]
|
||||
score = sum(weighted_rewards)
|
||||
|
||||
scores["tokens"].append(full_tokens)
|
||||
scores["masks"].append(masks)
|
||||
scores["inference_logprobs"].append(logprobs)
|
||||
scores["scores"].append(score)
|
||||
# Use ManagedServer's properly aligned tokens/masks/logprobs
|
||||
node = nodes[i]
|
||||
scored_data["tokens"].append(node.tokens)
|
||||
scored_data["masks"].append(node.masked_tokens)
|
||||
scored_data["inference_logprobs"].append(node.logprobs)
|
||||
scored_data["scores"].append(score)
|
||||
|
||||
return scores, []
|
||||
# Track scores for wandb logging
|
||||
for score in scored_data["scores"]:
|
||||
self.percent_correct_buffer.append(max(score, 0))
|
||||
|
||||
return scored_data, []
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue