import asyncio import math import random from concurrent.futures import ProcessPoolExecutor from difflib import SequenceMatcher from typing import Dict, List, Optional, Tuple import wandb from datasets import load_dataset from latex2sympy2_extended import NormalizationConfig from math_verify import LatexExtractionConfig, parse, verify from math_verify.errors import TimeoutException from pydantic import Field from tqdm.asyncio import tqdm_asyncio from atroposlib.envs.base import ( APIServerConfig, BaseEnv, BaseEnvConfig, EvalHandlingEnum, ScoredDataGroup, ) system_prompt = ( "You are a deep thinking AI, you may use extremely long chains of thought to deeply consider the " "problem and deliberate with yourself via systematic reasoning processes to help come to a correct " "solution prior to answering. You should enclose your thoughts and internal monologue inside " " tags, and then provide your solution or response to the problem." ) problem_format = "{problem}" judge_format = """Here is a math problem and a proposed solution: [START PROBLEM] {problem} [END PROBLEM] [START SOLUTION] {solution} [END SOLUTION] Please verify if it is correct or not. If it's correct submit your answer in your response with \\boxed{{True}}. If it's incorrect, please submit your answer in your response with \\boxed{{False}}. Please include how to solve the problem correctly in your answer.""" retry_format = """Here is a math problem, a proposed solution, and a verification of the solution: [START PROBLEM] {problem} [END PROBLEM] [START SOLUTION] {solution} [END SOLUTION] [START VERIFICATION] {verification} [END VERIFICATION] Please use this verification to help you solve the problem correctly. Provide your answer in your response with \\boxed{{answer}}.""" # noqa: E501 rlaif_format = """Here is a math problem, and two solutions that are correct. Please choose whichever answer you prefer. [START PROBLEM] {problem} [END PROBLEM] [START SOLUTION 1] {solution1} [END SOLUTION 1] [START SOLUTION 2] {solution2} [END SOLUTION 2] Here are some metrics for you to use to grade the two solutions: - Conciseness: How concise is the solution? Is it too long or too short? - Clarity: How clear is the solution? Is it easy to understand? - Correctness: Is the reasoning correct? The answer has been prechecked to be correct, but there may be errors in the reasoning. Please use these metrics to help you choose the best solution, in order of priority. Please provide your answer in your response with \\boxed{{1}}, for the first solution, or \\boxed{{2}} for the second solution.""" # noqa: E501 class RSConfig(BaseEnvConfig): run_evaluation: bool = Field(True, description="If this should run evaluation") mask_too_long_completions: bool = Field( True, description="If this should mask too long completions" ) percent_to_judge: float = Field(0.3, description="The percentage of items to judge") percent_length_penalty: float = Field( 0.1, description="The percentage of items to have length penalty" ) def quick_similarity(a, b): return SequenceMatcher(None, a, b).ratio() def score_answer(gold, resp) -> Optional[bool]: try: gold_parsed = parse( gold, extraction_mode="first_match", extraction_config=[LatexExtractionConfig()], ) except (Exception, TimeoutException, KeyError, TypeError, NotImplementedError): return None if len(gold_parsed) != 0: # print(item[0][-1]["content"]) try: answer_parsed = parse( resp, extraction_config=[ LatexExtractionConfig( normalization_config=NormalizationConfig( nits=False, malformed_operators=False, basic_latex=True, boxed="all", units=True, ), # Ensures that boxed is tried first boxed_match_priority=0, try_extract_without_anchor=False, ) ], extraction_mode="first_match", ) except ( Exception, TimeoutException, KeyError, TypeError, NotImplementedError, ): # Can't parse, so we skip return None # Reward 1 if the content is the same as the ground truth, 0 otherwise try: return verify(answer_parsed, gold_parsed) except ( Exception, TimeoutException, KeyError, TypeError, NotImplementedError, ): return None return None class MathEnv(BaseEnv): name = "math" env_config_cls = RSConfig def __init__( self, config: RSConfig, server_configs: List[APIServerConfig], slurm=True, testing=False, ): print("Initializing MathEnv") print(f"Slurm: {slurm}, Testing: {testing}") super().__init__(config, server_configs, slurm, testing) self.percent_correct_buffer = list() self.eval_metrics = list() self.mp_executor = ProcessPoolExecutor(64) self.percent_overanswer = list() self.percent_judge_correct = list() self.correct_answer_len = list() self.incorrect_answer_len = list() self.normal_rollouts = list() self.rlaif_rollouts = list() self.pass_at_groupsize = list() self.judge_rollouts = list() self.selfcorrect_rollouts = list() self.judge_success_rate = list() self.iter = 0 @classmethod def config_init(self) -> Tuple[RSConfig, List[APIServerConfig]]: env_config = RSConfig( tokenizer_name="NousResearch/Hermes-4-14B", group_size=16, use_wandb=True, rollout_server_url="http://localhost:8000", total_steps=1000, batch_size=1024, max_num_workers_per_node=12, steps_per_eval=25, max_token_length=16384, # 22000 // (2 ** i), wandb_name="math", eval_handling=EvalHandlingEnum.LIMIT_TRAIN, eval_limit_ratio=0.1, inference_weight=4, min_batch_allocation=0.1, ) server_configs = [ APIServerConfig( model_name="NousResearch/Hermes-4-14B", base_url="http://localhost:9004/v1", api_key="x", num_requests_for_eval=256, # since evaling only on one... server_type="sglang", ), ] return env_config, server_configs async def wandb_log(self, wandb_metrics: Optional[Dict] = None): if wandb_metrics is None: wandb_metrics = dict() if len(self.pass_at_groupsize) > 0: wandb_metrics["train/pass_at_groupsize"] = sum( self.pass_at_groupsize ) / len(self.pass_at_groupsize) self.pass_at_8 = list() if len(self.percent_correct_buffer) > 0: wandb_metrics["train/percent_correct"] = sum( self.percent_correct_buffer ) / len(self.percent_correct_buffer) wandb_metrics["train/percent_overanswer"] = sum( self.percent_overanswer ) / len(self.percent_overanswer) self.percent_overthink = list() self.percent_overanswer = list() self.percent_correct_buffer = list() if len(self.correct_answer_len) > 0: wandb_metrics["train/avg_correct_answer_len"] = sum( self.correct_answer_len ) / len(self.correct_answer_len) self.correct_answer_len = list() if len(self.incorrect_answer_len) > 0: wandb_metrics["train/avg_incorrect_answer_len"] = sum( self.incorrect_answer_len ) / len(self.incorrect_answer_len) self.incorrect_answer_len = list() if len(self.percent_judge_correct) > 0: wandb_metrics["judge_train/percent_judge_correct"] = sum( self.percent_judge_correct ) / len(self.percent_judge_correct) self.percent_judge_correct = list() if len(self.judge_success_rate) > 0: wandb_metrics["judge_train/judge_success_rate"] = sum( self.judge_success_rate ) / len(self.judge_success_rate) # create tables if len(self.judge_rollouts) > 0: table = wandb.Table( columns=["problem", "solution", "answer", "correct", "judge"] ) for group in self.judge_rollouts: table.add_data(group[0], group[1], group[2], group[3], group[4]) wandb_metrics["judge_train/judge_rollouts"] = table if len(self.selfcorrect_rollouts) > 0: table = wandb.Table(columns=["problem", "solution1", "solution2", "score"]) for group in self.selfcorrect_rollouts: table.add_data(group[0], group[1], group[2], group[3]) wandb_metrics["judge_train/selfcorrect_rollouts"] = table if len(self.normal_rollouts) > 0: table = wandb.Table(columns=["problem", "solution", "answer", "score"]) for group in self.normal_rollouts: table.add_data(group[0], group[1], group[2], group[3]) wandb_metrics["train/normal_rollouts"] = table if len(self.rlaif_rollouts) > 0: table = wandb.Table( columns=["problem", "solution1", "solution2", "score", "rollout"] ) for group in self.rlaif_rollouts: table.add_data(group[0], group[1], group[2], group[3], group[4]) wandb_metrics["train/rlaif_rollouts"] = table wandb_metrics["train/iter"] = self.iter for item in self.eval_metrics: wandb_metrics[item[0]] = item[1] self.eval_metrics = list() await super().wandb_log(wandb_metrics) async def setup(self): self.train = load_dataset("zwhe99/DeepMath-103K", split="train").shuffle( seed=42 ) aime_test_data = load_dataset("HuggingFaceH4/aime_2024", split="train") math500_test_data = load_dataset("HuggingFaceH4/math-500", split="test") amc_test_data = load_dataset("math-ai/amc23", split="test") self.test = list() for name, t_dataset in zip( ["aime24", "math500"], [aime_test_data, math500_test_data] ): for item in t_dataset: self.test.append( ( problem_format.format(problem=item["problem"]), item["answer"], name, ) ) for name, t_dataset in zip( ["amc23"], [amc_test_data], ): for item in t_dataset: self.test.append( ( problem_format.format(problem=item["question"]), item["answer"], name, ) ) async def rollout_and_score_eval(self, question, answer, subset): completion = await self.server.chat_completion( messages=[ {"role": "system", "content": system_prompt}, {"role": "user", "content": question}, ], n=1, max_tokens=32765, temperature=0.0, split="eval", ) loop = asyncio.get_event_loop() gold = "\\boxed{" + answer + "}" if "\\boxed" not in answer else answer resp = completion.choices[0].message.content.split("")[-1] task = loop.run_in_executor(self.mp_executor, score_answer, gold, resp) reward = await task if reward is None: return 0, subset score = 1 if reward else 0 return score, subset async def evaluate(self, *args, **kwargs): if not self.config.run_evaluation: return eval_tasks = [] for item in self.test: eval_tasks.append(self.rollout_and_score_eval(item[0], item[1], item[2])) parsing_data = await tqdm_asyncio.gather(*eval_tasks) task_lists = dict() for score, subset in parsing_data: if subset not in task_lists: task_lists[subset] = list() task_lists[subset].append(score) # Now get the average for subset, scores in task_lists.items(): self.eval_metrics.append( (f"eval/{subset}_percent_correct", sum(scores) / len(scores)) ) # overall score scores = [] for subset, score in task_lists.items(): scores.extend(score) self.eval_metrics.append( ("eval/overall_percent_correct", sum(scores) / len(scores)) ) async def collect_trajectories_normal(self, item) -> Tuple[List, List]: thinking_len = self.config.max_token_length user_prompt = problem_format.format(problem=item[0]) chat = [ {"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}, ] thinking_len = thinking_len - len( self.tokenizer.apply_chat_template(chat, add_generation_prompt=True) ) print(f"thinking_len: {thinking_len}", flush=True) if thinking_len < 1024: print("thinking_len is less than 1024, skipping", flush=True) return None, [] # ============================================================================ # MANAGED SERVER USAGE - Chat Completion API # ============================================================================ # This demonstrates using ManagedServer with the chat_completion() API. # The process is identical to the completion() API (see math_server_zero.py), # but uses OpenAI chat message format instead of raw text prompts. # # ManagedServer automatically: # 1. Applies the tokenizer's chat template to convert messages to text # 2. Tokenizes both prompt and completion # 3. Applies proper masking (-100 for prompt tokens, actual IDs for completion) # 4. Applies proper logprob masking (1.0 for prompt, actual values for completion) # 5. Ensures perfect alignment between tokens and logprobs # # See: atroposlib/envs/server_handling/MANAGED_SERVER.md for full documentation # ============================================================================ async with self.server.managed_server(tokenizer=self.tokenizer) as managed: # Call chat_completion through the managed server wrapper # Returns standard OpenAI-compatible ChatCompletion object chat_completions = await managed.chat_completion( messages=chat, n=self.config.group_size, # Generate multiple completions for GRPO max_tokens=thinking_len, temperature=1.0, top_p=0.95, ) # Get tracked sequences with aligned tokens and logprobs state = managed.get_state() nodes = state["nodes"] # List of SequenceNode objects, one per completion print("Finished generation", flush=True) to_score = list() to_backlog = list() for i, (chat_completion, node) in enumerate( zip(chat_completions.choices, nodes) ): messages = ( {"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}, {"role": "assistant", "content": chat_completion.message.content}, ) # Extract pre-computed data from SequenceNode # node.tokens: Full unmasked tokens [prompt + completion] # node.masked_tokens: [-100, ..., -100, tok1, tok2, ...] for training # node.logprobs: [1.0, ..., 1.0, logp1, logp2, ...] for training to_score.append( ( messages, item[1], # Ground truth answer chat_completion.finish_reason, node.tokens, # Pre-computed by ManagedServer node.masked_tokens, # Pre-computed by ManagedServer node.logprobs, # Pre-computed by ManagedServer ) ) print("scoring normal", flush=True) to_postprocess = await self.score_normal(to_score) print("scoring normal done", flush=True) if to_postprocess is None: return None, to_backlog if all( [to_postprocess["scores"][0] == score for score in to_postprocess["scores"]] ): if to_postprocess["scores"][0] == 1.0: # we can do RLAIF # find the two most dissimilar messages messages = to_postprocess["messages"] score_matrix = [] most_dissimilar = (0, 1) most_dissimilar_score = 1.0 # find the two most dissimilar messages for i in range(len(messages) - 1): score_matrix.append([]) for j in range(i + 1): # Only need to compute half of the matrix score_matrix[i].append(1.0) for j in range(i + 1, len(messages)): m1 = messages[i][-1]["content"].split("")[-1] m2 = messages[j][-1]["content"].split("")[-1] if m1 == m2: score_matrix[i].append(1.0) else: score_matrix[i].append(quick_similarity(m1, m2)) if score_matrix[i][j] < most_dissimilar_score: most_dissimilar = (i, j) most_dissimilar_score = score_matrix[i][j] if most_dissimilar_score < 0.975: # send over to RLAIF to_backlog.append( ( item[0], item[1], "rlaif", tuple( [ frozenset(item.items()) for item in messages[most_dissimilar[0]] ] ), tuple( [ frozenset(item.items()) for item in messages[most_dissimilar[1]] ] ), most_dissimilar_score, # Pass tokens/masks/logprobs for solution 1 to_postprocess["tokens"][most_dissimilar[0]], to_postprocess["masks"][most_dissimilar[0]], to_postprocess["inference_logprobs"][most_dissimilar[0]], # Pass tokens/masks/logprobs for solution 2 to_postprocess["tokens"][most_dissimilar[1]], to_postprocess["masks"][most_dissimilar[1]], to_postprocess["inference_logprobs"][most_dissimilar[1]], ) ) print( "\n".join( [ "[" + ", ".join([str(item) for item in score_matrix_row]) + "]" for score_matrix_row in score_matrix ] ) ) print( f"Sending to RLAIF, most dissimilar score: {most_dissimilar_score}" ) else: print( f"Unable to RLAIF, most dissimilar score: {most_dissimilar_score}" ) if random.random() < self.config.percent_length_penalty: # Check if deltas of message lengths are different enough to want to length penalty on message_lengths = [ len(tokens) for tokens in to_postprocess["tokens"] ] min_message_length = min(message_lengths) max_message_delta = max( [msg_len - min_message_length for msg_len in message_lengths] ) if max_message_delta > 0.1 * min_message_length: print( "Max message delta is greater than 0.1 * shortest message, adding length penalty" ) for i in range(len(to_postprocess["scores"])): len_penalty = ( message_lengths[i] - min_message_length ) / max_message_delta len_penalty = math.cos(len_penalty * math.pi) to_postprocess["scores"][i] = len_penalty else: print( "Max message delta is less than 0.1 * shortest message, no length penalty" ) return None, to_backlog else: return None, to_backlog else: return None, to_backlog else: self.normal_rollouts.append( ( item[0], to_postprocess["messages"][0], item[1], to_postprocess["scores"][0], ) ) print("Sending to judge potentially") if random.random() < self.config.percent_to_judge: # find first pos and neg scored answers. pos_idx = [ i for i, score in enumerate(to_postprocess["scores"]) if score == 1.0 ][0] neg_idx = [ i for i, score in enumerate(to_postprocess["scores"]) if (score == -1.0) and ( not to_postprocess["overrides"][i].get( "set_advantage_to_zero", False ) ) ] if len(neg_idx) == 0: return None, to_backlog neg_idx = neg_idx[0] if pos_idx is not None and neg_idx is not None: to_backlog.append( ( item[0], item[1], "judge", to_postprocess["messages"][pos_idx][-1]["content"].split( "" )[-1], "True", ) ) to_backlog.append( ( item[0], item[1], "judge", to_postprocess["messages"][neg_idx][-1]["content"].split( "" )[-1], "False", ) ) print("sending to judge") else: return None, to_backlog print(f"Collected {len(to_postprocess['scores'])} trajectories") if not self.config.mask_too_long_completions: to_postprocess["overrides"] = [ {} for _ in range(len(to_postprocess["scores"])) ] return to_postprocess, to_backlog async def collect_trajectories(self, item) -> Tuple[List, List]: if item[2] == "normal": return await self.collect_trajectories_normal(item) elif item[2] == "rlaif": return await self.collect_trajectories_rlaif(item) elif item[2] == "judge": return await self.collect_trajectories_judge(item) elif item[2] == "selfcorrect": # selfcorrect is a special case where we are using the Judge rollout print("selfcorrect processing...") print("selfcorrect item:", item, flush=True) group = item[3] scores = item[4] finish_reasons = item[5] tokens_list = item[6] masks_list = item[7] logprobs_list = item[8] to_postprocess = ScoredDataGroup() to_postprocess["tokens"] = list() to_postprocess["masks"] = list() to_postprocess["scores"] = list() to_postprocess["overrides"] = list() to_postprocess["messages"] = list() to_postprocess["inference_logprobs"] = list() for i in range(len(group)): # convert from frozen set to dict conv = [dict(x) for x in group[i]] if i == 0: self.selfcorrect_rollouts.append( ( item[0], item[1], conv[0]["content"], conv[1]["content"], ) ) if ( len(self.selfcorrect_rollouts) >= self.config.num_rollouts_to_keep ): self.selfcorrect_rollouts.pop(0) # Use pre-computed tokens/masks/logprobs from managed_server assert len(logprobs_list[i]) == len( masks_list[i] ), f"{len(logprobs_list[i])}, {len(masks_list[i])} mismatch" to_postprocess["tokens"].append(tokens_list[i]) to_postprocess["masks"].append(masks_list[i]) to_postprocess["inference_logprobs"].append(logprobs_list[i]) to_postprocess["scores"].append(scores[i]) to_postprocess["overrides"].append(dict()) if (finish_reasons[i] == "length") and ( self.config.mask_too_long_completions ): to_postprocess["overrides"][-1]["set_advantage_to_zero"] = True # Convert back to messages format for consistency to_postprocess["messages"].append(conv) print("selfcorrect done, sending batch off") return to_postprocess, [] else: raise ValueError(f"Unknown rollout type: {item[2]}") async def score_normal(self, rollout_group_data: List) -> Optional[ScoredDataGroup]: scores = ScoredDataGroup() scores["tokens"] = list() scores["masks"] = list() scores["scores"] = list() scores["overrides"] = list() scores["messages"] = list() scores["inference_logprobs"] = list() gold = rollout_group_data[0][1] loop = asyncio.get_event_loop() random.shuffle(rollout_group_data) for item in rollout_group_data: resp = item[0][-1]["content"].split("")[-1] scores["overrides"].append(dict()) # Extract pre-computed data from managed_server tokens = item[3] masks = item[4] logprobs = item[5] finish_reason = item[2] if finish_reason == "length": reward = False if self.config.mask_too_long_completions: scores["overrides"][-1]["set_advantage_to_zero"] = True else: task = loop.run_in_executor(self.mp_executor, score_answer, gold, resp) reward = await task if reward is None: return None assert len(logprobs) == len( masks ), f"{len(logprobs)}, {len(masks)} mismatch" # Use messages from item[0] messages = item[0] # remove obviously bad examples if len([1 for i in masks if i != -100]) < 10: continue if finish_reason == "length": # Note we set it here so we can filter out the examples that are too long # for the Judge loop. IF you set the config to not do this we fix it # in the collect_trajectories_normal function. scores["overrides"][-1]["set_advantage_to_zero"] = True scores["tokens"].append(tokens) scores["masks"].append(masks) scores["scores"].append(1.0 if reward else -1.0) scores["messages"].append(messages) scores["inference_logprobs"].append(logprobs) if len(scores["tokens"]) >= self.config.group_size: break if any([score == 1.0 for score in scores["scores"]]): self.pass_at_groupsize.append(1.0) else: self.pass_at_groupsize.append(0.0) if len(scores["tokens"]) < self.config.group_size: # We don't have enough data to score return None for score in scores["scores"]: self.percent_correct_buffer.append(max(score, 0)) self.percent_overanswer.extend( [item[2] == "length" for item in rollout_group_data] ) # check if all the same # print(scores['scores']) # Fill in the correct/incorrect lenses after so we're only looking at actual training data self.correct_answer_len.extend( [ len(scores["tokens"][i]) for i in range(len(scores["scores"])) if scores["scores"][i] == 1.0 ] ) self.incorrect_answer_len.extend( [ len(scores["tokens"][i]) for i in range(len(scores["scores"])) if (scores["scores"][i] == -1.0) and (not scores["overrides"][i].get("set_advantage_to_zero", False)) ] ) return scores async def get_next_item(self): while True: next_item = self.train[self.iter % len(self.train)] self.iter += 1 prompt = next_item["question"] try: answer = ( ("\\boxed{" + next_item["final_answer"] + "}") if "\\boxed" not in next_item["final_answer"] else next_item["final_answer"] ) break except TypeError: print( f"Error in getting next item, trying again, " f"data: {next_item['question']} -> {next_item['final_answer']}" ) return (prompt, answer, "normal") async def collect_trajectories_rlaif(self, frozen_item) -> Tuple[List, List]: to_backlog = list() print("Attempting RLAIF") item = list(frozen_item) print("Converting to dicts") item[3] = [dict(x) for x in item[3]] item[4] = [dict(x) for x in item[4]] print("Formatting user prompts") user_prompt_fwd = rlaif_format.format( problem=item[0], solution1=item[3][-1]["content"].split("")[-1], solution2=item[4][-1]["content"].split("")[-1], ) user_prompt_bwd = rlaif_format.format( problem=item[0], solution1=item[4][-1]["content"].split("")[-1], solution2=item[3][-1]["content"].split("")[-1], ) print("Sending to server") chat_fwd = [ {"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt_fwd}, ] max_token_length_fwd = self.config.max_token_length - len( self.tokenizer.apply_chat_template(chat_fwd, add_generation_prompt=True) ) print("Sending to server") # Should be the same token length as the fwd but tokenizers are cursed chat_bwd = [ {"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt_bwd}, ] max_token_length_bwd = self.config.max_token_length - len( self.tokenizer.apply_chat_template(chat_bwd, add_generation_prompt=True) ) # ============================================================================ # MULTIPLE MANAGED SERVER CONTEXTS - RLAIF Pattern # ============================================================================ # This demonstrates using SEPARATE managed_server contexts for independent # completions. Each context tracks its own set of sequences independently. # # Pattern: Create separate async functions that each use their own context, # then gather them in parallel. This is useful for: # - RLAIF (forward/backward preference judgments) # - Multi-step workflows where completions don't extend each other # - Comparing different prompts or conditions # # Note: The tokens/masks/logprobs from these contexts are NOT used directly # in this RLAIF workflow. Instead, we stored them earlier from the original # completions (see lines 461-471 where they're added to backlog_item). # ============================================================================ async def get_fwd_completion(): async with self.server.managed_server(tokenizer=self.tokenizer) as managed: return await managed.chat_completion( messages=chat_fwd, n=3, max_tokens=max_token_length_fwd, temperature=1.0, top_p=0.95, ) async def get_bwd_completion(): async with self.server.managed_server(tokenizer=self.tokenizer) as managed: return await managed.chat_completion( messages=chat_bwd, n=3, max_tokens=max_token_length_bwd, temperature=1.0, top_p=0.95, ) print("Gathering completions") chat_completions_fwd, chat_completions_bwd = await asyncio.gather( get_fwd_completion(), get_bwd_completion() ) print("Grabbed RLAIF completions") # Check for correct answers score_1 = 0 score_2 = 0 for chat_completion in chat_completions_fwd.choices: score = ( chat_completion.message.content.split("")[-1] .split("\\boxed{")[-1] .split("}")[0] .strip() ) if score == "1": score_1 += 1 elif score == "2": score_2 += 1 for chat_completion in chat_completions_bwd.choices: score = ( chat_completion.message.content.split("")[-1] .split("\\boxed{")[-1] .split("}")[0] .strip() ) if score == "1": score_2 += 1 elif score == "2": score_1 += 1 print(f"Score 1: {score_1}, Score 2: {score_2}") if score_1 == score_2: return None, [] self.rlaif_rollouts.append( ( item[0], item[3][-1]["content"].split("")[-1], item[4][-1]["content"].split("")[-1], score_1 - score_2, chat_completions_fwd.choices[0].message.content, ) ) if len(self.rlaif_rollouts) >= self.config.num_rollouts_to_keep: self.rlaif_rollouts.pop(0) print("RLAIF rollout added") to_postprocess = ScoredDataGroup() to_postprocess["tokens"] = list() to_postprocess["masks"] = list() to_postprocess["scores"] = list() to_postprocess["overrides"] = list() to_postprocess["messages"] = list() to_postprocess["inference_logprobs"] = list() # Extract pre-computed tokens/masks/logprobs from backlog tokens_1 = item[6] masks_1 = item[7] logprobs_1 = item[8] tokens_2 = item[9] masks_2 = item[10] logprobs_2 = item[11] # Add assertions to verify data integrity assert len(logprobs_1) == len( masks_1 ), f"{len(logprobs_1)}, {len(masks_1)} mismatch" assert len(logprobs_2) == len( masks_2 ), f"{len(logprobs_2)}, {len(masks_2)} mismatch" # add the first message in to_postprocess["tokens"].append(tokens_1) to_postprocess["masks"].append(masks_1) to_postprocess["scores"].append(1.0 if score_1 > score_2 else -1.0) to_postprocess["messages"].append(item[3]) # Already converted to dicts to_postprocess["inference_logprobs"].append(logprobs_1) to_postprocess["overrides"].append(dict()) # add the second message in to_postprocess["tokens"].append(tokens_2) to_postprocess["masks"].append(masks_2) to_postprocess["scores"].append(1.0 if score_2 > score_1 else -1.0) to_postprocess["messages"].append(item[4]) # Already converted to dicts to_postprocess["inference_logprobs"].append(logprobs_2) to_postprocess["overrides"].append(dict()) to_postprocess["group_overrides"] = { "group_size": 2, } print("RLAIF rollout added") return to_postprocess, to_backlog async def collect_trajectories_judge(self, item) -> Tuple[List, List]: user_prompt = judge_format.format( problem=item[0], solution=item[3], ) to_backlog = list() chat = [ {"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}, ] max_token_length = self.config.max_token_length - len( self.tokenizer.apply_chat_template(chat, add_generation_prompt=True) ) # Judge completions: Standard managed_server usage # Tokens/masks/logprobs from nodes will be used directly for training async with self.server.managed_server(tokenizer=self.tokenizer) as managed: chat_completions = await managed.chat_completion( messages=chat, n=self.config.group_size, max_tokens=max_token_length, temperature=1.0, top_p=0.95, ) # Get tracked sequences with aligned tokens and logprobs state = managed.get_state() nodes = state["nodes"] is_correct = [ ( chat_completion.message.content.split("")[-1] .split("\\boxed{")[-1] .split("}")[0] .strip() == item[4] ) and (chat_completion.finish_reason != "length") for chat_completion in chat_completions.choices ] self.percent_judge_correct.append( sum([1.0 if val else 0.0 for val in is_correct]) / len(is_correct) ) if all([not val for val in is_correct]): # Can't judge :( return None, [] scores = ScoredDataGroup() scores["tokens"] = [] scores["masks"] = [] scores["scores"] = [] scores["overrides"] = [] scores["messages"] = [] scores["inference_logprobs"] = [] for_table = [] for i, (chat_completion, node) in enumerate( zip(chat_completions.choices, nodes) ): # Extract pre-computed data from managed_server tokens = node.tokens masks = node.masked_tokens logprobs = node.logprobs messages = [ {"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}, {"role": "assistant", "content": chat_completion.message.content}, ] assert len(logprobs) == len( masks ), f"{len(logprobs)}, {len(masks)} mismatch" if not is_correct[i]: scores["tokens"].append(tokens) scores["masks"].append(masks) scores["scores"].append(-1.0) scores["messages"].append(messages) scores["inference_logprobs"].append(logprobs) scores["overrides"].append(dict()) if (chat_completion.finish_reason == "length") and ( self.config.mask_too_long_completions ): scores["overrides"][-1]["set_advantage_to_zero"] = True else: if len(for_table) == 0: # populate the table for_table = [ item[0], item[1], item[3], item[4], chat_completion.message.content, ] if item[4] == "False": # Score based on percentage correct from retry print("Scoring retry") retry_prompt = retry_format.format( problem=item[0], solution=item[3], verification=chat_completion.message.content.split("")[ -1 ], ) print("Sending to server") retry_messages = [ {"role": "system", "content": system_prompt}, {"role": "user", "content": retry_prompt}, ] max_token_length = self.config.max_token_length - len( self.tokenizer.apply_chat_template( retry_messages, add_generation_prompt=True ) ) # Retry/self-correction completions: Nested managed_server usage # This demonstrates using managed_server INSIDE another workflow. # Tokens/masks/logprobs from retry_nodes will be stored in backlog # for potential use in the "selfcorrect" trajectory type (see lines 1070-1077) async with self.server.managed_server( tokenizer=self.tokenizer ) as managed: retry_chat_completions = await managed.chat_completion( messages=retry_messages, n=self.config.group_size, max_tokens=max_token_length, temperature=1.0, top_p=0.95, ) # Get tracked sequences with aligned tokens and logprobs retry_state = managed.get_state() retry_nodes = retry_state["nodes"] print("Gathering completions") scoring_data = [] backlog_scores = [] backlog_reasons = [] backlog_messages = [] backlog_tokens = [] backlog_masks = [] backlog_logprobs = [] for j, (retry_chat_completion, retry_node) in enumerate( zip(retry_chat_completions.choices, retry_nodes) ): print(f"Scoring generation {j} for retry...") backlog_messages.append( tuple( [frozenset(msg.items()) for msg in retry_messages] + [ frozenset( { "role": "assistant", "content": retry_chat_completion.message.content, }.items() ) ] ) ) backlog_reasons.append(retry_chat_completion.finish_reason) # Store pre-computed tokens/masks/logprobs from ManagedServer # These will be passed through the backlog (line 1110-1116) and # eventually used in collect_trajectories "selfcorrect" case (line 620-636) backlog_tokens.append(retry_node.tokens) backlog_masks.append(retry_node.masked_tokens) backlog_logprobs.append(retry_node.logprobs) if retry_chat_completion.finish_reason == "length": scoring_data.append(0) backlog_scores.append(0) else: loop = asyncio.get_event_loop() task = loop.run_in_executor( self.mp_executor, score_answer, item[1], retry_chat_completion.message.content.split("")[ -1 ], ) reward = await task scoring_data.append(1.0 if reward else 0.0) backlog_scores.append(1.0 if reward else -1.0) if ( not all( backlog_score == backlog_scores[0] for backlog_score in backlog_scores ) ) or ( all( backlog_reasons == 1.0 for backlog_reason in backlog_reasons ) and (random.random() < self.config.percent_length_penalty) ): to_backlog.append( ( item[0], item[1], "selfcorrect", tuple(backlog_messages), tuple(backlog_scores), tuple(backlog_reasons), tuple(backlog_tokens), tuple(backlog_masks), tuple(backlog_logprobs), ) ) print(f"Sending to selfcorrect, {len(to_backlog)} in backlog") scores["scores"].append(sum(scoring_data) / len(scoring_data)) scores["tokens"].append(tokens) scores["masks"].append(masks) scores["messages"].append(messages) scores["inference_logprobs"].append(logprobs) scores["overrides"].append(dict()) self.judge_success_rate.append( sum(scoring_data) / len(scoring_data) ) if len(self.judge_success_rate) >= self.config.num_rollouts_to_keep: self.judge_success_rate.pop(0) else: scores["scores"].append(1.0) scores["tokens"].append(tokens) scores["masks"].append(masks) scores["messages"].append(messages) scores["inference_logprobs"].append(logprobs) scores["overrides"].append(dict()) if all([score == 1.0 for score in scores["scores"]]) and ( random.random() < self.config.percent_length_penalty ): # Do len penalty message_lengths = [len(tokens) for tokens in scores["tokens"]] min_message_length = min(message_lengths) max_message_delta = max( [msg_len - min_message_length for msg_len in message_lengths] ) if max_message_delta > 0.1 * min_message_length: print( "Max message delta is greater than 0.1 * shortest message, adding length penalty" ) for i in range(len(scores["scores"])): len_penalty = ( message_lengths[i] - min_message_length ) / max_message_delta len_penalty = math.cos(len_penalty * math.pi) scores["scores"][i] = len_penalty else: print( "Max message delta is less than 0.1 * shortest message, no length penalty" ) return None, [] elif all([score == scores["scores"][0] for score in scores["scores"]]): return None, [] if len(for_table) > 0: self.judge_rollouts.append(for_table) if len(self.judge_rollouts) >= self.config.num_rollouts_to_keep: self.judge_rollouts.pop(0) print( f"Collected {len(scores['scores'])} trajectories with {len(to_backlog)} in backlog" ) return scores, to_backlog if __name__ == "__main__": MathEnv.cli()