"""
MT-Bench Evaluation Environment for Atropos (Generative Mode with LLM Judge)
This environment evaluates models on MT-Bench - a multi-turn conversational
benchmark for evaluating language models.
Dataset: lighteval/mt-bench
MT-Bench consists of 80 high-quality multi-turn questions across 8 categories:
- Writing, Roleplay, Reasoning, Math, Coding, Extraction, STEM, Humanities
Model responses are evaluated by an LLM judge on a 1-5 scale.
The evaluation follows the refusalbench pattern for LLM judge configuration:
- Separate judge model with configurable endpoint/API key
- Rate limiting and retry logic
- Concurrent call limits with semaphore
- Fallback scoring when judge fails
Supports thinking mode with tags for extended reasoning.
"""
import asyncio
import os
import random
import re
from typing import Dict, List, Optional, Tuple
import wandb
from datasets import load_dataset
from eval_helpers import (
create_system_content,
extract_thinking_content,
get_default_thinking_prompt,
save_eval_results,
validate_thinking_format,
)
from pydantic import Field
from tenacity import (
retry,
retry_if_exception_type,
stop_after_attempt,
wait_random_exponential,
)
from tqdm.asyncio import tqdm_asyncio
from atroposlib.envs.base import (
APIServerConfig,
BaseEnv,
BaseEnvConfig,
)
# MT-Bench categories
MT_BENCH_CATEGORIES = [
"writing",
"roleplay",
"reasoning",
"math",
"coding",
"extraction",
"stem",
"humanities",
]
# Judge prompt templates (from lighteval)
def judge_prompt_with_reference(
question: str, answer: str, reference: Optional[str] = None
) -> List[Dict]:
"""Create judge prompt with optional reference answer."""
reference_text = (
f"""the reference answer is:
{reference}"""
if reference
else ""
)
return [
{
"role": "user",
"content": f"""# GOAL
Your job is to evaluate a task carried out by an AI system powered by a large language model.
You will be provided with the inputs and output of the task, as well as the evaluation criteria and scoring rubric. Your task is to evaluate the output of the AI system based on the evaluation criteria and scoring rubric provided.
# INPUT
Below are the inputs required for performing the task: # noqa: E501
{question}
# OUTPUT
Below is the output of the task:
# EVALUATION CRITERIA AND SCORING RUBRIC
Here are the evaluation criteria and the rubric that you need to use for evaluating the task:
How well the response answers the question?{' ' + reference_text if reference_text else ''}
- Score 1: The response completely fails to answer the question.
- Score 2: The response barely answers the question.
- Score 3: The response partially answers the question.
- Score 4: The response mostly answers the question.
- Score 5: The response completely answers the question.
# INSTRUCTIONS FOR THE EVALUATION
1. Understand the task and criteria: Familiarize yourself with the task to be evaluated. Review the evaluation criteria and scoring rubric to understand the different levels of performance and the descriptions for each score.
2. Review the inputs and output: Look at the inputs provided for the task. Examine the output generated from completing the task.
3. Compare output to score descriptions: Compare the output against the criteria and score descriptions in the scoring rubric. For each criterion, decide which description best matches the output.
4. After comparing the output to the score descriptions, pay attention to the small details that might impact the final score that you assign. Sometimes a small difference can dictate the final score. # noqa: E501
5. Write verbal feedback justifying your evaluation that includes a detailed rationale, referring to specific aspects of the output and comparing them to the rubric. # noqa: E501
6. Assign a final score based on the scoring rubric. # noqa: E501
# noqa: E501
## FORMAT FOR THE EVALUATION # noqa: E501
- Write the verbal feedback inside tags without any additional surrounding text.
- Write the numeric score inside tags, without any additional surrounding text and always after the feedback.
Please accurately evaluate the task. Strictly adhere to the evaluation criteria and rubric.""",
}
]
class MTBenchEvalConfig(BaseEnvConfig):
"""Configuration for MT-Bench evaluation environment with LLM judge."""
# Dataset configuration
dataset_name: str = Field(
default="lighteval/mt-bench", description="HuggingFace dataset name"
)
subset: str = Field(default="default", description="Dataset subset")
eval_split: str = Field(default="train", description="Split to evaluate on")
categories: Optional[List[str]] = Field(
default=None,
description="List of categories to evaluate (None = all categories)",
)
shuffle_seed: int = Field(default=42, description="Random seed for shuffling")
# Model generation parameters
eval_temperature: float = Field(
default=0.6, description="Temperature for model evaluation completions"
)
eval_max_tokens: int = Field(
default=0, description="Max tokens for model evaluation (0 = use model default)"
)
# System prompt configuration
custom_system_prompt: Optional[str] = Field(
default=None, description="Optional custom system prompt"
)
# Thinking mode configuration
thinking_mode: bool = Field(
default=True,
description="Whether to use thinking mode with tags",
)
custom_thinking_prompt: Optional[str] = Field(
default=None, description="Optional custom thinking prompt"
)
# Judge model configuration (following refusalbench pattern)
judge_model_name: str = Field(
default="gpt-4o", description="Model name for the judge"
)
judge_base_url: str = Field(
default="https://api.openai.com/v1",
description="Base URL for the judge model API",
)
judge_api_key_env: str = Field(
default="OPENAI_API_KEY",
description="Environment variable name containing the API key for the judge model",
)
judge_temperature: float = Field(
default=0.2, description="Temperature for judge completions"
)
judge_max_tokens: int = Field(
default=0,
description="Maximum tokens for judge completions (0 = use model default)",
)
# Judge retry configuration
judge_max_retries: int = Field(
default=3,
ge=1,
description="Maximum number of retries for failed judge API calls",
)
judge_retry_multiplier: float = Field(
default=1.0,
ge=0.0,
description="Exponential backoff multiplier for judge retries",
)
judge_retry_max_wait: int = Field(
default=10, ge=1, description="Maximum wait time in seconds for judge retries"
)
# Rate limiting configuration
judge_max_concurrent_calls: int = Field(
default=5, ge=1, description="Maximum number of concurrent judge API calls"
)
judge_rate_limit_delay: float = Field(
default=0.5,
ge=0.0,
description="Minimum delay in seconds between judge API calls",
)
# Fallback configuration
use_fallback_scoring: bool = Field(
default=True, description="Use fallback scoring (score=0) when judge API fails"
)
# Retry and debug configuration
max_retries: int = Field(
default=3, description="Maximum retries for failed model API calls"
)
retry_delay: float = Field(
default=1.0, description="Delay between retries in seconds"
)
min_response_length: int = Field(
default=1, description="Minimum response length to consider valid"
)
full_debug: bool = Field(default=False, description="Enable full debug output")
# Override defaults
group_size: int = 1
max_num_workers: int = 256
max_eval_workers: int = 64
max_num_workers_per_node: int = 32
use_wandb: bool = True
rollout_server_url: str = "http://localhost:8000"
total_steps: int = 1
wandb_name: str = "mtbench_eval"
steps_per_eval: int = 1
class MTBenchEvalEnv(BaseEnv):
"""
MT-Bench Evaluation Environment with LLM Judge.
Evaluates multi-turn conversational ability using MT-Bench dataset.
Uses an LLM judge to score responses on a 1-5 scale.
"""
name = "mtbench_eval"
def __init__(
self,
config: MTBenchEvalConfig,
server_configs: List[APIServerConfig],
slurm_job_id: Optional[str] = None,
testing: bool = False,
):
super().__init__(config, server_configs, slurm_job_id, testing)
self.config: MTBenchEvalConfig = config
self.eval_items: List[Dict] = []
self._dataset_loaded = False
# Setup judge client
self.judge_client = None
self._setup_judge_client()
# Rate limiting for judge calls
self.judge_semaphore = asyncio.Semaphore(self.config.judge_max_concurrent_calls)
# Pre-compile regex for score extraction
self._score_pattern = re.compile(r"\s*(\d)\s*", re.IGNORECASE)
# Thread-safe metrics tracking
self._metrics_lock = asyncio.Lock()
self.judge_error_count = 0
self.fallback_count = 0
@classmethod
def config_cls(cls) -> type:
return MTBenchEvalConfig
def _setup_judge_client(self):
"""Setup the judge API client (following refusalbench pattern)."""
try:
import openai
api_key = os.getenv(self.config.judge_api_key_env)
if not api_key:
raise ValueError(
f"API key not found in environment variable: {self.config.judge_api_key_env}"
)
self.judge_client = openai.AsyncOpenAI(
api_key=api_key,
base_url=self.config.judge_base_url,
)
except ImportError:
raise ImportError(
"OpenAI package is required for judge functionality. Install with: pip install openai"
)
async def setup(self) -> None:
"""Initialize the environment and load the dataset."""
await super().setup()
if not self._dataset_loaded:
await self._load_dataset()
print("\nMT-Bench Evaluation Setup (Multi-Turn with LLM Judge):")
print(f" Dataset: {self.config.dataset_name}")
print(f" Categories: {self.config.categories or 'all'}")
print(f" Evaluation split: {self.config.eval_split}")
print(f" Thinking mode: {self.config.thinking_mode}")
print(f" Judge model: {self.config.judge_model_name}")
print(f" Judge endpoint: {self.config.judge_base_url}")
if self.config.thinking_mode:
thinking_prompt = get_default_thinking_prompt(
self.config.custom_thinking_prompt
)
print(f" Thinking prompt: {thinking_prompt[:80]}...")
print(f" Loaded {len(self.eval_items)} evaluation items")
async def _load_dataset(self) -> None:
"""Load and process the MT-Bench dataset."""
print(f"Loading MT-Bench dataset: {self.config.dataset_name}...")
try:
dataset = load_dataset(
self.config.dataset_name,
self.config.subset if self.config.subset != "default" else None,
trust_remote_code=True,
)
except Exception as e:
print(f"Error loading dataset: {e}")
raise
if self.config.eval_split not in dataset:
available_splits = list(dataset.keys())
raise ValueError(
f"Split '{self.config.eval_split}' not found. Available: {available_splits}"
)
split_data = dataset[self.config.eval_split]
# Process items
self.eval_items = []
for idx, item in enumerate(split_data):
category = item.get("category", "unknown")
# Filter by categories if specified
if self.config.categories and category.lower() not in [
c.lower() for c in self.config.categories
]:
continue
turns = item.get("turns", [])
if len(turns) < 2:
print(f" Warning: Item {idx} has fewer than 2 turns, skipping")
continue
reference = item.get("reference", [])
question_id = item.get("question_id", idx)
self.eval_items.append(
{
"id": question_id,
"category": category,
"turns": turns, # List of 2 turn prompts
"reference": reference, # Optional reference answers
}
)
# Shuffle with seed
random.seed(self.config.shuffle_seed)
random.shuffle(self.eval_items)
self._dataset_loaded = True
print(f"Loaded {len(self.eval_items)} evaluation items")
# Show category distribution
category_counts = {}
for item in self.eval_items:
cat = item["category"]
category_counts[cat] = category_counts.get(cat, 0) + 1
print(" Category distribution:")
for cat, count in sorted(category_counts.items()):
print(f" {cat}: {count}")
def _create_system_content(self) -> str:
"""Create system message content based on thinking mode."""
return (
create_system_content(
self.config.thinking_mode,
self.config.custom_thinking_prompt,
self.config.custom_system_prompt,
)
or "You are a helpful assistant."
)
async def _rate_limited_judge_call(self, messages: List[Dict]) -> Optional[str]:
"""Make a rate-limited API call to the judge model with retry logic."""
retry_decorator = retry(
stop=stop_after_attempt(self.config.judge_max_retries),
wait=wait_random_exponential(
multiplier=self.config.judge_retry_multiplier,
max=self.config.judge_retry_max_wait,
),
retry=retry_if_exception_type((Exception,)),
)
async def _inner_judge_call():
async with self.judge_semaphore:
await asyncio.sleep(self.config.judge_rate_limit_delay)
return await self._judge_api_call_raw(messages)
retrying_call = retry_decorator(_inner_judge_call)
return await retrying_call()
async def _judge_api_call_raw(self, messages: List[Dict]) -> Optional[str]:
"""Make a raw API call to the judge model."""
try:
kwargs = {
"model": self.config.judge_model_name,
"messages": messages,
"temperature": self.config.judge_temperature,
}
if self.config.judge_max_tokens > 0:
kwargs["max_tokens"] = self.config.judge_max_tokens
result = await self.judge_client.chat.completions.create(**kwargs)
if result.choices and result.choices[0].message.content:
return result.choices[0].message.content
return None
except Exception as e:
if self.config.full_debug:
print(f" Judge API error: {e}")
raise
def _parse_judge_score(self, judgment: str) -> int:
"""Parse the judge's score from the response."""
match = self._score_pattern.search(judgment)
if match:
return int(match.group(1))
return 0 # Fallback score
async def _judge_response(
self, question: str, answer: str, reference: Optional[str] = None
) -> Tuple[int, str]:
"""
Judge a model response using the LLM judge.
Returns:
Tuple of (score: 1-5, judgment: str)
"""
messages = judge_prompt_with_reference(question, answer, reference)
try:
judgment = await self._rate_limited_judge_call(messages)
if not judgment:
async with self._metrics_lock:
self.judge_error_count += 1
if self.config.use_fallback_scoring:
async with self._metrics_lock:
self.fallback_count += 1
return 0, "JUDGE_ERROR_EMPTY_RESPONSE"
return 0, "JUDGE_ERROR_EMPTY_RESPONSE"
score = self._parse_judge_score(judgment)
return score, judgment
except Exception as e:
async with self._metrics_lock:
self.judge_error_count += 1
if self.config.use_fallback_scoring:
async with self._metrics_lock:
self.fallback_count += 1
return 0, f"JUDGE_ERROR: {str(e)}"
return 0, f"JUDGE_ERROR: {str(e)}"
async def rollout_and_score_eval(
self,
item: Dict,
server: APIServerConfig,
) -> Optional[Dict]:
"""Run multi-turn evaluation on a single item and return the result."""
turns = item["turns"]
references = item.get("reference", [])
system_content = self._create_system_content()
# Initialize conversation
messages = []
if system_content:
messages.append({"role": "system", "content": system_content})
# Store responses and scores for each turn
turn_responses = []
turn_scores = []
turn_judgments = []
turn_valid_formats = []
# Build API call parameters
kwargs = {
"model": server.model_name,
"temperature": self.config.eval_temperature,
}
if self.config.eval_max_tokens > 0:
kwargs["max_tokens"] = self.config.eval_max_tokens
# Process each turn
for turn_idx, turn_prompt in enumerate(turns):
# Add user message
messages.append({"role": "user", "content": turn_prompt})
# Get model response
response_text = ""
for attempt in range(self.config.max_retries):
try:
response = await self.server.chat_completion(
messages=messages, **kwargs
)
response_text = response.choices[0].message.content or ""
if len(response_text) >= self.config.min_response_length:
break
except Exception as e:
if self.config.full_debug:
print(
f" API error turn {turn_idx + 1} (attempt {attempt + 1}): {e}"
)
if attempt < self.config.max_retries - 1:
await asyncio.sleep(self.config.retry_delay)
continue
if not response_text:
return None
# Validate thinking format and extract actual response
is_valid_format, extracted_response = validate_thinking_format(
response_text, self.config.thinking_mode
)
turn_valid_formats.append(is_valid_format)
# Add assistant response to conversation
messages.append({"role": "assistant", "content": response_text})
turn_responses.append(response_text)
# Build context for judging (include conversation history for turn 2)
if turn_idx == 0:
judge_question = turn_prompt
else:
# For turn 2, include context from turn 1
judge_question = f"Context from previous turn:\nUser: {turns[0]}\nAssistant: {turn_responses[0]}\n\nCurrent turn:\nUser: {turn_prompt}" # noqa: E501
# Get reference for this turn if available
turn_reference = (
references[turn_idx] if turn_idx < len(references) else None
)
# Judge the response (use extracted response without think tags)
score, judgment = await self._judge_response(
judge_question, extracted_response, turn_reference
)
turn_scores.append(score)
turn_judgments.append(judgment)
if self.config.full_debug:
print(f" Turn {turn_idx + 1} score: {score}")
# Extract thinking content if applicable
thinking_contents = []
for resp in turn_responses:
thinking = (
extract_thinking_content(resp) if self.config.thinking_mode else None
)
thinking_contents.append(thinking)
return {
"item_id": item["id"],
"category": item["category"],
"turns": turns,
"responses": turn_responses,
"extracted_responses": [
validate_thinking_format(r, self.config.thinking_mode)[1]
for r in turn_responses
],
"scores": turn_scores,
"judgments": turn_judgments,
"references": references,
"format_valid": turn_valid_formats,
"thinking_contents": thinking_contents,
"avg_score": sum(turn_scores) / len(turn_scores) if turn_scores else 0,
"score_turn_1": turn_scores[0] if len(turn_scores) > 0 else 0,
"score_turn_2": turn_scores[1] if len(turn_scores) > 1 else 0,
}
async def evaluate(self, *args, **kwargs) -> Dict:
"""Run the full MT-Bench evaluation."""
print(f"\n{'='*60}")
print("Starting MT-Bench Evaluation (Multi-Turn with LLM Judge)")
print(f"{'='*60}")
print(f" Total questions: {len(self.eval_items)}")
print(f" Judge model: {self.config.judge_model_name}")
print(f" Thinking mode: {self.config.thinking_mode}")
print(f"{'='*60}\n")
# Reset metrics
self.judge_error_count = 0
self.fallback_count = 0
# Create evaluation tasks
async def eval_task(item):
return await self.rollout_and_score_eval(item, self.server_configs[0])
tasks = [eval_task(item) for item in self.eval_items]
# Run with progress bar
results = await tqdm_asyncio.gather(*tasks, desc="Evaluating MT-Bench")
# Filter out failed results
valid_results = [r for r in results if r is not None]
if not valid_results:
print("Warning: No valid evaluation results obtained")
return {"error": "No valid results", "avg_score": 0.0}
# Calculate overall metrics
total = len(valid_results)
avg_score = sum(r["avg_score"] for r in valid_results) / total
avg_turn_1 = sum(r["score_turn_1"] for r in valid_results) / total
avg_turn_2 = sum(r["score_turn_2"] for r in valid_results) / total
# Calculate per-category metrics
category_metrics = {}
for r in valid_results:
cat = r["category"]
if cat not in category_metrics:
category_metrics[cat] = {"scores": [], "turn_1": [], "turn_2": []}
category_metrics[cat]["scores"].append(r["avg_score"])
category_metrics[cat]["turn_1"].append(r["score_turn_1"])
category_metrics[cat]["turn_2"].append(r["score_turn_2"])
for cat in category_metrics:
scores = category_metrics[cat]["scores"]
t1 = category_metrics[cat]["turn_1"]
t2 = category_metrics[cat]["turn_2"]
category_metrics[cat]["avg_score"] = (
sum(scores) / len(scores) if scores else 0
)
category_metrics[cat]["avg_turn_1"] = sum(t1) / len(t1) if t1 else 0
category_metrics[cat]["avg_turn_2"] = sum(t2) / len(t2) if t2 else 0
category_metrics[cat]["count"] = len(scores)
# Format compliance
format_valid_t1 = sum(1 for r in valid_results if r["format_valid"][0])
format_valid_t2 = sum(
1
for r in valid_results
if len(r["format_valid"]) > 1 and r["format_valid"][1]
)
metrics = {
"avg_score": avg_score,
"avg_score_turn_1": avg_turn_1,
"avg_score_turn_2": avg_turn_2,
"total_evaluated": total,
"format_compliance_turn_1": format_valid_t1 / total if total > 0 else 0.0,
"format_compliance_turn_2": format_valid_t2 / total if total > 0 else 0.0,
"judge_error_rate": (
self.judge_error_count / (total * 2) if total > 0 else 0.0
),
"fallback_rate": self.fallback_count / (total * 2) if total > 0 else 0.0,
"category_metrics": category_metrics,
}
print(f"\n{'='*60}")
print("MT-Bench Evaluation Results")
print(f"{'='*60}")
print(f" Overall Average Score: {avg_score:.2f}/5.0")
print(f" Turn 1 Average: {avg_turn_1:.2f}/5.0")
print(f" Turn 2 Average: {avg_turn_2:.2f}/5.0")
print(f" Total Evaluated: {total}")
if self.config.thinking_mode:
print(f" Format Compliance (T1): {format_valid_t1 / total:.2%}")
print(f" Format Compliance (T2): {format_valid_t2 / total:.2%}")
print("\n Per-Category Breakdown:")
for cat, data in sorted(
category_metrics.items(), key=lambda x: -x[1]["avg_score"]
):
print(
f" {cat}: {data['avg_score']:.2f} (T1: {data['avg_turn_1']:.2f}, T2: {data['avg_turn_2']:.2f}) [{data['count']} items]" # noqa: E501
)
print(f"{'='*60}\n")
# Save results
if self.config.data_dir_to_save_evals:
self._save_results(metrics, valid_results)
return metrics
def _save_results(self, metrics: Dict, results: List[Dict]) -> None:
"""Save evaluation results to disk."""
save_eval_results(self.config.data_dir_to_save_evals, metrics, results)
async def wandb_log(self, metrics: Dict, step: int = 0) -> None:
"""Log metrics to Weights & Biases."""
if not self.config.use_wandb:
return
log_dict = {
"mtbench/avg_score": metrics.get("avg_score", 0),
"mtbench/avg_score_turn_1": metrics.get("avg_score_turn_1", 0),
"mtbench/avg_score_turn_2": metrics.get("avg_score_turn_2", 0),
"mtbench/total_evaluated": metrics.get("total_evaluated", 0),
"mtbench/format_compliance_turn_1": metrics.get(
"format_compliance_turn_1", 0
),
"mtbench/format_compliance_turn_2": metrics.get(
"format_compliance_turn_2", 0
),
"mtbench/judge_error_rate": metrics.get("judge_error_rate", 0),
}
# Log per-category scores
for cat, data in metrics.get("category_metrics", {}).items():
safe_name = cat.replace(" ", "_")[:20]
log_dict[f"mtbench/score_{safe_name}"] = data.get("avg_score", 0)
wandb.log(log_dict, step=step)
# Required abstract method implementations
async def get_next_item(self) -> Optional[Dict]:
"""Not used in evaluation mode."""
return None
async def collect_trajectories(self, item) -> List:
"""Not used in evaluation mode."""
return []
async def score(self, rollout_group_data) -> Optional[List]:
"""Not used in evaluation mode."""
return None
if __name__ == "__main__":
MTBenchEvalEnv.cli()