import random import re from typing import Dict, List, Optional, Tuple, Union import wandb from datasets import load_dataset from tqdm.asyncio import tqdm_asyncio from atroposlib.envs.base import ( APIServerConfig, BaseEnv, BaseEnvConfig, EvalHandlingEnum, Item, ScoredDataGroup, ) system_prompt = ( "You are a deep thinking AI, you may use extremely long chains of thought to deeply consider the " "problem and deliberate with yourself via systematic reasoning processes to help come to a correct " "solution prior to answering. You should enclose your thoughts and internal monologue inside " " tags, and then provide your solution or response to the problem." ) class MCQAThinkingEnv(BaseEnv): def __init__( self, config: BaseEnvConfig, server_configs: List[APIServerConfig], slurm=True, testing=False, ): """ Initialize the MCQA (Multiple Choice Question Answering) environment. Args: config: Configuration for the base environment server_configs: List of server configurations for OpenAI API slurm: Whether to use Slurm for distributed training testing: Whether in testing mode """ super().__init__(config, server_configs, slurm, testing) self.percent_correct_buffer = list() self.eval_metrics = list() @classmethod def config_init(self) -> Tuple[BaseEnvConfig, List[APIServerConfig]]: env_config = BaseEnvConfig( tokenizer_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview", group_size=32, use_wandb=True, max_num_workers=128, rollout_server_url="http://localhost:8000", total_steps=2000, batch_size=1024, steps_per_eval=20, max_token_length=1024 * 15, inference_weight=1.0, wandb_name="mcqa_deep_thinking", data_path_to_save_groups=None, eval_handling=EvalHandlingEnum.LIMIT_TRAIN, eval_limit_ratio=0.1, ) server_configs = [ APIServerConfig( model_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview", base_url="http://localhost:9004/v1", api_key="x", num_max_requests_at_once=32, num_requests_for_eval=256, ) ] return env_config, server_configs async def setup(self): """ Set up the environment by loading and preparing the dataset. """ # Load the full dataset full_dataset = load_dataset( "NousResearch/AcademicMCQA", "default", split="train" ) full_dataset = full_dataset.shuffle(seed=42) # Create train/test split on the fly (e.g., 95% train, 5% test) split_dataset = full_dataset.train_test_split(test_size=0.02, seed=42) # Keep the splits as is - no need to reformat self.train = split_dataset["train"] self.test = split_dataset["test"] # Print some dataset statistics print( f"Loaded dataset with {len(self.train)} training examples and {len(self.test)} test examples" ) print(f"Example item format: {self.train[0]}") # Initialize iteration counter self.iter = 0 def save_checkpoint(self, step, data=None): if data is None: data = {} data["iter"] = self.iter super().save_checkpoint(step, data) async def get_next_item(self): """ Get the next training item from the dataset. Returns: A tuple containing prompt and expected answer """ next_item = self.train[self.iter % len(self.train)] self.iter += 1 # Extract question and options from the multiple choice item question_text = next_item["prompt"] correct_answer_index = next_item["answer"] ground_truth_letter = next_item["ground_truth"] options = next_item["options"] # Append the answer format instruction to the prompt question_text_with_instruction = f'{question_text}\n\nProvide your answer by saying "The best answer is: {{Answer}}"' # noqa E501 # Create prompt tuple using frozensets as required prompt = [] # Add system prompt as defined at the top of the script prompt.append(frozenset({"role": "system", "content": system_prompt}.items())) # Add user message with the question and instruction prompt.append( frozenset( {"role": "user", "content": question_text_with_instruction}.items() ) ) # Prepare the expected answer # We'll use the ground_truth_letter (A, B, C, D) as the expected answer # The scoring function will need to check if the model response contains this letter answer = ground_truth_letter answer_string = options[correct_answer_index] return (tuple(prompt), answer, answer_string) async def collect_trajectories(self, item) -> Tuple[ScoredDataGroup, List]: """ Generate and collect model responses for scoring. Args: item: Input item containing prompt and expected answer Returns: Tuple of lists containing scored data groups and backlog """ # Extract messages from the item messages = [] for role_dict in item[0]: messages.append(dict(role_dict)) # Apply chat template to convert messages to a single string prompt = self.tokenizer.apply_chat_template( messages, add_generation_prompt=True, tokenize=False ) async with self.server.managed_server(tokenizer=self.tokenizer) as managed: # Get completions from the model using completion() instead of chat_completion() completions = await managed.completion( prompt=prompt, n=self.config.group_size, max_tokens=1024 * 15, temperature=1.0, # Using temperature to get diverse responses ) state = managed.get_state() nodes = state["nodes"] to_score = list() for i, completion_choice in enumerate(completions.choices): # Create a copy of the prompt messages trajectory_messages = [] for role_dict in item[0]: trajectory_messages.append(dict(role_dict)) # Add the model's response trajectory_messages.append( {"role": "assistant", "content": completion_choice.text} ) # Add to scoring queue with expected answer, ground truth text, and stop reason to_score.append( { "messages": tuple(trajectory_messages), "expected_answer": item[1], # Letter (A, B, C, D) "ground_truth_text": item[ 2 ], # Include the answer_string/ground_truth_text "finish_reason": completion_choice.finish_reason, # Add the stop reason "tokens": nodes[i].tokens, "masks": nodes[i].masked_tokens, "logprobs": nodes[i].logprobs, } ) # Call score to get the scored data scored_data = await self.score(to_score) to_backlog = [] return scored_data, to_backlog def _extract_mcqa_answer(self, text, ground_truth_text, ground_truth_letter): """ Extract the multiple choice answer (A, B, C, or D) from model response. Only allows one valid answer format - multiple answer formats result in a score of 0. Args: text: Text containing the model's response ground_truth_text: The full text of the correct answer ground_truth_letter: The letter (A, B, C, D) of the correct answer Returns: Extracted answer letter or None if invalid response pattern is found """ # Check for multiple tags - score as 0 if found think_tags = re.findall(r"", text, re.IGNORECASE) if len(think_tags) > 1: return None # Check if the think tag is properly opened - we need exactly one opening tag if len(think_tags) != 1: return None # Check for closing tags think_close_tags = re.findall(r"", text, re.IGNORECASE) if len(think_close_tags) != 1: return None # Must have exactly one closing tag # Split the text into thinking and answer sections parts = re.split(r"", text, flags=re.IGNORECASE, maxsplit=1) # If there's no tag or multiple sections, return None if len(parts) != 2: return None thinking_section, answer_section = parts # Validate thinking section # Make sure thinking section actually contains the opening tag if "" not in thinking_section.lower(): return None # Malformed thinking section # Check if there are any tags in the answer section (after the first ) if "" in answer_section.lower(): return None # More flexible answer patterns that handle parentheses and additional text answer_patterns = [ r"The correct answer is:?\s*(?:\*\*)?(A|B|C|D)(?:\*\*)?(?:\)|\.|:)?(?:[^A-Da-d]*.*?)?(?=$|\n|\.)", # noqa W605 r"The best answer is:?\s*(?:\*\*)?(A|B|C|D)(?:\*\*)?(?:\)|\.|:)?(?:[^A-Da-d]*.*?)?(?=$|\n|\.)", # noqa W605 r"The answer is:?\s*(?:\*\*)?(A|B|C|D)(?:\*\*)?(?:\)|\.|:)?(?:[^A-Da-d]*.*?)?(?=$|\n|\.)", # noqa W605 r"\*\*The best answer is\s*(A|B|C|D)\*\*(?:\)|\.|:)?(?:[^A-Da-d]*.*?)?(?=$|\n|\.)", # noqa W605 r"\*\*The best answer is:\s*(A|B|C|D)\*\*(?:\)|\.|:)?(?:[^A-Da-d]*.*?)?(?=$|\n|\.)", # noqa W605 r"Thus, final answer:\s*(A|B|C|D)\)(?:\)|\.|:)?(?:[^A-Da-d]*.*?)?(?=$|\n|\.)", # noqa W605 r"\\boxed{(A|B|C|D)}(?:\)|\.|:)?(?:[^A-Da-d]*.*?)?(?=$|\n|\.)", # noqa W605 ] string_patterns = [ # Patterns to match exact ground truth text, with optional markdown bold formatting r"The correct answer is:?\s(?:\*\*)?" + re.escape(ground_truth_text) + r"(?:\*\*)?(?:[^A-Da-d]*.*?)?(?=$|\n|\.)", r"The best answer is:?\s(?:\*\*)?" + re.escape(ground_truth_text) + r"(?:\*\*)?(?:[^A-Da-d]*.*?)?(?=$|\n|\.)", r"The answer is:?\s(?:\*\*)?" + re.escape(ground_truth_text) + r"(?:\*\*)?(?:[^A-Da-d]*.*?)?(?=$|\n|\.)", ] # Track all found answers found_answers = [] # Check each pattern for pattern in answer_patterns: matches = re.findall(pattern, answer_section, re.IGNORECASE) if matches: for match in matches: # Extract just the letter found_answers.append(match.upper()) for pattern in string_patterns: matches = re.findall(pattern, answer_section, re.IGNORECASE) if matches: # For each match found, append the ground truth letter instead of the full match for _ in matches: found_answers.append(ground_truth_letter) # If no answers found or multiple answers found, return None if len(found_answers) != 1: return None # Return the single found answer return found_answers[0] async def score(self, rollout_group_data: List) -> Optional[ScoredDataGroup]: """ Score the generated model responses against expected MCQA answers. Args: rollout_group_data: List of generated responses with expected answers Returns: ScoredDataGroup with tokenized inputs and scores, or None if no valid scores """ scores = ScoredDataGroup() scores["tokens"] = list() scores["masks"] = list() scores["scores"] = list() scores["inference_logprobs"] = list() # Get the expected answer letter expected_answer = rollout_group_data[0]["expected_answer"] # Letter A, B, C, D ground_truth_text = rollout_group_data[0]["ground_truth_text"] # Shuffle to avoid bias in selection random.shuffle(rollout_group_data) for item in rollout_group_data: # Extract the model's response model_response = item["messages"][-1]["content"] stop_reason = item["finish_reason"] # Get the stop reason # If the response was cut off due to length, give it a score of 0 if stop_reason == "length": reward = 0 else: # Extract the answer from the model's response model_answer = self._extract_mcqa_answer( model_response, ground_truth_text, expected_answer ) # Track metrics based on result if model_answer is None: reward = 0 # Invalid format gets 0 reward elif model_answer == expected_answer: reward = 1 # Correct answer gets 1 reward else: reward = 0 # Wrong answer gets 0 reward tokens = item["tokens"] masks = item["masks"] logprobs = item["logprobs"] # Remove examples with insufficient context if len([1 for i in masks if i != -100]) < 10: continue scores["tokens"].append(tokens) scores["masks"].append(masks) scores["inference_logprobs"].append(logprobs) scores["scores"].append(1.0 if reward else -1.0) # Break once we have enough examples if len(scores["tokens"]) >= self.config.group_size: break # Record success rate metrics for wandb logging for score in scores["scores"]: self.percent_correct_buffer.append(max(score, 0)) # Return None if all scores are the same (no learning signal) if all(scores["scores"][0] == score for score in scores["scores"]): return None return scores async def rollout_and_score_eval(self, test_item): """ Generate and score model responses for a single test item. Args: test_item: Test item from dataset Returns: Score (1 for correct, 0 for incorrect) """ # Extract question and options from the test item question_text = test_item["prompt"] correct_answer_index = test_item["answer"] expected_answer_letter = test_item["ground_truth"] options = test_item["options"] # Append the answer format instruction to the prompt question_text_with_instruction = f'{question_text}\n\nProvide your answer by saying "The best answer is: {{Answer}}"' # noqa E501 # Create messages for model messages = [ {"role": "system", "content": system_prompt}, {"role": "user", "content": question_text_with_instruction}, ] # Apply chat template to convert messages to a single string prompt = self.tokenizer.apply_chat_template( messages, add_generation_prompt=True, tokenize=False ) # Get model completion completion = await self.server.completion( prompt=prompt, n=1, max_tokens=1024 * 15, temperature=0.5, # Lower for eval split="eval", ) # Extract the model's response from the completion model_response = completion.choices[0].text # Extract the answer from the model's response model_answer = self._extract_mcqa_answer( model_response, options[correct_answer_index], expected_answer_letter ) # Score 1 if the answers match, 0 otherwise score = 1 if model_answer and model_answer == expected_answer_letter else 0 return score async def evaluate(self, *args, **kwargs): """ Evaluate the model on test data. """ eval_tasks = [] for test_item in self.test: eval_tasks.append(self.rollout_and_score_eval(test_item)) # Run evaluation scores = await tqdm_asyncio.gather(*eval_tasks) self.eval_metrics.append(("eval/percent_correct", sum(scores) / len(scores))) async def add_rollouts_for_wandb( self, scored_data: Union[ScoredDataGroup, List[ScoredDataGroup]], item: Item = None, ): # save rollout to trajectory num_keep = self.config.num_rollouts_per_group_for_logging if num_keep == -1: num_keep = self.config.group_size self.rollouts_for_wandb.append( [ ( self.tokenizer.decode(scored_data["tokens"][i]), scored_data["scores"][i], item[1] if isinstance(item, tuple) else item["expected_answer"], item[2] if isinstance(item, tuple) else item["ground_truth_text"], ) for i in range(num_keep) ] ) if len(self.rollouts_for_wandb) > self.config.num_rollouts_to_keep: self.rollouts_for_wandb.pop(0) async def create_rollout_table(self, wandb_metrics): if len(self.rollouts_for_wandb) > 0: table = wandb.Table(columns=["text", "score", "answer", "string_answer"]) for group in self.rollouts_for_wandb: for item in group: table.add_data(item[0], item[1], item[2], item[3]) wandb_metrics["train/rollouts"] = table self.rollouts_for_wandb = [] return wandb_metrics async def wandb_log(self, wandb_metrics: Optional[Dict] = None): if wandb_metrics is None: wandb_metrics = {} # Try to calculate percent_correct, pass if there's a division by zero try: wandb_metrics["train/percent_correct"] = sum( self.percent_correct_buffer ) / len(self.percent_correct_buffer) except ZeroDivisionError: # Skip if buffer is empty pass self.percent_correct_buffer = list() for item in self.eval_metrics: wandb_metrics[item[0]] = item[1] self.eval_metrics = list() await super().wandb_log(wandb_metrics) if __name__ == "__main__": MCQAThinkingEnv.cli()