mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-24 17:04:55 +00:00
fix env_args, dataset/prompt loading
This commit is contained in:
parent
7907ffd0ad
commit
a1d1e7d7fe
1 changed files with 59 additions and 33 deletions
|
|
@ -8,13 +8,13 @@ Supports TWO modes:
|
|||
Usage:
|
||||
# RL Training (requires local vLLM/SGLang server)
|
||||
python verifiers_server.py serve \
|
||||
--env.vf_env_name "will/wordle" \
|
||||
--env.vf_env_name "primeintellect/alphabet-sort" \
|
||||
--openai.base_url http://localhost:9001/v1 \
|
||||
--slurm false
|
||||
|
||||
# SFT Data Generation with OpenAI GPT-4o
|
||||
python verifiers_server.py process \
|
||||
--env.vf_env_name "will/wordle" \
|
||||
--env.vf_env_name "primeintellect/alphabet-sort" \
|
||||
--env.data_path_to_save_groups gpt4o_sft_data.jsonl \
|
||||
--env.total_steps 100 \
|
||||
--env.group_size 4 \
|
||||
|
|
@ -23,30 +23,30 @@ Usage:
|
|||
|
||||
# SFT Data Generation with local server
|
||||
python verifiers_server.py process \
|
||||
--env.vf_env_name "will/wordle" \
|
||||
--env.vf_env_name "primeintellect/alphabet-sort" \
|
||||
--env.data_path_to_save_groups local_sft_data.jsonl \
|
||||
--openai.base_url http://localhost:9001/v1
|
||||
|
||||
# Evaluation (uses ManagedServer by default, falls back to direct API in process mode)
|
||||
python verifiers_server.py evaluate \
|
||||
--env.vf_env_name "will/wordle" \
|
||||
--env.vf_env_name "primeintellect/alphabet-sort" \
|
||||
--openai.base_url http://localhost:9001/v1
|
||||
|
||||
To install a Verifiers/Prime environment:
|
||||
1. uv tool install prime
|
||||
2. prime login
|
||||
3. prime env install will/wordle (or any owner/environment)
|
||||
3. prime env install primeintellect/alphabet-sort (or any owner/environment)
|
||||
Docs: https://docs.primeintellect.ai/tutorials-environments/install
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import verifiers as vf
|
||||
from openai import AsyncOpenAI
|
||||
from pydantic import Field
|
||||
|
||||
from atroposlib.envs.base import (
|
||||
APIServerConfig,
|
||||
|
|
@ -61,7 +61,13 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
class VfEnvConfig(BaseEnvConfig):
|
||||
vf_env_name: str = ""
|
||||
env_args: Dict[str, Any] = Field(default_factory=dict)
|
||||
env_args: str = "{}"
|
||||
|
||||
def get_env_args(self) -> Dict[str, Any]:
|
||||
"""Parse env_args JSON string into dict."""
|
||||
if isinstance(self.env_args, dict):
|
||||
return self.env_args
|
||||
return json.loads(self.env_args)
|
||||
|
||||
|
||||
class VerifiersEnv(BaseEnv):
|
||||
|
|
@ -85,7 +91,10 @@ class VerifiersEnv(BaseEnv):
|
|||
self.groups_total: int = 0
|
||||
|
||||
logger.info("Loading verifiers environment: %s", config.vf_env_name)
|
||||
self.vf_env = vf.load_environment(config.vf_env_name, **config.env_args)
|
||||
env_args = config.get_env_args()
|
||||
if env_args:
|
||||
logger.info("Environment args: %s", env_args)
|
||||
self.vf_env = vf.load_environment(config.vf_env_name, **env_args)
|
||||
self.rubric = self.vf_env.rubric
|
||||
self.system_prompt = self.vf_env.system_prompt
|
||||
|
||||
|
|
@ -93,6 +102,10 @@ class VerifiersEnv(BaseEnv):
|
|||
self.reward_func_names = self.rubric._get_reward_func_names()
|
||||
logger.info("Reward functions: %s", self.reward_func_names)
|
||||
|
||||
# Log multi-turn config if available
|
||||
if hasattr(self.vf_env, "max_turns"):
|
||||
logger.info("Max turns: %d", self.vf_env.max_turns)
|
||||
|
||||
@classmethod
|
||||
def config_init(cls) -> Tuple[VfEnvConfig, List[APIServerConfig]]:
|
||||
env_config = VfEnvConfig(
|
||||
|
|
@ -172,10 +185,9 @@ class VerifiersEnv(BaseEnv):
|
|||
await super().wandb_log(wandb_metrics)
|
||||
|
||||
async def setup(self):
|
||||
# Dataset already has: prompt, answer, info, example_id, task
|
||||
train_data = self.vf_env.get_dataset()
|
||||
columns_to_keep = ["question", "answer", "info"]
|
||||
available_columns = [c for c in columns_to_keep if c in train_data.column_names]
|
||||
self.train = train_data.select_columns(available_columns).to_list()
|
||||
self.train = train_data.to_list()
|
||||
self.iter = 0
|
||||
|
||||
def save_checkpoint(self, step, data=None):
|
||||
|
|
@ -193,13 +205,6 @@ class VerifiersEnv(BaseEnv):
|
|||
"""No-op. Use environments/eval_environments/verifiers_eval.py for evaluation."""
|
||||
return {}
|
||||
|
||||
def _build_initial_messages(self, question: str) -> List[Dict[str, str]]:
|
||||
messages: List[Dict[str, str]] = []
|
||||
if self.system_prompt:
|
||||
messages.append({"role": "system", "content": self.system_prompt})
|
||||
messages.append({"role": "user", "content": question})
|
||||
return messages
|
||||
|
||||
async def collect_trajectories(self, item) -> Tuple[ScoredDataGroup, list]:
|
||||
is_process_mode = getattr(self, "process_mode", False)
|
||||
if is_process_mode:
|
||||
|
|
@ -217,16 +222,16 @@ class VerifiersEnv(BaseEnv):
|
|||
timeout=server_config.timeout,
|
||||
)
|
||||
|
||||
initial_messages = self._build_initial_messages(item["question"])
|
||||
# item already has prompt, answer, example_id, task, info from dataset
|
||||
inputs = [
|
||||
{
|
||||
"prompt": initial_messages,
|
||||
"answer": item["answer"],
|
||||
"example_id": i,
|
||||
"task": self.config.vf_env_name,
|
||||
"prompt": item["prompt"],
|
||||
"answer": item.get("answer", ""),
|
||||
"example_id": item["example_id"],
|
||||
"task": item.get("task", self.config.vf_env_name),
|
||||
"info": item.get("info", {}),
|
||||
}
|
||||
for i in range(self.config.group_size)
|
||||
for _ in range(self.config.group_size)
|
||||
]
|
||||
|
||||
results = await self.vf_env.generate(
|
||||
|
|
@ -279,7 +284,9 @@ class VerifiersEnv(BaseEnv):
|
|||
|
||||
# Capture metrics for wandb logging
|
||||
self.reward_buffer.append(reward)
|
||||
self.num_turns_buffer.append(len(trajectory))
|
||||
num_turns = len(trajectory)
|
||||
self.num_turns_buffer.append(num_turns)
|
||||
logger.debug("Rollout: %d turns, reward=%.3f", num_turns, reward)
|
||||
|
||||
# Extract per-function metrics from verifiers state
|
||||
state_metrics = state.get("metrics", {})
|
||||
|
|
@ -288,6 +295,15 @@ class VerifiersEnv(BaseEnv):
|
|||
if isinstance(metric_value, (int, float)):
|
||||
self.metrics_buffer[metric_name].append(float(metric_value))
|
||||
|
||||
# Log group summary
|
||||
turns = [len(s.get("trajectory", [])) for s in results["state"]]
|
||||
logger.info(
|
||||
"Group: %d rollouts, turns=%s, rewards=%s",
|
||||
len(results["state"]),
|
||||
turns,
|
||||
[f"{s:.3f}" for s in scored_data["scores"]],
|
||||
)
|
||||
|
||||
# Track group-level identical scores for debugging
|
||||
self.groups_total += 1
|
||||
if len(set(scored_data["scores"])) == 1:
|
||||
|
|
@ -307,7 +323,6 @@ class VerifiersEnv(BaseEnv):
|
|||
AtroposManagedClient,
|
||||
)
|
||||
|
||||
initial_messages = self._build_initial_messages(item["question"])
|
||||
sampling_args = {
|
||||
"temperature": 1.0,
|
||||
"max_completion_tokens": self.config.max_token_length,
|
||||
|
|
@ -315,17 +330,17 @@ class VerifiersEnv(BaseEnv):
|
|||
score_sem = asyncio.Semaphore(self.config.group_size)
|
||||
model = self.server.servers[0].config.model_name
|
||||
|
||||
async def run_rollout(example_id: int) -> Dict[str, Any]:
|
||||
async def run_rollout() -> Dict[str, Any]:
|
||||
"""Run a single rollout and return full state for metrics extraction."""
|
||||
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
|
||||
client = AtroposManagedClient(managed_server=managed, model=model)
|
||||
|
||||
state = await self.vf_env.rollout(
|
||||
input={
|
||||
"prompt": initial_messages,
|
||||
"answer": item["answer"],
|
||||
"example_id": example_id,
|
||||
"task": self.config.vf_env_name,
|
||||
"prompt": item["prompt"],
|
||||
"answer": item.get("answer", ""),
|
||||
"example_id": item["example_id"],
|
||||
"task": item.get("task", self.config.vf_env_name),
|
||||
"info": item.get("info", {}),
|
||||
},
|
||||
client=client,
|
||||
|
|
@ -336,7 +351,7 @@ class VerifiersEnv(BaseEnv):
|
|||
return state
|
||||
|
||||
states = await asyncio.gather(
|
||||
*[run_rollout(i) for i in range(self.config.group_size)]
|
||||
*[run_rollout() for _ in range(self.config.group_size)]
|
||||
)
|
||||
|
||||
scored_data = ScoredDataGroup()
|
||||
|
|
@ -356,7 +371,9 @@ class VerifiersEnv(BaseEnv):
|
|||
self.reward_buffer.append(reward)
|
||||
|
||||
trajectory = state.get("trajectory", [])
|
||||
self.num_turns_buffer.append(len(trajectory))
|
||||
num_turns = len(trajectory)
|
||||
self.num_turns_buffer.append(num_turns)
|
||||
logger.debug("Rollout: %d turns, reward=%.3f", num_turns, reward)
|
||||
|
||||
# Extract per-function metrics from verifiers state
|
||||
state_metrics = state.get("metrics", {})
|
||||
|
|
@ -365,6 +382,15 @@ class VerifiersEnv(BaseEnv):
|
|||
if isinstance(metric_value, (int, float)):
|
||||
self.metrics_buffer[metric_name].append(float(metric_value))
|
||||
|
||||
# Log group summary
|
||||
turns = [len(s.get("trajectory", [])) for s in states]
|
||||
logger.info(
|
||||
"Group: %d rollouts, turns=%s, rewards=%s",
|
||||
len(states),
|
||||
turns,
|
||||
[f"{s:.3f}" for s in scored_data["scores"]],
|
||||
)
|
||||
|
||||
# Track group-level identical scores for debugging
|
||||
self.groups_total += 1
|
||||
if len(set(scored_data["scores"])) == 1:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue