mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
422 lines
15 KiB
Python
422 lines
15 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
InternBootcamp RL Environment for Atropos
|
|
|
|
This environment integrates InternBootcamp's verifiable reasoning tasks with the Atropos
|
|
RL training framework. It supports training on single tasks, with plans for multi-task
|
|
and curriculum learning modes.
|
|
"""
|
|
|
|
import asyncio
|
|
import logging
|
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
|
|
|
from atroposlib.envs.base import (
|
|
APIServerConfig,
|
|
BaseEnv,
|
|
BaseEnvConfig,
|
|
ScoredDataGroup,
|
|
)
|
|
|
|
from .bootcamp_registry import create_bootcamp, get_available_bootcamps
|
|
|
|
logger = logging.getLogger(__name__)
|
|
logger.setLevel(logging.DEBUG)
|
|
|
|
# System prompt for reasoning tasks
|
|
SYSTEM_PROMPT = (
|
|
"You are a deep thinking AI with strong reasoning abilities. You may use "
|
|
"extremely long chains of thought to deeply consider the problem and "
|
|
"deliberate with yourself via systematic reasoning processes to help come "
|
|
"to a correct solution.\n\n"
|
|
"You should enclose your thoughts and internal monologue inside <think> "
|
|
"</think> tags, and then provide your solution or response to the problem. "
|
|
"Please think in English, even if the problem is presented in another "
|
|
"language.\n\n"
|
|
"When solving problems:\n"
|
|
"1. Think step by step through the problem inside <think> tags\n"
|
|
"2. Show your work clearly in your thinking\n"
|
|
"3. Verify your answer before finalizing\n"
|
|
"4. Follow the specific answer format requested in the problem\n\n"
|
|
"Pay close attention to how the problem asks you to format your answer - "
|
|
"some may require specific tags, notations, or formats."
|
|
)
|
|
|
|
|
|
class InternBootcampEnvConfig(BaseEnvConfig):
|
|
"""Configuration for the InternBootcamp environment."""
|
|
|
|
# Task selection
|
|
task_name: str = "RandomTask" # Random task selection mode
|
|
|
|
# Task-specific parameters
|
|
task_params: Dict[str, Any] = {}
|
|
|
|
# Reward configuration
|
|
correct_reward: float = 1.0
|
|
incorrect_reward: float = -0.5
|
|
format_bonus: float = 0.2
|
|
|
|
# Training parameters
|
|
require_reasoning: bool = True
|
|
min_reasoning_length: int = 50
|
|
temperature: float = 0.7
|
|
top_p: float = 0.9
|
|
|
|
|
|
class InternBootcampEnv(BaseEnv):
|
|
"""Environment for training on InternBootcamp reasoning tasks."""
|
|
|
|
name = "intern_bootcamp"
|
|
|
|
def __init__(
|
|
self,
|
|
config: InternBootcampEnvConfig,
|
|
server_configs: Union[List[APIServerConfig], APIServerConfig],
|
|
slurm=True,
|
|
testing=False,
|
|
):
|
|
super().__init__(config, server_configs, slurm, testing)
|
|
self.config = config
|
|
|
|
# Task tracking
|
|
self.bootcamp_instance = None
|
|
self.current_task_name = config.task_name
|
|
|
|
# Performance tracking
|
|
self.task_correct_buffer = []
|
|
self.format_correct_buffer = []
|
|
self.eval_metrics = []
|
|
|
|
self.system_prompt = SYSTEM_PROMPT
|
|
|
|
async def setup(self):
|
|
"""Initialize the environment and bootcamp task."""
|
|
logger.info(f"Setting up InternBootcampEnv with task: {self.config.task_name}")
|
|
|
|
# Log available bootcamps
|
|
available = get_available_bootcamps()
|
|
logger.info(f"Found {len(available)} available bootcamp tasks")
|
|
logger.debug(f"Available tasks (first 20): {available[:20]}")
|
|
|
|
# Initialize the bootcamp task
|
|
self._initialize_bootcamp()
|
|
|
|
# Generate some test problems to verify setup
|
|
try:
|
|
for i in range(3):
|
|
identity = self.bootcamp_instance.case_generator()
|
|
prompt = self.bootcamp_instance.prompt_func(identity)
|
|
logger.info(f"Test problem {i+1}: {prompt[:100]}...")
|
|
except Exception as e:
|
|
logger.error(f"Failed to generate test problems: {e}")
|
|
raise
|
|
|
|
def _initialize_bootcamp(self):
|
|
"""Initialize the bootcamp instance based on task name."""
|
|
try:
|
|
# Create bootcamp instance using the registry
|
|
self.bootcamp_instance = create_bootcamp(
|
|
self.config.task_name, **self.config.task_params
|
|
)
|
|
logger.info(
|
|
f"Initialized {self.config.task_name} with params: {self.config.task_params}"
|
|
)
|
|
except ValueError as e:
|
|
# If task not found, list available tasks
|
|
available = get_available_bootcamps()
|
|
logger.error(f"Task '{self.config.task_name}' not found!")
|
|
logger.error(f"Available tasks (showing first 20): {available[:20]}")
|
|
raise e
|
|
except Exception as e:
|
|
logger.error(f"Failed to initialize bootcamp: {e}")
|
|
raise
|
|
|
|
async def get_next_item(self) -> Tuple[Any, Dict]:
|
|
"""Get the next problem from the bootcamp."""
|
|
# Generate a new problem
|
|
identity = self.bootcamp_instance.case_generator()
|
|
prompt = self.bootcamp_instance.prompt_func(identity)
|
|
|
|
# Log which bootcamp is being used if RandomTask
|
|
if (
|
|
self.config.task_name == "RandomTask"
|
|
and isinstance(identity, dict)
|
|
and "_bootcamp_name" in identity
|
|
):
|
|
logger.info(f"RandomTask selected: {identity['_bootcamp_name']}")
|
|
|
|
# Create the message format expected by Atropos
|
|
messages = [
|
|
{"role": "system", "content": self.system_prompt},
|
|
{"role": "user", "content": prompt},
|
|
]
|
|
|
|
# Return item with metadata
|
|
return (
|
|
messages,
|
|
{
|
|
"identity": identity,
|
|
"task_name": self.current_task_name,
|
|
"raw_prompt": prompt,
|
|
},
|
|
)
|
|
|
|
async def collect_trajectories(self, item) -> Tuple[List, List]:
|
|
"""Collect trajectories for the current item."""
|
|
messages, metadata = item
|
|
logger.info(f"Collecting trajectories for item: {messages}")
|
|
|
|
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
|
|
# Get completions from the model using chat_completion
|
|
completions = await managed.chat_completion(
|
|
messages=messages,
|
|
n=self.config.group_size,
|
|
max_tokens=self.config.max_token_length,
|
|
temperature=self.config.temperature,
|
|
top_p=self.config.top_p,
|
|
)
|
|
|
|
state = managed.get_state()
|
|
nodes = state["nodes"]
|
|
|
|
to_score = []
|
|
|
|
for i, completion in enumerate(completions.choices):
|
|
model_response = completion.message.content
|
|
|
|
# Create full conversation for scoring
|
|
full_messages = messages + [
|
|
{"role": "assistant", "content": model_response}
|
|
]
|
|
|
|
to_score.append(
|
|
{
|
|
"messages": full_messages,
|
|
"metadata": metadata,
|
|
"model_response": model_response,
|
|
"tokens": nodes[i].tokens,
|
|
"masks": nodes[i].masked_tokens,
|
|
"logprobs": nodes[i].logprobs,
|
|
}
|
|
)
|
|
|
|
# Score the trajectories immediately and return a ScoredDataGroup
|
|
scored_data = await self.score(to_score)
|
|
backlog = [] # No backlog items for now
|
|
|
|
return scored_data, backlog
|
|
|
|
async def score(self, rollout_group_data) -> ScoredDataGroup:
|
|
"""Score the collected trajectories using bootcamp verification."""
|
|
scored_data = ScoredDataGroup()
|
|
scored_data["tokens"] = []
|
|
scored_data["masks"] = []
|
|
scored_data["scores"] = []
|
|
scored_data["messages"] = []
|
|
scored_data["inference_logprobs"] = []
|
|
|
|
for rollout_item in rollout_group_data:
|
|
messages = rollout_item["messages"]
|
|
metadata = rollout_item["metadata"]
|
|
model_response = rollout_item["model_response"]
|
|
tokens = rollout_item["tokens"]
|
|
masks = rollout_item["masks"]
|
|
logprobs = rollout_item["logprobs"]
|
|
|
|
# Verify the response using the bootcamp
|
|
identity = metadata["identity"]
|
|
|
|
# Calculate base score from bootcamp verification
|
|
base_score = self.bootcamp_instance.verify_score(
|
|
model_response,
|
|
identity,
|
|
format_score=self.config.format_bonus,
|
|
short_penalty=self.config.require_reasoning,
|
|
short_threshold=self.config.min_reasoning_length,
|
|
)
|
|
|
|
# Apply reward scaling
|
|
if base_score >= 1.0:
|
|
# Correct answer with format
|
|
final_score = self.config.correct_reward
|
|
self.task_correct_buffer.append(1)
|
|
self.format_correct_buffer.append(1)
|
|
elif base_score > 0:
|
|
# Correct format but wrong answer
|
|
final_score = self.config.incorrect_reward + base_score
|
|
self.task_correct_buffer.append(0)
|
|
self.format_correct_buffer.append(1)
|
|
else:
|
|
# Wrong answer and/or format
|
|
final_score = self.config.incorrect_reward
|
|
self.task_correct_buffer.append(0)
|
|
self.format_correct_buffer.append(0)
|
|
|
|
# Log the scoring details
|
|
logger.debug(
|
|
f"Scored response: base_score={base_score}, "
|
|
f"final_score={final_score}, "
|
|
f"identity={identity}"
|
|
)
|
|
|
|
scored_data["tokens"].append(tokens)
|
|
scored_data["masks"].append(masks)
|
|
scored_data["inference_logprobs"].append(logprobs)
|
|
scored_data["scores"].append(final_score)
|
|
scored_data["messages"].append(messages)
|
|
|
|
return scored_data
|
|
|
|
async def evaluate(self, *args, **kwargs):
|
|
"""Evaluate the model on test problems."""
|
|
logger.info(f"Starting evaluation for {self.current_task_name}")
|
|
|
|
eval_tasks = []
|
|
num_eval_problems = 20 # Number of problems to evaluate on
|
|
|
|
# Generate evaluation problems
|
|
for i in range(num_eval_problems):
|
|
eval_tasks.append(self.evaluate_single_problem())
|
|
|
|
# Run evaluations in parallel
|
|
results = await asyncio.gather(*eval_tasks)
|
|
|
|
# Calculate metrics
|
|
correct_count = sum(1 for is_correct, _ in results if is_correct)
|
|
format_count = sum(1 for _, has_format in results if has_format)
|
|
total_count = len(results)
|
|
|
|
accuracy = correct_count / total_count if total_count > 0 else 0
|
|
format_rate = format_count / total_count if total_count > 0 else 0
|
|
|
|
logger.info(
|
|
f"Evaluation complete: accuracy={accuracy:.2%}, "
|
|
f"format_rate={format_rate:.2%} "
|
|
f"({correct_count}/{total_count} correct)"
|
|
)
|
|
|
|
# Store metrics for wandb logging
|
|
self.eval_metrics.append((f"eval/{self.current_task_name}_accuracy", accuracy))
|
|
self.eval_metrics.append(
|
|
(f"eval/{self.current_task_name}_format_rate", format_rate)
|
|
)
|
|
self.eval_metrics.append(("eval/overall_accuracy", accuracy))
|
|
|
|
return self.eval_metrics
|
|
|
|
async def evaluate_single_problem(self) -> Tuple[bool, bool]:
|
|
"""Evaluate a single problem."""
|
|
try:
|
|
# Generate a problem
|
|
identity = self.bootcamp_instance.case_generator()
|
|
prompt = self.bootcamp_instance.prompt_func(identity)
|
|
|
|
# Create messages
|
|
messages = [
|
|
{"role": "system", "content": self.system_prompt},
|
|
{"role": "user", "content": prompt},
|
|
]
|
|
|
|
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
|
|
# Get model response using chat_completion
|
|
completion = await managed.chat_completion(
|
|
messages=messages,
|
|
n=1,
|
|
max_tokens=self.config.max_token_length,
|
|
temperature=0.0, # Deterministic for evaluation
|
|
top_p=1.0,
|
|
split="eval",
|
|
)
|
|
|
|
model_response = completion.choices[0].message.content
|
|
|
|
# Score the response
|
|
score = self.bootcamp_instance.verify_score(
|
|
model_response,
|
|
identity,
|
|
format_score=self.config.format_bonus,
|
|
short_penalty=False, # Don't penalize short responses in eval
|
|
)
|
|
|
|
is_correct = score >= 1.0
|
|
has_format = score > 0
|
|
|
|
return is_correct, has_format
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error evaluating problem: {e}")
|
|
return False, False
|
|
|
|
async def wandb_log(self, wandb_metrics: Optional[Dict] = None):
|
|
"""Log metrics to wandb."""
|
|
if wandb_metrics is None:
|
|
wandb_metrics = {}
|
|
|
|
# Add training metrics
|
|
if self.task_correct_buffer:
|
|
wandb_metrics[f"train/{self.current_task_name}_accuracy"] = sum(
|
|
self.task_correct_buffer
|
|
) / len(self.task_correct_buffer)
|
|
|
|
if self.format_correct_buffer:
|
|
wandb_metrics[f"train/{self.current_task_name}_format_rate"] = sum(
|
|
self.format_correct_buffer
|
|
) / len(self.format_correct_buffer)
|
|
|
|
# Add evaluation metrics
|
|
for metric_name, value in self.eval_metrics:
|
|
wandb_metrics[metric_name] = value
|
|
|
|
# Clear buffers
|
|
self.task_correct_buffer = []
|
|
self.format_correct_buffer = []
|
|
self.eval_metrics = []
|
|
|
|
await super().wandb_log(wandb_metrics)
|
|
|
|
@classmethod
|
|
def config_init(cls) -> Tuple[InternBootcampEnvConfig, List[APIServerConfig]]:
|
|
"""Initialize environment and server configurations."""
|
|
env_config = InternBootcampEnvConfig(
|
|
tokenizer_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview",
|
|
group_size=8,
|
|
use_wandb=True,
|
|
max_num_workers_per_node=16,
|
|
rollout_server_url="http://localhost:8000",
|
|
total_steps=10000,
|
|
batch_size=1024,
|
|
steps_per_eval=100,
|
|
max_token_length=8192,
|
|
inference_weight=1.0,
|
|
wandb_name="intern_bootcamp_random_tasks",
|
|
data_path_to_save_groups="data/intern_bootcamp_random_tasks.jsonl",
|
|
# Task configuration
|
|
task_name="RandomTask",
|
|
task_params={},
|
|
# Reward configuration
|
|
correct_reward=1.0,
|
|
incorrect_reward=-0.5,
|
|
format_bonus=0.2,
|
|
# Training parameters
|
|
require_reasoning=True,
|
|
min_batch_allocation=0.1,
|
|
min_reasoning_length=50,
|
|
temperature=0.7,
|
|
top_p=0.9,
|
|
)
|
|
|
|
server_configs = [
|
|
APIServerConfig(
|
|
model_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview",
|
|
base_url="http://localhost:9004/v1",
|
|
api_key="x",
|
|
num_requests_for_eval=64,
|
|
)
|
|
]
|
|
|
|
return env_config, server_configs
|
|
|
|
|
|
if __name__ == "__main__":
|
|
InternBootcampEnv.cli()
|