atropos/environments/math_server.py
2026-03-10 01:42:44 +05:30

1183 lines
49 KiB
Python

import asyncio
import math
import random
from concurrent.futures import ProcessPoolExecutor
from difflib import SequenceMatcher
from typing import Dict, List, Optional, Tuple
import wandb
from datasets import load_dataset
from latex2sympy2_extended import NormalizationConfig
from math_verify import LatexExtractionConfig, parse, verify
from math_verify.errors import TimeoutException
from pydantic import Field
from tqdm.asyncio import tqdm_asyncio
from atroposlib.envs.base import (
APIServerConfig,
BaseEnv,
BaseEnvConfig,
EvalHandlingEnum,
ScoredDataGroup,
)
system_prompt = (
"You are a deep thinking AI, you may use extremely long chains of thought to deeply consider the "
"problem and deliberate with yourself via systematic reasoning processes to help come to a correct "
"solution prior to answering. You should enclose your thoughts and internal monologue inside <think> "
"</think> tags, and then provide your solution or response to the problem."
)
problem_format = "{problem}"
judge_format = """Here is a math problem and a proposed solution:
[START PROBLEM]
{problem}
[END PROBLEM]
[START SOLUTION]
{solution}
[END SOLUTION]
Please verify if it is correct or not.
If it's correct submit your answer in your response with \\boxed{{True}}.
If it's incorrect, please submit your answer in your response with \\boxed{{False}}.
Please include how to solve the problem correctly in your answer."""
retry_format = """Here is a math problem, a proposed solution, and a verification of the solution:
[START PROBLEM]
{problem}
[END PROBLEM]
[START SOLUTION]
{solution}
[END SOLUTION]
[START VERIFICATION]
{verification}
[END VERIFICATION]
Please use this verification to help you solve the problem correctly.
Provide your answer in your response with \\boxed{{answer}}.""" # noqa: E501
rlaif_format = """Here is a math problem, and two solutions that are correct. Please choose whichever answer you prefer.
[START PROBLEM]
{problem}
[END PROBLEM]
[START SOLUTION 1]
{solution1}
[END SOLUTION 1]
[START SOLUTION 2]
{solution2}
[END SOLUTION 2]
Here are some metrics for you to use to grade the two solutions:
- Conciseness: How concise is the solution? Is it too long or too short?
- Clarity: How clear is the solution? Is it easy to understand?
- Correctness: Is the reasoning correct? The answer has been prechecked to be correct, but there may be errors in the reasoning.
Please use these metrics to help you choose the best solution, in order of priority.
Please provide your answer in your response with \\boxed{{1}}, for the first solution, or \\boxed{{2}} for the second solution.""" # noqa: E501
class RSConfig(BaseEnvConfig):
run_evaluation: bool = Field(True, description="If this should run evaluation")
mask_too_long_completions: bool = Field(
True, description="If this should mask too long completions"
)
percent_to_judge: float = Field(0.3, description="The percentage of items to judge")
percent_length_penalty: float = Field(
0.1, description="The percentage of items to have length penalty"
)
def quick_similarity(a, b):
return SequenceMatcher(None, a, b).ratio()
def score_answer(gold, resp) -> Optional[bool]:
try:
gold_parsed = parse(
gold,
extraction_mode="first_match",
extraction_config=[LatexExtractionConfig()],
)
except (Exception, TimeoutException, KeyError, TypeError, NotImplementedError):
return None
if len(gold_parsed) != 0:
# print(item[0][-1]["content"])
try:
answer_parsed = parse(
resp,
extraction_config=[
LatexExtractionConfig(
normalization_config=NormalizationConfig(
nits=False,
malformed_operators=False,
basic_latex=True,
boxed="all",
units=True,
),
# Ensures that boxed is tried first
boxed_match_priority=0,
try_extract_without_anchor=False,
)
],
extraction_mode="first_match",
)
except (
Exception,
TimeoutException,
KeyError,
TypeError,
NotImplementedError,
):
# Can't parse, so we skip
return None
# Reward 1 if the content is the same as the ground truth, 0 otherwise
try:
return verify(answer_parsed, gold_parsed)
except (
Exception,
TimeoutException,
KeyError,
TypeError,
NotImplementedError,
):
return None
return None
class MathEnv(BaseEnv):
name = "math"
env_config_cls = RSConfig
def __init__(
self,
config: RSConfig,
server_configs: List[APIServerConfig],
slurm=True,
testing=False,
):
print("Initializing MathEnv")
print(f"Slurm: {slurm}, Testing: {testing}")
super().__init__(config, server_configs, slurm, testing)
self.percent_correct_buffer = list()
self.eval_metrics = list()
self.mp_executor = ProcessPoolExecutor(64)
self.percent_overanswer = list()
self.percent_judge_correct = list()
self.correct_answer_len = list()
self.incorrect_answer_len = list()
self.normal_rollouts = list()
self.rlaif_rollouts = list()
self.pass_at_groupsize = list()
self.judge_rollouts = list()
self.selfcorrect_rollouts = list()
self.judge_success_rate = list()
self.iter = 0
@classmethod
def config_init(self) -> Tuple[RSConfig, List[APIServerConfig]]:
env_config = RSConfig(
tokenizer_name="NousResearch/Hermes-4-14B",
group_size=16,
use_wandb=True,
rollout_server_url="http://localhost:8000",
total_steps=1000,
batch_size=1024,
max_num_workers_per_node=12,
steps_per_eval=25,
max_token_length=16384, # 22000 // (2 ** i),
wandb_name="math",
eval_handling=EvalHandlingEnum.LIMIT_TRAIN,
eval_limit_ratio=0.1,
inference_weight=4,
min_batch_allocation=0.1,
)
server_configs = [
APIServerConfig(
model_name="NousResearch/Hermes-4-14B",
base_url="http://localhost:9004/v1",
api_key="x",
num_requests_for_eval=256, # since evaling only on one...
server_type="sglang",
),
]
return env_config, server_configs
async def wandb_log(self, wandb_metrics: Optional[Dict] = None):
if wandb_metrics is None:
wandb_metrics = dict()
if len(self.pass_at_groupsize) > 0:
wandb_metrics["train/pass_at_groupsize"] = sum(
self.pass_at_groupsize
) / len(self.pass_at_groupsize)
self.pass_at_8 = list()
if len(self.percent_correct_buffer) > 0:
wandb_metrics["train/percent_correct"] = sum(
self.percent_correct_buffer
) / len(self.percent_correct_buffer)
wandb_metrics["train/percent_overanswer"] = sum(
self.percent_overanswer
) / len(self.percent_overanswer)
self.percent_overthink = list()
self.percent_overanswer = list()
self.percent_correct_buffer = list()
if len(self.correct_answer_len) > 0:
wandb_metrics["train/avg_correct_answer_len"] = sum(
self.correct_answer_len
) / len(self.correct_answer_len)
self.correct_answer_len = list()
if len(self.incorrect_answer_len) > 0:
wandb_metrics["train/avg_incorrect_answer_len"] = sum(
self.incorrect_answer_len
) / len(self.incorrect_answer_len)
self.incorrect_answer_len = list()
if len(self.percent_judge_correct) > 0:
wandb_metrics["judge_train/percent_judge_correct"] = sum(
self.percent_judge_correct
) / len(self.percent_judge_correct)
self.percent_judge_correct = list()
if len(self.judge_success_rate) > 0:
wandb_metrics["judge_train/judge_success_rate"] = sum(
self.judge_success_rate
) / len(self.judge_success_rate)
# create tables
if len(self.judge_rollouts) > 0:
table = wandb.Table(
columns=["problem", "solution", "answer", "correct", "judge"]
)
for group in self.judge_rollouts:
table.add_data(group[0], group[1], group[2], group[3], group[4])
wandb_metrics["judge_train/judge_rollouts"] = table
if len(self.selfcorrect_rollouts) > 0:
table = wandb.Table(columns=["problem", "solution1", "solution2", "score"])
for group in self.selfcorrect_rollouts:
table.add_data(group[0], group[1], group[2], group[3])
wandb_metrics["judge_train/selfcorrect_rollouts"] = table
if len(self.normal_rollouts) > 0:
table = wandb.Table(columns=["problem", "solution", "answer", "score"])
for group in self.normal_rollouts:
table.add_data(group[0], group[1], group[2], group[3])
wandb_metrics["train/normal_rollouts"] = table
if len(self.rlaif_rollouts) > 0:
table = wandb.Table(
columns=["problem", "solution1", "solution2", "score", "rollout"]
)
for group in self.rlaif_rollouts:
table.add_data(group[0], group[1], group[2], group[3], group[4])
wandb_metrics["train/rlaif_rollouts"] = table
wandb_metrics["train/iter"] = self.iter
for item in self.eval_metrics:
wandb_metrics[item[0]] = item[1]
self.eval_metrics = list()
await super().wandb_log(wandb_metrics)
async def setup(self):
self.train = load_dataset("zwhe99/DeepMath-103K", split="train").shuffle(
seed=42
)
aime_test_data = load_dataset("HuggingFaceH4/aime_2024", split="train")
math500_test_data = load_dataset("HuggingFaceH4/math-500", split="test")
amc_test_data = load_dataset("math-ai/amc23", split="test")
self.test = list()
for name, t_dataset in zip(
["aime24", "math500"], [aime_test_data, math500_test_data]
):
for item in t_dataset:
self.test.append(
(
problem_format.format(problem=item["problem"]),
item["answer"],
name,
)
)
for name, t_dataset in zip(
["amc23"],
[amc_test_data],
):
for item in t_dataset:
self.test.append(
(
problem_format.format(problem=item["question"]),
item["answer"],
name,
)
)
async def rollout_and_score_eval(self, question, answer, subset):
completion = await self.server.chat_completion(
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": question},
],
n=1,
max_tokens=32765,
temperature=0.0,
split="eval",
)
loop = asyncio.get_event_loop()
gold = "\\boxed{" + answer + "}" if "\\boxed" not in answer else answer
resp = completion.choices[0].message.content.split("</think>")[-1]
task = loop.run_in_executor(self.mp_executor, score_answer, gold, resp)
reward = await task
if reward is None:
return 0, subset
score = 1 if reward else 0
return score, subset
async def evaluate(self, *args, **kwargs):
if not self.config.run_evaluation:
return
eval_tasks = []
for item in self.test:
eval_tasks.append(self.rollout_and_score_eval(item[0], item[1], item[2]))
parsing_data = await tqdm_asyncio.gather(*eval_tasks)
task_lists = dict()
for score, subset in parsing_data:
if subset not in task_lists:
task_lists[subset] = list()
task_lists[subset].append(score)
# Now get the average
for subset, scores in task_lists.items():
self.eval_metrics.append(
(f"eval/{subset}_percent_correct", sum(scores) / len(scores))
)
# overall score
scores = []
for subset, score in task_lists.items():
scores.extend(score)
self.eval_metrics.append(
("eval/overall_percent_correct", sum(scores) / len(scores))
)
async def collect_trajectories_normal(self, item) -> Tuple[List, List]:
thinking_len = self.config.max_token_length
user_prompt = problem_format.format(problem=item[0])
chat = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
]
thinking_len = thinking_len - len(
self.tokenizer.apply_chat_template(chat, add_generation_prompt=True)
)
print(f"thinking_len: {thinking_len}", flush=True)
if thinking_len < 1024:
print("thinking_len is less than 1024, skipping", flush=True)
return None, []
# ============================================================================
# MANAGED SERVER USAGE - Chat Completion API
# ============================================================================
# This demonstrates using ManagedServer with the chat_completion() API.
# The process is identical to the completion() API (see math_server_zero.py),
# but uses OpenAI chat message format instead of raw text prompts.
#
# ManagedServer automatically:
# 1. Applies the tokenizer's chat template to convert messages to text
# 2. Tokenizes both prompt and completion
# 3. Applies proper masking (-100 for prompt tokens, actual IDs for completion)
# 4. Applies proper logprob masking (1.0 for prompt, actual values for completion)
# 5. Ensures perfect alignment between tokens and logprobs
#
# See: atroposlib/envs/server_handling/MANAGED_SERVER.md for full documentation
# ============================================================================
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
# Call chat_completion through the managed server wrapper
# Returns standard OpenAI-compatible ChatCompletion object
chat_completions = await managed.chat_completion(
messages=chat,
n=self.config.group_size, # Generate multiple completions for GRPO
max_tokens=thinking_len,
temperature=1.0,
top_p=0.95,
)
# Get tracked sequences with aligned tokens and logprobs
state = managed.get_state()
nodes = state["nodes"] # List of SequenceNode objects, one per completion
print("Finished generation", flush=True)
to_score = list()
to_backlog = list()
for i, (chat_completion, node) in enumerate(
zip(chat_completions.choices, nodes)
):
messages = (
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
{"role": "assistant", "content": chat_completion.message.content},
)
# Extract pre-computed data from SequenceNode
# node.tokens: Full unmasked tokens [prompt + completion]
# node.masked_tokens: [-100, ..., -100, tok1, tok2, ...] for training
# node.logprobs: [1.0, ..., 1.0, logp1, logp2, ...] for training
to_score.append(
(
messages,
item[1], # Ground truth answer
chat_completion.finish_reason,
node.tokens, # Pre-computed by ManagedServer
node.masked_tokens, # Pre-computed by ManagedServer
node.logprobs, # Pre-computed by ManagedServer
)
)
print("scoring normal", flush=True)
to_postprocess = await self.score_normal(to_score)
print("scoring normal done", flush=True)
if to_postprocess is None:
return None, to_backlog
if all(
[to_postprocess["scores"][0] == score for score in to_postprocess["scores"]]
):
if to_postprocess["scores"][0] == 1.0:
# we can do RLAIF
# find the two most dissimilar messages
messages = to_postprocess["messages"]
score_matrix = []
most_dissimilar = (0, 1)
most_dissimilar_score = 1.0
# find the two most dissimilar messages
for i in range(len(messages) - 1):
score_matrix.append([])
for j in range(i + 1):
# Only need to compute half of the matrix
score_matrix[i].append(1.0)
for j in range(i + 1, len(messages)):
m1 = messages[i][-1]["content"].split("</think>")[-1]
m2 = messages[j][-1]["content"].split("</think>")[-1]
if m1 == m2:
score_matrix[i].append(1.0)
else:
score_matrix[i].append(quick_similarity(m1, m2))
if score_matrix[i][j] < most_dissimilar_score:
most_dissimilar = (i, j)
most_dissimilar_score = score_matrix[i][j]
if most_dissimilar_score < 0.975:
# send over to RLAIF
to_backlog.append(
(
item[0],
item[1],
"rlaif",
tuple(
[
frozenset(item.items())
for item in messages[most_dissimilar[0]]
]
),
tuple(
[
frozenset(item.items())
for item in messages[most_dissimilar[1]]
]
),
most_dissimilar_score,
# Pass tokens/masks/logprobs for solution 1
to_postprocess["tokens"][most_dissimilar[0]],
to_postprocess["masks"][most_dissimilar[0]],
to_postprocess["inference_logprobs"][most_dissimilar[0]],
# Pass tokens/masks/logprobs for solution 2
to_postprocess["tokens"][most_dissimilar[1]],
to_postprocess["masks"][most_dissimilar[1]],
to_postprocess["inference_logprobs"][most_dissimilar[1]],
)
)
print(
"\n".join(
[
"["
+ ", ".join([str(item) for item in score_matrix_row])
+ "]"
for score_matrix_row in score_matrix
]
)
)
print(
f"Sending to RLAIF, most dissimilar score: {most_dissimilar_score}"
)
else:
print(
f"Unable to RLAIF, most dissimilar score: {most_dissimilar_score}"
)
if random.random() < self.config.percent_length_penalty:
# Check if deltas of message lengths are different enough to want to length penalty on
message_lengths = [
len(tokens) for tokens in to_postprocess["tokens"]
]
min_message_length = min(message_lengths)
max_message_delta = max(
[msg_len - min_message_length for msg_len in message_lengths]
)
if max_message_delta > 0.1 * min_message_length:
print(
"Max message delta is greater than 0.1 * shortest message, adding length penalty"
)
for i in range(len(to_postprocess["scores"])):
len_penalty = (
message_lengths[i] - min_message_length
) / max_message_delta
len_penalty = math.cos(len_penalty * math.pi)
to_postprocess["scores"][i] = len_penalty
else:
print(
"Max message delta is less than 0.1 * shortest message, no length penalty"
)
return None, to_backlog
else:
return None, to_backlog
else:
return None, to_backlog
else:
self.normal_rollouts.append(
(
item[0],
to_postprocess["messages"][0],
item[1],
to_postprocess["scores"][0],
)
)
print("Sending to judge potentially")
if random.random() < self.config.percent_to_judge:
# find first pos and neg scored answers.
pos_idx = [
i
for i, score in enumerate(to_postprocess["scores"])
if score == 1.0
]
if len(pos_idx) == 0:
return None, to_backlog
pos_idx = pos_idx[0]
neg_idx = [
i
for i, score in enumerate(to_postprocess["scores"])
if (score == -1.0)
and (
not to_postprocess["overrides"][i].get(
"set_advantage_to_zero", False
)
)
]
if len(neg_idx) == 0:
return None, to_backlog
neg_idx = neg_idx[0]
if pos_idx is not None and neg_idx is not None:
to_backlog.append(
(
item[0],
item[1],
"judge",
to_postprocess["messages"][pos_idx][-1]["content"].split(
"</think>"
)[-1],
"True",
)
)
to_backlog.append(
(
item[0],
item[1],
"judge",
to_postprocess["messages"][neg_idx][-1]["content"].split(
"</think>"
)[-1],
"False",
)
)
print("sending to judge")
else:
return None, to_backlog
print(f"Collected {len(to_postprocess['scores'])} trajectories")
if not self.config.mask_too_long_completions:
to_postprocess["overrides"] = [
{} for _ in range(len(to_postprocess["scores"]))
]
return to_postprocess, to_backlog
async def collect_trajectories(self, item) -> Tuple[List, List]:
if item[2] == "normal":
return await self.collect_trajectories_normal(item)
elif item[2] == "rlaif":
return await self.collect_trajectories_rlaif(item)
elif item[2] == "judge":
return await self.collect_trajectories_judge(item)
elif item[2] == "selfcorrect":
# selfcorrect is a special case where we are using the Judge rollout
print("selfcorrect processing...")
print("selfcorrect item:", item, flush=True)
group = item[3]
scores = item[4]
finish_reasons = item[5]
tokens_list = item[6]
masks_list = item[7]
logprobs_list = item[8]
to_postprocess = ScoredDataGroup()
to_postprocess["tokens"] = list()
to_postprocess["masks"] = list()
to_postprocess["scores"] = list()
to_postprocess["overrides"] = list()
to_postprocess["messages"] = list()
to_postprocess["inference_logprobs"] = list()
for i in range(len(group)):
# convert from frozen set to dict
conv = [dict(x) for x in group[i]]
if i == 0:
self.selfcorrect_rollouts.append(
(
item[0],
item[1],
conv[0]["content"],
conv[1]["content"],
)
)
if (
len(self.selfcorrect_rollouts)
>= self.config.num_rollouts_to_keep
):
self.selfcorrect_rollouts.pop(0)
# Use pre-computed tokens/masks/logprobs from managed_server
assert len(logprobs_list[i]) == len(
masks_list[i]
), f"{len(logprobs_list[i])}, {len(masks_list[i])} mismatch"
to_postprocess["tokens"].append(tokens_list[i])
to_postprocess["masks"].append(masks_list[i])
to_postprocess["inference_logprobs"].append(logprobs_list[i])
to_postprocess["scores"].append(scores[i])
to_postprocess["overrides"].append(dict())
if (finish_reasons[i] == "length") and (
self.config.mask_too_long_completions
):
to_postprocess["overrides"][-1]["set_advantage_to_zero"] = True
# Convert back to messages format for consistency
to_postprocess["messages"].append(conv)
print("selfcorrect done, sending batch off")
return to_postprocess, []
else:
raise ValueError(f"Unknown rollout type: {item[2]}")
async def score_normal(self, rollout_group_data: List) -> Optional[ScoredDataGroup]:
scores = ScoredDataGroup()
scores["tokens"] = list()
scores["masks"] = list()
scores["scores"] = list()
scores["overrides"] = list()
scores["messages"] = list()
scores["inference_logprobs"] = list()
gold = rollout_group_data[0][1]
loop = asyncio.get_event_loop()
random.shuffle(rollout_group_data)
for item in rollout_group_data:
resp = item[0][-1]["content"].split("</think>")[-1]
scores["overrides"].append(dict())
# Extract pre-computed data from managed_server
tokens = item[3]
masks = item[4]
logprobs = item[5]
finish_reason = item[2]
if finish_reason == "length":
reward = False
if self.config.mask_too_long_completions:
scores["overrides"][-1]["set_advantage_to_zero"] = True
else:
task = loop.run_in_executor(self.mp_executor, score_answer, gold, resp)
reward = await task
if reward is None:
return None
assert len(logprobs) == len(
masks
), f"{len(logprobs)}, {len(masks)} mismatch"
# Use messages from item[0]
messages = item[0]
# remove obviously bad examples
if len([1 for i in masks if i != -100]) < 10:
continue
if finish_reason == "length":
# Note we set it here so we can filter out the examples that are too long
# for the Judge loop. IF you set the config to not do this we fix it
# in the collect_trajectories_normal function.
scores["overrides"][-1]["set_advantage_to_zero"] = True
scores["tokens"].append(tokens)
scores["masks"].append(masks)
scores["scores"].append(1.0 if reward else -1.0)
scores["messages"].append(messages)
scores["inference_logprobs"].append(logprobs)
if len(scores["tokens"]) >= self.config.group_size:
break
if any([score == 1.0 for score in scores["scores"]]):
self.pass_at_groupsize.append(1.0)
else:
self.pass_at_groupsize.append(0.0)
if len(scores["tokens"]) < self.config.group_size:
# We don't have enough data to score
return None
for score in scores["scores"]:
self.percent_correct_buffer.append(max(score, 0))
self.percent_overanswer.extend(
[item[2] == "length" for item in rollout_group_data]
)
# check if all the same
# print(scores['scores'])
# Fill in the correct/incorrect lenses after so we're only looking at actual training data
self.correct_answer_len.extend(
[
len(scores["tokens"][i])
for i in range(len(scores["scores"]))
if scores["scores"][i] == 1.0
]
)
self.incorrect_answer_len.extend(
[
len(scores["tokens"][i])
for i in range(len(scores["scores"]))
if (scores["scores"][i] == -1.0)
and (not scores["overrides"][i].get("set_advantage_to_zero", False))
]
)
return scores
async def get_next_item(self):
while True:
next_item = self.train[self.iter % len(self.train)]
self.iter += 1
prompt = next_item["question"]
try:
answer = (
("\\boxed{" + next_item["final_answer"] + "}")
if "\\boxed" not in next_item["final_answer"]
else next_item["final_answer"]
)
break
except TypeError:
print(
f"Error in getting next item, trying again, "
f"data: {next_item['question']} -> {next_item['final_answer']}"
)
return (prompt, answer, "normal")
async def collect_trajectories_rlaif(self, frozen_item) -> Tuple[List, List]:
to_backlog = list()
print("Attempting RLAIF")
item = list(frozen_item)
print("Converting to dicts")
item[3] = [dict(x) for x in item[3]]
item[4] = [dict(x) for x in item[4]]
print("Formatting user prompts")
user_prompt_fwd = rlaif_format.format(
problem=item[0],
solution1=item[3][-1]["content"].split("</think>")[-1],
solution2=item[4][-1]["content"].split("</think>")[-1],
)
user_prompt_bwd = rlaif_format.format(
problem=item[0],
solution1=item[4][-1]["content"].split("</think>")[-1],
solution2=item[3][-1]["content"].split("</think>")[-1],
)
print("Sending to server")
chat_fwd = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt_fwd},
]
max_token_length_fwd = self.config.max_token_length - len(
self.tokenizer.apply_chat_template(chat_fwd, add_generation_prompt=True)
)
print("Sending to server")
# Should be the same token length as the fwd but tokenizers are cursed
chat_bwd = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt_bwd},
]
max_token_length_bwd = self.config.max_token_length - len(
self.tokenizer.apply_chat_template(chat_bwd, add_generation_prompt=True)
)
# ============================================================================
# MULTIPLE MANAGED SERVER CONTEXTS - RLAIF Pattern
# ============================================================================
# This demonstrates using SEPARATE managed_server contexts for independent
# completions. Each context tracks its own set of sequences independently.
#
# Pattern: Create separate async functions that each use their own context,
# then gather them in parallel. This is useful for:
# - RLAIF (forward/backward preference judgments)
# - Multi-step workflows where completions don't extend each other
# - Comparing different prompts or conditions
#
# Note: The tokens/masks/logprobs from these contexts are NOT used directly
# in this RLAIF workflow. Instead, we stored them earlier from the original
# completions (see lines 461-471 where they're added to backlog_item).
# ============================================================================
async def get_fwd_completion():
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
return await managed.chat_completion(
messages=chat_fwd,
n=3,
max_tokens=max_token_length_fwd,
temperature=1.0,
top_p=0.95,
)
async def get_bwd_completion():
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
return await managed.chat_completion(
messages=chat_bwd,
n=3,
max_tokens=max_token_length_bwd,
temperature=1.0,
top_p=0.95,
)
print("Gathering completions")
chat_completions_fwd, chat_completions_bwd = await asyncio.gather(
get_fwd_completion(), get_bwd_completion()
)
print("Grabbed RLAIF completions")
# Check for correct answers
score_1 = 0
score_2 = 0
for chat_completion in chat_completions_fwd.choices:
score = (
chat_completion.message.content.split("</think>")[-1]
.split("\\boxed{")[-1]
.split("}")[0]
.strip()
)
if score == "1":
score_1 += 1
elif score == "2":
score_2 += 1
for chat_completion in chat_completions_bwd.choices:
score = (
chat_completion.message.content.split("</think>")[-1]
.split("\\boxed{")[-1]
.split("}")[0]
.strip()
)
if score == "1":
score_2 += 1
elif score == "2":
score_1 += 1
print(f"Score 1: {score_1}, Score 2: {score_2}")
if score_1 == score_2:
return None, []
self.rlaif_rollouts.append(
(
item[0],
item[3][-1]["content"].split("</think>")[-1],
item[4][-1]["content"].split("</think>")[-1],
score_1 - score_2,
chat_completions_fwd.choices[0].message.content,
)
)
if len(self.rlaif_rollouts) >= self.config.num_rollouts_to_keep:
self.rlaif_rollouts.pop(0)
print("RLAIF rollout added")
to_postprocess = ScoredDataGroup()
to_postprocess["tokens"] = list()
to_postprocess["masks"] = list()
to_postprocess["scores"] = list()
to_postprocess["overrides"] = list()
to_postprocess["messages"] = list()
to_postprocess["inference_logprobs"] = list()
# Extract pre-computed tokens/masks/logprobs from backlog
tokens_1 = item[6]
masks_1 = item[7]
logprobs_1 = item[8]
tokens_2 = item[9]
masks_2 = item[10]
logprobs_2 = item[11]
# Add assertions to verify data integrity
assert len(logprobs_1) == len(
masks_1
), f"{len(logprobs_1)}, {len(masks_1)} mismatch"
assert len(logprobs_2) == len(
masks_2
), f"{len(logprobs_2)}, {len(masks_2)} mismatch"
# add the first message in
to_postprocess["tokens"].append(tokens_1)
to_postprocess["masks"].append(masks_1)
to_postprocess["scores"].append(1.0 if score_1 > score_2 else -1.0)
to_postprocess["messages"].append(item[3]) # Already converted to dicts
to_postprocess["inference_logprobs"].append(logprobs_1)
to_postprocess["overrides"].append(dict())
# add the second message in
to_postprocess["tokens"].append(tokens_2)
to_postprocess["masks"].append(masks_2)
to_postprocess["scores"].append(1.0 if score_2 > score_1 else -1.0)
to_postprocess["messages"].append(item[4]) # Already converted to dicts
to_postprocess["inference_logprobs"].append(logprobs_2)
to_postprocess["overrides"].append(dict())
to_postprocess["group_overrides"] = {
"group_size": 2,
}
print("RLAIF rollout added")
return to_postprocess, to_backlog
async def collect_trajectories_judge(self, item) -> Tuple[List, List]:
user_prompt = judge_format.format(
problem=item[0],
solution=item[3],
)
to_backlog = list()
chat = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
]
max_token_length = self.config.max_token_length - len(
self.tokenizer.apply_chat_template(chat, add_generation_prompt=True)
)
# Judge completions: Standard managed_server usage
# Tokens/masks/logprobs from nodes will be used directly for training
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
chat_completions = await managed.chat_completion(
messages=chat,
n=self.config.group_size,
max_tokens=max_token_length,
temperature=1.0,
top_p=0.95,
)
# Get tracked sequences with aligned tokens and logprobs
state = managed.get_state()
nodes = state["nodes"]
is_correct = [
(
chat_completion.message.content.split("</think>")[-1]
.split("\\boxed{")[-1]
.split("}")[0]
.strip()
== item[4]
)
and (chat_completion.finish_reason != "length")
for chat_completion in chat_completions.choices
]
self.percent_judge_correct.append(
sum([1.0 if val else 0.0 for val in is_correct]) / len(is_correct)
)
if all([not val for val in is_correct]):
# Can't judge :(
return None, []
scores = ScoredDataGroup()
scores["tokens"] = []
scores["masks"] = []
scores["scores"] = []
scores["overrides"] = []
scores["messages"] = []
scores["inference_logprobs"] = []
for_table = []
for i, (chat_completion, node) in enumerate(
zip(chat_completions.choices, nodes)
):
# Extract pre-computed data from managed_server
tokens = node.tokens
masks = node.masked_tokens
logprobs = node.logprobs
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
{"role": "assistant", "content": chat_completion.message.content},
]
assert len(logprobs) == len(
masks
), f"{len(logprobs)}, {len(masks)} mismatch"
if not is_correct[i]:
scores["tokens"].append(tokens)
scores["masks"].append(masks)
scores["scores"].append(-1.0)
scores["messages"].append(messages)
scores["inference_logprobs"].append(logprobs)
scores["overrides"].append(dict())
if (chat_completion.finish_reason == "length") and (
self.config.mask_too_long_completions
):
scores["overrides"][-1]["set_advantage_to_zero"] = True
else:
if len(for_table) == 0:
# populate the table
for_table = [
item[0],
item[1],
item[3],
item[4],
chat_completion.message.content,
]
if item[4] == "False":
# Score based on percentage correct from retry
print("Scoring retry")
retry_prompt = retry_format.format(
problem=item[0],
solution=item[3],
verification=chat_completion.message.content.split("</think>")[
-1
],
)
print("Sending to server")
retry_messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": retry_prompt},
]
max_token_length = self.config.max_token_length - len(
self.tokenizer.apply_chat_template(
retry_messages, add_generation_prompt=True
)
)
# Retry/self-correction completions: Nested managed_server usage
# This demonstrates using managed_server INSIDE another workflow.
# Tokens/masks/logprobs from retry_nodes will be stored in backlog
# for potential use in the "selfcorrect" trajectory type (see lines 1070-1077)
async with self.server.managed_server(
tokenizer=self.tokenizer
) as managed:
retry_chat_completions = await managed.chat_completion(
messages=retry_messages,
n=self.config.group_size,
max_tokens=max_token_length,
temperature=1.0,
top_p=0.95,
)
# Get tracked sequences with aligned tokens and logprobs
retry_state = managed.get_state()
retry_nodes = retry_state["nodes"]
print("Gathering completions")
scoring_data = []
backlog_scores = []
backlog_reasons = []
backlog_messages = []
backlog_tokens = []
backlog_masks = []
backlog_logprobs = []
for j, (retry_chat_completion, retry_node) in enumerate(
zip(retry_chat_completions.choices, retry_nodes)
):
print(f"Scoring generation {j} for retry...")
backlog_messages.append(
tuple(
[frozenset(msg.items()) for msg in retry_messages]
+ [
frozenset(
{
"role": "assistant",
"content": retry_chat_completion.message.content,
}.items()
)
]
)
)
backlog_reasons.append(retry_chat_completion.finish_reason)
# Store pre-computed tokens/masks/logprobs from ManagedServer
# These will be passed through the backlog (line 1110-1116) and
# eventually used in collect_trajectories "selfcorrect" case (line 620-636)
backlog_tokens.append(retry_node.tokens)
backlog_masks.append(retry_node.masked_tokens)
backlog_logprobs.append(retry_node.logprobs)
if retry_chat_completion.finish_reason == "length":
scoring_data.append(0)
backlog_scores.append(0)
else:
loop = asyncio.get_event_loop()
task = loop.run_in_executor(
self.mp_executor,
score_answer,
item[1],
retry_chat_completion.message.content.split("</think>")[
-1
],
)
reward = await task
scoring_data.append(1.0 if reward else 0.0)
backlog_scores.append(1.0 if reward else -1.0)
if (
not all(
backlog_score == backlog_scores[0]
for backlog_score in backlog_scores
)
) or (
all(
backlog_reasons == 1.0 for backlog_reason in backlog_reasons
)
and (random.random() < self.config.percent_length_penalty)
):
to_backlog.append(
(
item[0],
item[1],
"selfcorrect",
tuple(backlog_messages),
tuple(backlog_scores),
tuple(backlog_reasons),
tuple(backlog_tokens),
tuple(backlog_masks),
tuple(backlog_logprobs),
)
)
print(f"Sending to selfcorrect, {len(to_backlog)} in backlog")
scores["scores"].append(sum(scoring_data) / len(scoring_data))
scores["tokens"].append(tokens)
scores["masks"].append(masks)
scores["messages"].append(messages)
scores["inference_logprobs"].append(logprobs)
scores["overrides"].append(dict())
self.judge_success_rate.append(
sum(scoring_data) / len(scoring_data)
)
if len(self.judge_success_rate) >= self.config.num_rollouts_to_keep:
self.judge_success_rate.pop(0)
else:
scores["scores"].append(1.0)
scores["tokens"].append(tokens)
scores["masks"].append(masks)
scores["messages"].append(messages)
scores["inference_logprobs"].append(logprobs)
scores["overrides"].append(dict())
if all([score == 1.0 for score in scores["scores"]]) and (
random.random() < self.config.percent_length_penalty
):
# Do len penalty
message_lengths = [len(tokens) for tokens in scores["tokens"]]
min_message_length = min(message_lengths)
max_message_delta = max(
[msg_len - min_message_length for msg_len in message_lengths]
)
if max_message_delta > 0.1 * min_message_length:
print(
"Max message delta is greater than 0.1 * shortest message, adding length penalty"
)
for i in range(len(scores["scores"])):
len_penalty = (
message_lengths[i] - min_message_length
) / max_message_delta
len_penalty = math.cos(len_penalty * math.pi)
scores["scores"][i] = len_penalty
else:
print(
"Max message delta is less than 0.1 * shortest message, no length penalty"
)
return None, []
elif all([score == scores["scores"][0] for score in scores["scores"]]):
return None, []
if len(for_table) > 0:
self.judge_rollouts.append(for_table)
if len(self.judge_rollouts) >= self.config.num_rollouts_to_keep:
self.judge_rollouts.pop(0)
print(
f"Collected {len(scores['scores'])} trajectories with {len(to_backlog)} in backlog"
)
return scores, to_backlog
if __name__ == "__main__":
MathEnv.cli()