import random
import re
from typing import Dict, List, Optional, Tuple, Union
import wandb
from datasets import load_dataset
from tqdm.asyncio import tqdm_asyncio
from atroposlib.envs.base import (
APIServerConfig,
BaseEnv,
BaseEnvConfig,
EvalHandlingEnum,
Item,
ScoredDataGroup,
)
system_prompt = (
"You are a deep thinking AI, you may use extremely long chains of thought to deeply consider the "
"problem and deliberate with yourself via systematic reasoning processes to help come to a correct "
"solution prior to answering. You should enclose your thoughts and internal monologue inside "
" tags, and then provide your solution or response to the problem."
)
class MCQAThinkingEnv(BaseEnv):
def __init__(
self,
config: BaseEnvConfig,
server_configs: List[APIServerConfig],
slurm=True,
testing=False,
):
"""
Initialize the MCQA (Multiple Choice Question Answering) environment.
Args:
config: Configuration for the base environment
server_configs: List of server configurations for OpenAI API
slurm: Whether to use Slurm for distributed training
testing: Whether in testing mode
"""
super().__init__(config, server_configs, slurm, testing)
self.percent_correct_buffer = list()
self.eval_metrics = list()
@classmethod
def config_init(self) -> Tuple[BaseEnvConfig, List[APIServerConfig]]:
env_config = BaseEnvConfig(
tokenizer_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview",
group_size=32,
use_wandb=True,
max_num_workers=128,
rollout_server_url="http://localhost:8000",
total_steps=2000,
batch_size=1024,
steps_per_eval=20,
max_token_length=1024 * 15,
inference_weight=1.0,
wandb_name="mcqa_deep_thinking",
data_path_to_save_groups=None,
eval_handling=EvalHandlingEnum.LIMIT_TRAIN,
eval_limit_ratio=0.1,
)
server_configs = [
APIServerConfig(
model_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview",
base_url="http://localhost:9004/v1",
api_key="x",
num_max_requests_at_once=32,
num_requests_for_eval=256,
)
]
return env_config, server_configs
async def setup(self):
"""
Set up the environment by loading and preparing the dataset.
"""
# Load the full dataset
full_dataset = load_dataset(
"NousResearch/AcademicMCQA", "default", split="train"
)
full_dataset = full_dataset.shuffle(seed=42)
# Create train/test split on the fly (e.g., 95% train, 5% test)
split_dataset = full_dataset.train_test_split(test_size=0.02, seed=42)
# Keep the splits as is - no need to reformat
self.train = split_dataset["train"]
self.test = split_dataset["test"]
# Print some dataset statistics
print(
f"Loaded dataset with {len(self.train)} training examples and {len(self.test)} test examples"
)
print(f"Example item format: {self.train[0]}")
# Initialize iteration counter
self.iter = 0
def save_checkpoint(self, step, data=None):
if data is None:
data = {}
data["iter"] = self.iter
super().save_checkpoint(step, data)
async def get_next_item(self):
"""
Get the next training item from the dataset.
Returns:
A tuple containing prompt and expected answer
"""
next_item = self.train[self.iter % len(self.train)]
self.iter += 1
# Extract question and options from the multiple choice item
question_text = next_item["prompt"]
correct_answer_index = next_item["answer"]
ground_truth_letter = next_item["ground_truth"]
options = next_item["options"]
# Append the answer format instruction to the prompt
question_text_with_instruction = f'{question_text}\n\nProvide your answer by saying "The best answer is: {{Answer}}"' # noqa E501
# Create prompt tuple using frozensets as required
prompt = []
# Add system prompt as defined at the top of the script
prompt.append(frozenset({"role": "system", "content": system_prompt}.items()))
# Add user message with the question and instruction
prompt.append(
frozenset(
{"role": "user", "content": question_text_with_instruction}.items()
)
)
# Prepare the expected answer
# We'll use the ground_truth_letter (A, B, C, D) as the expected answer
# The scoring function will need to check if the model response contains this letter
answer = ground_truth_letter
answer_string = options[correct_answer_index]
return (tuple(prompt), answer, answer_string)
async def collect_trajectories(self, item) -> Tuple[ScoredDataGroup, List]:
"""
Generate and collect model responses for scoring.
Args:
item: Input item containing prompt and expected answer
Returns:
Tuple of lists containing scored data groups and backlog
"""
# Extract messages from the item
messages = []
for role_dict in item[0]:
messages.append(dict(role_dict))
# Apply chat template to convert messages to a single string
prompt = self.tokenizer.apply_chat_template(
messages, add_generation_prompt=True, tokenize=False
)
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
# Get completions from the model using completion() instead of chat_completion()
completions = await managed.completion(
prompt=prompt,
n=self.config.group_size,
max_tokens=1024 * 15,
temperature=1.0, # Using temperature to get diverse responses
)
state = managed.get_state()
nodes = state["nodes"]
to_score = list()
for i, completion_choice in enumerate(completions.choices):
# Create a copy of the prompt messages
trajectory_messages = []
for role_dict in item[0]:
trajectory_messages.append(dict(role_dict))
# Add the model's response
trajectory_messages.append(
{"role": "assistant", "content": completion_choice.text}
)
# Add to scoring queue with expected answer, ground truth text, and stop reason
to_score.append(
{
"messages": tuple(trajectory_messages),
"expected_answer": item[1], # Letter (A, B, C, D)
"ground_truth_text": item[
2
], # Include the answer_string/ground_truth_text
"finish_reason": completion_choice.finish_reason, # Add the stop reason
"tokens": nodes[i].tokens,
"masks": nodes[i].masked_tokens,
"logprobs": nodes[i].logprobs,
}
)
# Call score to get the scored data
scored_data = await self.score(to_score)
to_backlog = []
return scored_data, to_backlog
def _extract_mcqa_answer(self, text, ground_truth_text, ground_truth_letter):
"""
Extract the multiple choice answer (A, B, C, or D) from model response.
Only allows one valid answer format - multiple answer formats result in a score of 0.
Args:
text: Text containing the model's response
ground_truth_text: The full text of the correct answer
ground_truth_letter: The letter (A, B, C, D) of the correct answer
Returns:
Extracted answer letter or None if invalid response pattern is found
"""
# Check for multiple tags - score as 0 if found
think_tags = re.findall(r"", text, re.IGNORECASE)
if len(think_tags) > 1:
return None
# Check if the think tag is properly opened - we need exactly one opening tag
if len(think_tags) != 1:
return None
# Check for closing tags
think_close_tags = re.findall(r"", text, re.IGNORECASE)
if len(think_close_tags) != 1:
return None # Must have exactly one closing tag
# Split the text into thinking and answer sections
parts = re.split(r"", text, flags=re.IGNORECASE, maxsplit=1)
# If there's no tag or multiple sections, return None
if len(parts) != 2:
return None
thinking_section, answer_section = parts
# Validate thinking section
# Make sure thinking section actually contains the opening tag
if "" not in thinking_section.lower():
return None # Malformed thinking section
# Check if there are any tags in the answer section (after the first )
if "" in answer_section.lower():
return None
# More flexible answer patterns that handle parentheses and additional text
answer_patterns = [
r"The correct answer is:?\s*(?:\*\*)?(A|B|C|D)(?:\*\*)?(?:\)|\.|:)?(?:[^A-Da-d]*.*?)?(?=$|\n|\.)", # noqa W605
r"The best answer is:?\s*(?:\*\*)?(A|B|C|D)(?:\*\*)?(?:\)|\.|:)?(?:[^A-Da-d]*.*?)?(?=$|\n|\.)", # noqa W605
r"The answer is:?\s*(?:\*\*)?(A|B|C|D)(?:\*\*)?(?:\)|\.|:)?(?:[^A-Da-d]*.*?)?(?=$|\n|\.)", # noqa W605
r"\*\*The best answer is\s*(A|B|C|D)\*\*(?:\)|\.|:)?(?:[^A-Da-d]*.*?)?(?=$|\n|\.)", # noqa W605
r"\*\*The best answer is:\s*(A|B|C|D)\*\*(?:\)|\.|:)?(?:[^A-Da-d]*.*?)?(?=$|\n|\.)", # noqa W605
r"Thus, final answer:\s*(A|B|C|D)\)(?:\)|\.|:)?(?:[^A-Da-d]*.*?)?(?=$|\n|\.)", # noqa W605
r"\\boxed{(A|B|C|D)}(?:\)|\.|:)?(?:[^A-Da-d]*.*?)?(?=$|\n|\.)", # noqa W605
]
string_patterns = [
# Patterns to match exact ground truth text, with optional markdown bold formatting
r"The correct answer is:?\s(?:\*\*)?"
+ re.escape(ground_truth_text)
+ r"(?:\*\*)?(?:[^A-Da-d]*.*?)?(?=$|\n|\.)",
r"The best answer is:?\s(?:\*\*)?"
+ re.escape(ground_truth_text)
+ r"(?:\*\*)?(?:[^A-Da-d]*.*?)?(?=$|\n|\.)",
r"The answer is:?\s(?:\*\*)?"
+ re.escape(ground_truth_text)
+ r"(?:\*\*)?(?:[^A-Da-d]*.*?)?(?=$|\n|\.)",
]
# Track all found answers
found_answers = []
# Check each pattern
for pattern in answer_patterns:
matches = re.findall(pattern, answer_section, re.IGNORECASE)
if matches:
for match in matches:
# Extract just the letter
found_answers.append(match.upper())
for pattern in string_patterns:
matches = re.findall(pattern, answer_section, re.IGNORECASE)
if matches:
# For each match found, append the ground truth letter instead of the full match
for _ in matches:
found_answers.append(ground_truth_letter)
# If no answers found or multiple answers found, return None
if len(found_answers) != 1:
return None
# Return the single found answer
return found_answers[0]
async def score(self, rollout_group_data: List) -> Optional[ScoredDataGroup]:
"""
Score the generated model responses against expected MCQA answers.
Args:
rollout_group_data: List of generated responses with expected answers
Returns:
ScoredDataGroup with tokenized inputs and scores, or None if no valid scores
"""
scores = ScoredDataGroup()
scores["tokens"] = list()
scores["masks"] = list()
scores["scores"] = list()
scores["inference_logprobs"] = list()
# Get the expected answer letter
expected_answer = rollout_group_data[0]["expected_answer"] # Letter A, B, C, D
ground_truth_text = rollout_group_data[0]["ground_truth_text"]
# Shuffle to avoid bias in selection
random.shuffle(rollout_group_data)
for item in rollout_group_data:
# Extract the model's response
model_response = item["messages"][-1]["content"]
stop_reason = item["finish_reason"] # Get the stop reason
# If the response was cut off due to length, give it a score of 0
if stop_reason == "length":
reward = 0
else:
# Extract the answer from the model's response
model_answer = self._extract_mcqa_answer(
model_response, ground_truth_text, expected_answer
)
# Track metrics based on result
if model_answer is None:
reward = 0 # Invalid format gets 0 reward
elif model_answer == expected_answer:
reward = 1 # Correct answer gets 1 reward
else:
reward = 0 # Wrong answer gets 0 reward
tokens = item["tokens"]
masks = item["masks"]
logprobs = item["logprobs"]
# Remove examples with insufficient context
if len([1 for i in masks if i != -100]) < 10:
continue
scores["tokens"].append(tokens)
scores["masks"].append(masks)
scores["inference_logprobs"].append(logprobs)
scores["scores"].append(1.0 if reward else -1.0)
# Break once we have enough examples
if len(scores["tokens"]) >= self.config.group_size:
break
# Record success rate metrics for wandb logging
for score in scores["scores"]:
self.percent_correct_buffer.append(max(score, 0))
# Return None if all scores are the same (no learning signal)
if all(scores["scores"][0] == score for score in scores["scores"]):
return None
return scores
async def rollout_and_score_eval(self, test_item):
"""
Generate and score model responses for a single test item.
Args:
test_item: Test item from dataset
Returns:
Score (1 for correct, 0 for incorrect)
"""
# Extract question and options from the test item
question_text = test_item["prompt"]
correct_answer_index = test_item["answer"]
expected_answer_letter = test_item["ground_truth"]
options = test_item["options"]
# Append the answer format instruction to the prompt
question_text_with_instruction = f'{question_text}\n\nProvide your answer by saying "The best answer is: {{Answer}}"' # noqa E501
# Create messages for model
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": question_text_with_instruction},
]
# Apply chat template to convert messages to a single string
prompt = self.tokenizer.apply_chat_template(
messages, add_generation_prompt=True, tokenize=False
)
# Get model completion
completion = await self.server.completion(
prompt=prompt,
n=1,
max_tokens=1024 * 15,
temperature=0.5, # Lower for eval
split="eval",
)
# Extract the model's response from the completion
model_response = completion.choices[0].text
# Extract the answer from the model's response
model_answer = self._extract_mcqa_answer(
model_response, options[correct_answer_index], expected_answer_letter
)
# Score 1 if the answers match, 0 otherwise
score = 1 if model_answer and model_answer == expected_answer_letter else 0
return score
async def evaluate(self, *args, **kwargs):
"""
Evaluate the model on test data.
"""
eval_tasks = []
for test_item in self.test:
eval_tasks.append(self.rollout_and_score_eval(test_item))
# Run evaluation
scores = await tqdm_asyncio.gather(*eval_tasks)
self.eval_metrics.append(("eval/percent_correct", sum(scores) / len(scores)))
async def add_rollouts_for_wandb(
self,
scored_data: Union[ScoredDataGroup, List[ScoredDataGroup]],
item: Item = None,
):
# save rollout to trajectory
num_keep = self.config.num_rollouts_per_group_for_logging
if num_keep == -1:
num_keep = self.config.group_size
self.rollouts_for_wandb.append(
[
(
self.tokenizer.decode(scored_data["tokens"][i]),
scored_data["scores"][i],
item[1] if isinstance(item, tuple) else item["expected_answer"],
item[2] if isinstance(item, tuple) else item["ground_truth_text"],
)
for i in range(num_keep)
]
)
if len(self.rollouts_for_wandb) > self.config.num_rollouts_to_keep:
self.rollouts_for_wandb.pop(0)
async def create_rollout_table(self, wandb_metrics):
if len(self.rollouts_for_wandb) > 0:
table = wandb.Table(columns=["text", "score", "answer", "string_answer"])
for group in self.rollouts_for_wandb:
for item in group:
table.add_data(item[0], item[1], item[2], item[3])
wandb_metrics["train/rollouts"] = table
self.rollouts_for_wandb = []
return wandb_metrics
async def wandb_log(self, wandb_metrics: Optional[Dict] = None):
if wandb_metrics is None:
wandb_metrics = {}
# Try to calculate percent_correct, pass if there's a division by zero
try:
wandb_metrics["train/percent_correct"] = sum(
self.percent_correct_buffer
) / len(self.percent_correct_buffer)
except ZeroDivisionError:
# Skip if buffer is empty
pass
self.percent_correct_buffer = list()
for item in self.eval_metrics:
wandb_metrics[item[0]] = item[1]
self.eval_metrics = list()
await super().wandb_log(wandb_metrics)
if __name__ == "__main__":
MCQAThinkingEnv.cli()