add chat example and fix bug in managed_server

This commit is contained in:
dmahan93 2025-10-24 23:15:56 -07:00
parent 7bf4cfbf80
commit 5d662bf1aa
2 changed files with 188 additions and 106 deletions

View file

@ -20,7 +20,6 @@ from atroposlib.envs.base import (
EvalHandlingEnum,
ScoredDataGroup,
)
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
system_prompt = (
"You are a deep thinking AI, you may use extremely long chains of thought to deeply consider the "
@ -186,15 +185,15 @@ class MathEnv(BaseEnv):
@classmethod
def config_init(self) -> Tuple[RSConfig, List[APIServerConfig]]:
env_config = RSConfig(
tokenizer_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview",
tokenizer_name="NousResearch/Hermes-4-14B",
group_size=16,
use_wandb=True,
rollout_server_url="http://localhost:8000",
total_steps=1000,
batch_size=1024,
max_num_workers_per_node=24,
max_num_workers_per_node=12,
steps_per_eval=25,
max_token_length=8192, # 22000 // (2 ** i),
max_token_length=16384, # 22000 // (2 ** i),
wandb_name="math",
eval_handling=EvalHandlingEnum.LIMIT_TRAIN,
eval_limit_ratio=0.1,
@ -203,10 +202,11 @@ class MathEnv(BaseEnv):
)
server_configs = [
APIServerConfig(
model_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview",
model_name="NousResearch/Hermes-4-14B",
base_url="http://localhost:9004/v1",
api_key="x",
num_requests_for_eval=256, # since evaling only on one...
server_type="sglang",
),
]
@ -373,17 +373,25 @@ class MathEnv(BaseEnv):
if thinking_len < 1024:
print("thinking_len is less than 1024, skipping", flush=True)
return None, []
chat_completions = await self.server.chat_completion(
messages=chat,
n=self.config.group_size,
max_tokens=thinking_len,
temperature=1.0,
top_p=0.95,
)
# Use managed server for automatic token/logprob tracking
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
chat_completions = await managed.chat_completion(
messages=chat,
n=self.config.group_size,
max_tokens=thinking_len,
temperature=1.0,
top_p=0.95,
)
# Get tracked sequences with aligned tokens and logprobs
state = managed.get_state()
nodes = state["nodes"]
print("Finished generation", flush=True)
to_score = list()
to_backlog = list()
for i, chat_completion in enumerate(chat_completions.choices):
for i, (chat_completion, node) in enumerate(
zip(chat_completions.choices, nodes)
):
messages = (
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
@ -394,6 +402,9 @@ class MathEnv(BaseEnv):
messages,
item[1],
chat_completion.finish_reason,
node.tokens,
node.masked_tokens,
node.logprobs,
)
)
print("scoring normal", flush=True)
@ -447,6 +458,14 @@ class MathEnv(BaseEnv):
]
),
most_dissimilar_score,
# Pass tokens/masks/logprobs for solution 1
to_postprocess["tokens"][most_dissimilar[0]],
to_postprocess["masks"][most_dissimilar[0]],
to_postprocess["inference_logprobs"][most_dissimilar[0]],
# Pass tokens/masks/logprobs for solution 2
to_postprocess["tokens"][most_dissimilar[1]],
to_postprocess["masks"][most_dissimilar[1]],
to_postprocess["inference_logprobs"][most_dissimilar[1]],
)
)
print(
@ -571,12 +590,16 @@ class MathEnv(BaseEnv):
group = item[3]
scores = item[4]
finish_reasons = item[5]
tokens_list = item[6]
masks_list = item[7]
logprobs_list = item[8]
to_postprocess = ScoredDataGroup()
to_postprocess["tokens"] = list()
to_postprocess["masks"] = list()
to_postprocess["scores"] = list()
to_postprocess["overrides"] = list()
to_postprocess["messages"] = list()
to_postprocess["inference_logprobs"] = list()
for i in range(len(group)):
# convert from frozen set to dict
conv = [dict(x) for x in group[i]]
@ -594,21 +617,21 @@ class MathEnv(BaseEnv):
>= self.config.num_rollouts_to_keep
):
self.selfcorrect_rollouts.pop(0)
out_dict = tokenize_for_trainer(
tokenizer=self.tokenizer,
chat=conv,
finish_reason=finish_reasons[i],
include_messages=True,
)
to_postprocess["tokens"].append(out_dict["tokens"])
to_postprocess["masks"].append(out_dict["masks"])
# Use pre-computed tokens/masks/logprobs from managed_server
assert len(logprobs_list[i]) == len(
masks_list[i]
), f"{len(logprobs_list[i])}, {len(masks_list[i])} mismatch"
to_postprocess["tokens"].append(tokens_list[i])
to_postprocess["masks"].append(masks_list[i])
to_postprocess["inference_logprobs"].append(logprobs_list[i])
to_postprocess["scores"].append(scores[i])
to_postprocess["overrides"].append(dict())
if (finish_reasons[i] == "length") and (
self.config.mask_too_long_completions
):
to_postprocess["overrides"][-1]["set_advantage_to_zero"] = True
to_postprocess["messages"].append(out_dict["messages"])
# Convert back to messages format for consistency
to_postprocess["messages"].append(conv)
print("selfcorrect done, sending batch off")
return to_postprocess, []
else:
@ -621,13 +644,20 @@ class MathEnv(BaseEnv):
scores["scores"] = list()
scores["overrides"] = list()
scores["messages"] = list()
scores["inference_logprobs"] = list()
gold = rollout_group_data[0][1]
loop = asyncio.get_event_loop()
random.shuffle(rollout_group_data)
for item in rollout_group_data:
resp = item[0][-1]["content"].split("</think>")[-1]
scores["overrides"].append(dict())
if item[2] == "length":
# Extract pre-computed data from managed_server
tokens = item[3]
masks = item[4]
logprobs = item[5]
finish_reason = item[2]
if finish_reason == "length":
reward = False
if self.config.mask_too_long_completions:
scores["overrides"][-1]["set_advantage_to_zero"] = True
@ -636,19 +666,17 @@ class MathEnv(BaseEnv):
reward = await task
if reward is None:
return None
out_dict = tokenize_for_trainer(
tokenizer=self.tokenizer,
chat=item[0],
finish_reason=item[2],
include_messages=True,
)
tokens = out_dict["tokens"]
masks = out_dict["masks"]
messages = out_dict["messages"]
assert len(logprobs) == len(
masks
), f"{len(logprobs)}, {len(masks)} mismatch"
# Use messages from item[0]
messages = item[0]
# remove obviously bad examples
if len([1 for i in masks if i != -100]) < 10:
continue
if item[2] == "length":
if finish_reason == "length":
# Note we set it here so we can filter out the examples that are too long
# for the Judge loop. IF you set the config to not do this we fix it
# in the collect_trajectories_normal function.
@ -657,6 +685,7 @@ class MathEnv(BaseEnv):
scores["masks"].append(masks)
scores["scores"].append(1.0 if reward else -1.0)
scores["messages"].append(messages)
scores["inference_logprobs"].append(logprobs)
if len(scores["tokens"]) >= self.config.group_size:
break
if any([score == 1.0 for score in scores["scores"]]):
@ -729,39 +758,48 @@ class MathEnv(BaseEnv):
solution2=item[3][-1]["content"].split("</think>")[-1],
)
print("Sending to server")
chat = [
chat_fwd = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt_fwd},
]
max_token_length = self.config.max_token_length - len(
self.tokenizer.apply_chat_template(chat, add_generation_prompt=True)
)
chat_completions_fwd = self.server.chat_completion(
messages=chat,
n=3,
max_tokens=max_token_length,
temperature=1.0,
top_p=0.95,
max_token_length_fwd = self.config.max_token_length - len(
self.tokenizer.apply_chat_template(chat_fwd, add_generation_prompt=True)
)
print("Sending to server")
# Should be the same token length as the fwd but tokenizers are cursed
chat = [
chat_bwd = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt_bwd},
]
max_token_length = self.config.max_token_length - len(
self.tokenizer.apply_chat_template(chat, add_generation_prompt=True)
)
chat_completions_bwd = self.server.chat_completion(
messages=chat,
n=3,
max_tokens=self.config.max_token_length,
temperature=1.0,
top_p=0.95,
max_token_length_bwd = self.config.max_token_length - len(
self.tokenizer.apply_chat_template(chat_bwd, add_generation_prompt=True)
)
# Use managed server for both forward and backward completions
async def get_fwd_completion():
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
return await managed.chat_completion(
messages=chat_fwd,
n=3,
max_tokens=max_token_length_fwd,
temperature=1.0,
top_p=0.95,
)
async def get_bwd_completion():
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
return await managed.chat_completion(
messages=chat_bwd,
n=3,
max_tokens=max_token_length_bwd,
temperature=1.0,
top_p=0.95,
)
print("Gathering completions")
chat_completions_fwd, chat_completions_bwd = await asyncio.gather(
chat_completions_fwd, chat_completions_bwd
get_fwd_completion(), get_bwd_completion()
)
print("Grabbed RLAIF completions")
# Check for correct answers
@ -810,25 +848,35 @@ class MathEnv(BaseEnv):
to_postprocess["scores"] = list()
to_postprocess["overrides"] = list()
to_postprocess["messages"] = list()
to_postprocess["inference_logprobs"] = list()
# Extract pre-computed tokens/masks/logprobs from backlog
tokens_1 = item[6]
masks_1 = item[7]
logprobs_1 = item[8]
tokens_2 = item[9]
masks_2 = item[10]
logprobs_2 = item[11]
# Add assertions to verify data integrity
assert len(logprobs_1) == len(
masks_1
), f"{len(logprobs_1)}, {len(masks_1)} mismatch"
assert len(logprobs_2) == len(
masks_2
), f"{len(logprobs_2)}, {len(masks_2)} mismatch"
# add the first message in
out_dict = tokenize_for_trainer(
tokenizer=self.tokenizer, chat=item[3], include_messages=True
)
tokens = out_dict["tokens"]
masks = out_dict["masks"]
to_postprocess["tokens"].append(tokens)
to_postprocess["masks"].append(masks)
to_postprocess["tokens"].append(tokens_1)
to_postprocess["masks"].append(masks_1)
to_postprocess["scores"].append(1.0 if score_1 > score_2 else -1.0)
to_postprocess["messages"].append(out_dict["messages"])
out_dict = tokenize_for_trainer(
tokenizer=self.tokenizer, chat=item[4], include_messages=True
)
tokens = out_dict["tokens"]
masks = out_dict["masks"]
to_postprocess["tokens"].append(tokens)
to_postprocess["masks"].append(masks)
to_postprocess["messages"].append(item[3]) # Already converted to dicts
to_postprocess["inference_logprobs"].append(logprobs_1)
to_postprocess["overrides"].append(dict())
# add the second message in
to_postprocess["tokens"].append(tokens_2)
to_postprocess["masks"].append(masks_2)
to_postprocess["scores"].append(1.0 if score_2 > score_1 else -1.0)
to_postprocess["messages"].append(out_dict["messages"])
to_postprocess["messages"].append(item[4]) # Already converted to dicts
to_postprocess["inference_logprobs"].append(logprobs_2)
to_postprocess["overrides"].append(dict())
to_postprocess["group_overrides"] = {
"group_size": 2,
}
@ -848,13 +896,19 @@ class MathEnv(BaseEnv):
max_token_length = self.config.max_token_length - len(
self.tokenizer.apply_chat_template(chat, add_generation_prompt=True)
)
chat_completions = await self.server.chat_completion(
messages=chat,
n=self.config.group_size,
max_tokens=max_token_length,
temperature=1.0,
top_p=0.95,
)
# Use managed server for judge completions
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
chat_completions = await managed.chat_completion(
messages=chat,
n=self.config.group_size,
max_tokens=max_token_length,
temperature=1.0,
top_p=0.95,
)
# Get tracked sequences with aligned tokens and logprobs
state = managed.get_state()
nodes = state["nodes"]
is_correct = [
(
chat_completion.message.content.split("</think>")[-1]
@ -878,25 +932,29 @@ class MathEnv(BaseEnv):
scores["scores"] = []
scores["overrides"] = []
scores["messages"] = []
scores["inference_logprobs"] = []
for_table = []
for i, chat_completion in enumerate(chat_completions.choices):
out_dict = tokenize_for_trainer(
tokenizer=self.tokenizer,
chat=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
{"role": "assistant", "content": chat_completion.message.content},
],
include_messages=True,
)
tokens = out_dict["tokens"]
masks = out_dict["masks"]
messages = out_dict["messages"]
for i, (chat_completion, node) in enumerate(
zip(chat_completions.choices, nodes)
):
# Extract pre-computed data from managed_server
tokens = node.tokens
masks = node.masked_tokens
logprobs = node.logprobs
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
{"role": "assistant", "content": chat_completion.message.content},
]
assert len(logprobs) == len(
masks
), f"{len(logprobs)}, {len(masks)} mismatch"
if not is_correct[i]:
scores["tokens"].append(tokens)
scores["masks"].append(masks)
scores["scores"].append(-1.0)
scores["messages"].append(messages)
scores["inference_logprobs"].append(logprobs)
scores["overrides"].append(dict())
if (chat_completion.finish_reason == "length") and (
self.config.mask_too_long_completions
@ -932,20 +990,31 @@ class MathEnv(BaseEnv):
retry_messages, add_generation_prompt=True
)
)
retry_chat_completions = await self.server.chat_completion(
messages=retry_messages,
n=self.config.group_size,
max_tokens=max_token_length,
temperature=1.0,
top_p=0.95,
)
# Use managed server for retry completions
async with self.server.managed_server(
tokenizer=self.tokenizer
) as managed:
retry_chat_completions = await managed.chat_completion(
messages=retry_messages,
n=self.config.group_size,
max_tokens=max_token_length,
temperature=1.0,
top_p=0.95,
)
# Get tracked sequences with aligned tokens and logprobs
retry_state = managed.get_state()
retry_nodes = retry_state["nodes"]
print("Gathering completions")
scoring_data = []
backlog_scores = []
backlog_reasons = []
backlog_messages = []
for j, retry_chat_completion in enumerate(
retry_chat_completions.choices
backlog_tokens = []
backlog_masks = []
backlog_logprobs = []
for j, (retry_chat_completion, retry_node) in enumerate(
zip(retry_chat_completions.choices, retry_nodes)
):
print(f"Scoring generation {j} for retry...")
backlog_messages.append(
@ -962,6 +1031,10 @@ class MathEnv(BaseEnv):
)
)
backlog_reasons.append(retry_chat_completion.finish_reason)
# Store tokens, masks, and logprobs from managed_server
backlog_tokens.append(retry_node.tokens)
backlog_masks.append(retry_node.masked_tokens)
backlog_logprobs.append(retry_node.logprobs)
if retry_chat_completion.finish_reason == "length":
scoring_data.append(0)
backlog_scores.append(0)
@ -998,10 +1071,18 @@ class MathEnv(BaseEnv):
tuple(backlog_messages),
tuple(backlog_scores),
tuple(backlog_reasons),
tuple(backlog_tokens),
tuple(backlog_masks),
tuple(backlog_logprobs),
)
)
print(f"Sending to selfcorrect, {len(to_backlog)} in backlog")
scores["scores"].append(sum(scoring_data) / len(scoring_data))
scores["tokens"].append(tokens)
scores["masks"].append(masks)
scores["messages"].append(messages)
scores["inference_logprobs"].append(logprobs)
scores["overrides"].append(dict())
self.judge_success_rate.append(
sum(scoring_data) / len(scoring_data)
)
@ -1012,6 +1093,7 @@ class MathEnv(BaseEnv):
scores["tokens"].append(tokens)
scores["masks"].append(masks)
scores["messages"].append(messages)
scores["inference_logprobs"].append(logprobs)
scores["overrides"].append(dict())
if all([score == 1.0 for score in scores["scores"]]) and (
random.random() < self.config.percent_length_penalty