mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-22 16:48:57 +00:00
add chat example and fix bug in managed_server
This commit is contained in:
parent
7bf4cfbf80
commit
5d662bf1aa
2 changed files with 188 additions and 106 deletions
|
|
@ -20,7 +20,6 @@ from atroposlib.envs.base import (
|
|||
EvalHandlingEnum,
|
||||
ScoredDataGroup,
|
||||
)
|
||||
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
|
||||
|
||||
system_prompt = (
|
||||
"You are a deep thinking AI, you may use extremely long chains of thought to deeply consider the "
|
||||
|
|
@ -186,15 +185,15 @@ class MathEnv(BaseEnv):
|
|||
@classmethod
|
||||
def config_init(self) -> Tuple[RSConfig, List[APIServerConfig]]:
|
||||
env_config = RSConfig(
|
||||
tokenizer_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview",
|
||||
tokenizer_name="NousResearch/Hermes-4-14B",
|
||||
group_size=16,
|
||||
use_wandb=True,
|
||||
rollout_server_url="http://localhost:8000",
|
||||
total_steps=1000,
|
||||
batch_size=1024,
|
||||
max_num_workers_per_node=24,
|
||||
max_num_workers_per_node=12,
|
||||
steps_per_eval=25,
|
||||
max_token_length=8192, # 22000 // (2 ** i),
|
||||
max_token_length=16384, # 22000 // (2 ** i),
|
||||
wandb_name="math",
|
||||
eval_handling=EvalHandlingEnum.LIMIT_TRAIN,
|
||||
eval_limit_ratio=0.1,
|
||||
|
|
@ -203,10 +202,11 @@ class MathEnv(BaseEnv):
|
|||
)
|
||||
server_configs = [
|
||||
APIServerConfig(
|
||||
model_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview",
|
||||
model_name="NousResearch/Hermes-4-14B",
|
||||
base_url="http://localhost:9004/v1",
|
||||
api_key="x",
|
||||
num_requests_for_eval=256, # since evaling only on one...
|
||||
server_type="sglang",
|
||||
),
|
||||
]
|
||||
|
||||
|
|
@ -373,17 +373,25 @@ class MathEnv(BaseEnv):
|
|||
if thinking_len < 1024:
|
||||
print("thinking_len is less than 1024, skipping", flush=True)
|
||||
return None, []
|
||||
chat_completions = await self.server.chat_completion(
|
||||
messages=chat,
|
||||
n=self.config.group_size,
|
||||
max_tokens=thinking_len,
|
||||
temperature=1.0,
|
||||
top_p=0.95,
|
||||
)
|
||||
# Use managed server for automatic token/logprob tracking
|
||||
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
|
||||
chat_completions = await managed.chat_completion(
|
||||
messages=chat,
|
||||
n=self.config.group_size,
|
||||
max_tokens=thinking_len,
|
||||
temperature=1.0,
|
||||
top_p=0.95,
|
||||
)
|
||||
# Get tracked sequences with aligned tokens and logprobs
|
||||
state = managed.get_state()
|
||||
nodes = state["nodes"]
|
||||
|
||||
print("Finished generation", flush=True)
|
||||
to_score = list()
|
||||
to_backlog = list()
|
||||
for i, chat_completion in enumerate(chat_completions.choices):
|
||||
for i, (chat_completion, node) in enumerate(
|
||||
zip(chat_completions.choices, nodes)
|
||||
):
|
||||
messages = (
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt},
|
||||
|
|
@ -394,6 +402,9 @@ class MathEnv(BaseEnv):
|
|||
messages,
|
||||
item[1],
|
||||
chat_completion.finish_reason,
|
||||
node.tokens,
|
||||
node.masked_tokens,
|
||||
node.logprobs,
|
||||
)
|
||||
)
|
||||
print("scoring normal", flush=True)
|
||||
|
|
@ -447,6 +458,14 @@ class MathEnv(BaseEnv):
|
|||
]
|
||||
),
|
||||
most_dissimilar_score,
|
||||
# Pass tokens/masks/logprobs for solution 1
|
||||
to_postprocess["tokens"][most_dissimilar[0]],
|
||||
to_postprocess["masks"][most_dissimilar[0]],
|
||||
to_postprocess["inference_logprobs"][most_dissimilar[0]],
|
||||
# Pass tokens/masks/logprobs for solution 2
|
||||
to_postprocess["tokens"][most_dissimilar[1]],
|
||||
to_postprocess["masks"][most_dissimilar[1]],
|
||||
to_postprocess["inference_logprobs"][most_dissimilar[1]],
|
||||
)
|
||||
)
|
||||
print(
|
||||
|
|
@ -571,12 +590,16 @@ class MathEnv(BaseEnv):
|
|||
group = item[3]
|
||||
scores = item[4]
|
||||
finish_reasons = item[5]
|
||||
tokens_list = item[6]
|
||||
masks_list = item[7]
|
||||
logprobs_list = item[8]
|
||||
to_postprocess = ScoredDataGroup()
|
||||
to_postprocess["tokens"] = list()
|
||||
to_postprocess["masks"] = list()
|
||||
to_postprocess["scores"] = list()
|
||||
to_postprocess["overrides"] = list()
|
||||
to_postprocess["messages"] = list()
|
||||
to_postprocess["inference_logprobs"] = list()
|
||||
for i in range(len(group)):
|
||||
# convert from frozen set to dict
|
||||
conv = [dict(x) for x in group[i]]
|
||||
|
|
@ -594,21 +617,21 @@ class MathEnv(BaseEnv):
|
|||
>= self.config.num_rollouts_to_keep
|
||||
):
|
||||
self.selfcorrect_rollouts.pop(0)
|
||||
out_dict = tokenize_for_trainer(
|
||||
tokenizer=self.tokenizer,
|
||||
chat=conv,
|
||||
finish_reason=finish_reasons[i],
|
||||
include_messages=True,
|
||||
)
|
||||
to_postprocess["tokens"].append(out_dict["tokens"])
|
||||
to_postprocess["masks"].append(out_dict["masks"])
|
||||
# Use pre-computed tokens/masks/logprobs from managed_server
|
||||
assert len(logprobs_list[i]) == len(
|
||||
masks_list[i]
|
||||
), f"{len(logprobs_list[i])}, {len(masks_list[i])} mismatch"
|
||||
to_postprocess["tokens"].append(tokens_list[i])
|
||||
to_postprocess["masks"].append(masks_list[i])
|
||||
to_postprocess["inference_logprobs"].append(logprobs_list[i])
|
||||
to_postprocess["scores"].append(scores[i])
|
||||
to_postprocess["overrides"].append(dict())
|
||||
if (finish_reasons[i] == "length") and (
|
||||
self.config.mask_too_long_completions
|
||||
):
|
||||
to_postprocess["overrides"][-1]["set_advantage_to_zero"] = True
|
||||
to_postprocess["messages"].append(out_dict["messages"])
|
||||
# Convert back to messages format for consistency
|
||||
to_postprocess["messages"].append(conv)
|
||||
print("selfcorrect done, sending batch off")
|
||||
return to_postprocess, []
|
||||
else:
|
||||
|
|
@ -621,13 +644,20 @@ class MathEnv(BaseEnv):
|
|||
scores["scores"] = list()
|
||||
scores["overrides"] = list()
|
||||
scores["messages"] = list()
|
||||
scores["inference_logprobs"] = list()
|
||||
gold = rollout_group_data[0][1]
|
||||
loop = asyncio.get_event_loop()
|
||||
random.shuffle(rollout_group_data)
|
||||
for item in rollout_group_data:
|
||||
resp = item[0][-1]["content"].split("</think>")[-1]
|
||||
scores["overrides"].append(dict())
|
||||
if item[2] == "length":
|
||||
# Extract pre-computed data from managed_server
|
||||
tokens = item[3]
|
||||
masks = item[4]
|
||||
logprobs = item[5]
|
||||
finish_reason = item[2]
|
||||
|
||||
if finish_reason == "length":
|
||||
reward = False
|
||||
if self.config.mask_too_long_completions:
|
||||
scores["overrides"][-1]["set_advantage_to_zero"] = True
|
||||
|
|
@ -636,19 +666,17 @@ class MathEnv(BaseEnv):
|
|||
reward = await task
|
||||
if reward is None:
|
||||
return None
|
||||
out_dict = tokenize_for_trainer(
|
||||
tokenizer=self.tokenizer,
|
||||
chat=item[0],
|
||||
finish_reason=item[2],
|
||||
include_messages=True,
|
||||
)
|
||||
tokens = out_dict["tokens"]
|
||||
masks = out_dict["masks"]
|
||||
messages = out_dict["messages"]
|
||||
|
||||
assert len(logprobs) == len(
|
||||
masks
|
||||
), f"{len(logprobs)}, {len(masks)} mismatch"
|
||||
# Use messages from item[0]
|
||||
messages = item[0]
|
||||
|
||||
# remove obviously bad examples
|
||||
if len([1 for i in masks if i != -100]) < 10:
|
||||
continue
|
||||
if item[2] == "length":
|
||||
if finish_reason == "length":
|
||||
# Note we set it here so we can filter out the examples that are too long
|
||||
# for the Judge loop. IF you set the config to not do this we fix it
|
||||
# in the collect_trajectories_normal function.
|
||||
|
|
@ -657,6 +685,7 @@ class MathEnv(BaseEnv):
|
|||
scores["masks"].append(masks)
|
||||
scores["scores"].append(1.0 if reward else -1.0)
|
||||
scores["messages"].append(messages)
|
||||
scores["inference_logprobs"].append(logprobs)
|
||||
if len(scores["tokens"]) >= self.config.group_size:
|
||||
break
|
||||
if any([score == 1.0 for score in scores["scores"]]):
|
||||
|
|
@ -729,39 +758,48 @@ class MathEnv(BaseEnv):
|
|||
solution2=item[3][-1]["content"].split("</think>")[-1],
|
||||
)
|
||||
print("Sending to server")
|
||||
chat = [
|
||||
chat_fwd = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt_fwd},
|
||||
]
|
||||
max_token_length = self.config.max_token_length - len(
|
||||
self.tokenizer.apply_chat_template(chat, add_generation_prompt=True)
|
||||
)
|
||||
chat_completions_fwd = self.server.chat_completion(
|
||||
messages=chat,
|
||||
n=3,
|
||||
max_tokens=max_token_length,
|
||||
temperature=1.0,
|
||||
top_p=0.95,
|
||||
max_token_length_fwd = self.config.max_token_length - len(
|
||||
self.tokenizer.apply_chat_template(chat_fwd, add_generation_prompt=True)
|
||||
)
|
||||
|
||||
print("Sending to server")
|
||||
# Should be the same token length as the fwd but tokenizers are cursed
|
||||
chat = [
|
||||
chat_bwd = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt_bwd},
|
||||
]
|
||||
max_token_length = self.config.max_token_length - len(
|
||||
self.tokenizer.apply_chat_template(chat, add_generation_prompt=True)
|
||||
)
|
||||
chat_completions_bwd = self.server.chat_completion(
|
||||
messages=chat,
|
||||
n=3,
|
||||
max_tokens=self.config.max_token_length,
|
||||
temperature=1.0,
|
||||
top_p=0.95,
|
||||
max_token_length_bwd = self.config.max_token_length - len(
|
||||
self.tokenizer.apply_chat_template(chat_bwd, add_generation_prompt=True)
|
||||
)
|
||||
|
||||
# Use managed server for both forward and backward completions
|
||||
async def get_fwd_completion():
|
||||
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
|
||||
return await managed.chat_completion(
|
||||
messages=chat_fwd,
|
||||
n=3,
|
||||
max_tokens=max_token_length_fwd,
|
||||
temperature=1.0,
|
||||
top_p=0.95,
|
||||
)
|
||||
|
||||
async def get_bwd_completion():
|
||||
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
|
||||
return await managed.chat_completion(
|
||||
messages=chat_bwd,
|
||||
n=3,
|
||||
max_tokens=max_token_length_bwd,
|
||||
temperature=1.0,
|
||||
top_p=0.95,
|
||||
)
|
||||
|
||||
print("Gathering completions")
|
||||
chat_completions_fwd, chat_completions_bwd = await asyncio.gather(
|
||||
chat_completions_fwd, chat_completions_bwd
|
||||
get_fwd_completion(), get_bwd_completion()
|
||||
)
|
||||
print("Grabbed RLAIF completions")
|
||||
# Check for correct answers
|
||||
|
|
@ -810,25 +848,35 @@ class MathEnv(BaseEnv):
|
|||
to_postprocess["scores"] = list()
|
||||
to_postprocess["overrides"] = list()
|
||||
to_postprocess["messages"] = list()
|
||||
to_postprocess["inference_logprobs"] = list()
|
||||
# Extract pre-computed tokens/masks/logprobs from backlog
|
||||
tokens_1 = item[6]
|
||||
masks_1 = item[7]
|
||||
logprobs_1 = item[8]
|
||||
tokens_2 = item[9]
|
||||
masks_2 = item[10]
|
||||
logprobs_2 = item[11]
|
||||
# Add assertions to verify data integrity
|
||||
assert len(logprobs_1) == len(
|
||||
masks_1
|
||||
), f"{len(logprobs_1)}, {len(masks_1)} mismatch"
|
||||
assert len(logprobs_2) == len(
|
||||
masks_2
|
||||
), f"{len(logprobs_2)}, {len(masks_2)} mismatch"
|
||||
# add the first message in
|
||||
out_dict = tokenize_for_trainer(
|
||||
tokenizer=self.tokenizer, chat=item[3], include_messages=True
|
||||
)
|
||||
tokens = out_dict["tokens"]
|
||||
masks = out_dict["masks"]
|
||||
to_postprocess["tokens"].append(tokens)
|
||||
to_postprocess["masks"].append(masks)
|
||||
to_postprocess["tokens"].append(tokens_1)
|
||||
to_postprocess["masks"].append(masks_1)
|
||||
to_postprocess["scores"].append(1.0 if score_1 > score_2 else -1.0)
|
||||
to_postprocess["messages"].append(out_dict["messages"])
|
||||
out_dict = tokenize_for_trainer(
|
||||
tokenizer=self.tokenizer, chat=item[4], include_messages=True
|
||||
)
|
||||
tokens = out_dict["tokens"]
|
||||
masks = out_dict["masks"]
|
||||
to_postprocess["tokens"].append(tokens)
|
||||
to_postprocess["masks"].append(masks)
|
||||
to_postprocess["messages"].append(item[3]) # Already converted to dicts
|
||||
to_postprocess["inference_logprobs"].append(logprobs_1)
|
||||
to_postprocess["overrides"].append(dict())
|
||||
# add the second message in
|
||||
to_postprocess["tokens"].append(tokens_2)
|
||||
to_postprocess["masks"].append(masks_2)
|
||||
to_postprocess["scores"].append(1.0 if score_2 > score_1 else -1.0)
|
||||
to_postprocess["messages"].append(out_dict["messages"])
|
||||
to_postprocess["messages"].append(item[4]) # Already converted to dicts
|
||||
to_postprocess["inference_logprobs"].append(logprobs_2)
|
||||
to_postprocess["overrides"].append(dict())
|
||||
to_postprocess["group_overrides"] = {
|
||||
"group_size": 2,
|
||||
}
|
||||
|
|
@ -848,13 +896,19 @@ class MathEnv(BaseEnv):
|
|||
max_token_length = self.config.max_token_length - len(
|
||||
self.tokenizer.apply_chat_template(chat, add_generation_prompt=True)
|
||||
)
|
||||
chat_completions = await self.server.chat_completion(
|
||||
messages=chat,
|
||||
n=self.config.group_size,
|
||||
max_tokens=max_token_length,
|
||||
temperature=1.0,
|
||||
top_p=0.95,
|
||||
)
|
||||
# Use managed server for judge completions
|
||||
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
|
||||
chat_completions = await managed.chat_completion(
|
||||
messages=chat,
|
||||
n=self.config.group_size,
|
||||
max_tokens=max_token_length,
|
||||
temperature=1.0,
|
||||
top_p=0.95,
|
||||
)
|
||||
# Get tracked sequences with aligned tokens and logprobs
|
||||
state = managed.get_state()
|
||||
nodes = state["nodes"]
|
||||
|
||||
is_correct = [
|
||||
(
|
||||
chat_completion.message.content.split("</think>")[-1]
|
||||
|
|
@ -878,25 +932,29 @@ class MathEnv(BaseEnv):
|
|||
scores["scores"] = []
|
||||
scores["overrides"] = []
|
||||
scores["messages"] = []
|
||||
scores["inference_logprobs"] = []
|
||||
for_table = []
|
||||
for i, chat_completion in enumerate(chat_completions.choices):
|
||||
out_dict = tokenize_for_trainer(
|
||||
tokenizer=self.tokenizer,
|
||||
chat=[
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt},
|
||||
{"role": "assistant", "content": chat_completion.message.content},
|
||||
],
|
||||
include_messages=True,
|
||||
)
|
||||
tokens = out_dict["tokens"]
|
||||
masks = out_dict["masks"]
|
||||
messages = out_dict["messages"]
|
||||
for i, (chat_completion, node) in enumerate(
|
||||
zip(chat_completions.choices, nodes)
|
||||
):
|
||||
# Extract pre-computed data from managed_server
|
||||
tokens = node.tokens
|
||||
masks = node.masked_tokens
|
||||
logprobs = node.logprobs
|
||||
messages = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt},
|
||||
{"role": "assistant", "content": chat_completion.message.content},
|
||||
]
|
||||
assert len(logprobs) == len(
|
||||
masks
|
||||
), f"{len(logprobs)}, {len(masks)} mismatch"
|
||||
if not is_correct[i]:
|
||||
scores["tokens"].append(tokens)
|
||||
scores["masks"].append(masks)
|
||||
scores["scores"].append(-1.0)
|
||||
scores["messages"].append(messages)
|
||||
scores["inference_logprobs"].append(logprobs)
|
||||
scores["overrides"].append(dict())
|
||||
if (chat_completion.finish_reason == "length") and (
|
||||
self.config.mask_too_long_completions
|
||||
|
|
@ -932,20 +990,31 @@ class MathEnv(BaseEnv):
|
|||
retry_messages, add_generation_prompt=True
|
||||
)
|
||||
)
|
||||
retry_chat_completions = await self.server.chat_completion(
|
||||
messages=retry_messages,
|
||||
n=self.config.group_size,
|
||||
max_tokens=max_token_length,
|
||||
temperature=1.0,
|
||||
top_p=0.95,
|
||||
)
|
||||
# Use managed server for retry completions
|
||||
async with self.server.managed_server(
|
||||
tokenizer=self.tokenizer
|
||||
) as managed:
|
||||
retry_chat_completions = await managed.chat_completion(
|
||||
messages=retry_messages,
|
||||
n=self.config.group_size,
|
||||
max_tokens=max_token_length,
|
||||
temperature=1.0,
|
||||
top_p=0.95,
|
||||
)
|
||||
# Get tracked sequences with aligned tokens and logprobs
|
||||
retry_state = managed.get_state()
|
||||
retry_nodes = retry_state["nodes"]
|
||||
|
||||
print("Gathering completions")
|
||||
scoring_data = []
|
||||
backlog_scores = []
|
||||
backlog_reasons = []
|
||||
backlog_messages = []
|
||||
for j, retry_chat_completion in enumerate(
|
||||
retry_chat_completions.choices
|
||||
backlog_tokens = []
|
||||
backlog_masks = []
|
||||
backlog_logprobs = []
|
||||
for j, (retry_chat_completion, retry_node) in enumerate(
|
||||
zip(retry_chat_completions.choices, retry_nodes)
|
||||
):
|
||||
print(f"Scoring generation {j} for retry...")
|
||||
backlog_messages.append(
|
||||
|
|
@ -962,6 +1031,10 @@ class MathEnv(BaseEnv):
|
|||
)
|
||||
)
|
||||
backlog_reasons.append(retry_chat_completion.finish_reason)
|
||||
# Store tokens, masks, and logprobs from managed_server
|
||||
backlog_tokens.append(retry_node.tokens)
|
||||
backlog_masks.append(retry_node.masked_tokens)
|
||||
backlog_logprobs.append(retry_node.logprobs)
|
||||
if retry_chat_completion.finish_reason == "length":
|
||||
scoring_data.append(0)
|
||||
backlog_scores.append(0)
|
||||
|
|
@ -998,10 +1071,18 @@ class MathEnv(BaseEnv):
|
|||
tuple(backlog_messages),
|
||||
tuple(backlog_scores),
|
||||
tuple(backlog_reasons),
|
||||
tuple(backlog_tokens),
|
||||
tuple(backlog_masks),
|
||||
tuple(backlog_logprobs),
|
||||
)
|
||||
)
|
||||
print(f"Sending to selfcorrect, {len(to_backlog)} in backlog")
|
||||
scores["scores"].append(sum(scoring_data) / len(scoring_data))
|
||||
scores["tokens"].append(tokens)
|
||||
scores["masks"].append(masks)
|
||||
scores["messages"].append(messages)
|
||||
scores["inference_logprobs"].append(logprobs)
|
||||
scores["overrides"].append(dict())
|
||||
self.judge_success_rate.append(
|
||||
sum(scoring_data) / len(scoring_data)
|
||||
)
|
||||
|
|
@ -1012,6 +1093,7 @@ class MathEnv(BaseEnv):
|
|||
scores["tokens"].append(tokens)
|
||||
scores["masks"].append(masks)
|
||||
scores["messages"].append(messages)
|
||||
scores["inference_logprobs"].append(logprobs)
|
||||
scores["overrides"].append(dict())
|
||||
if all([score == 1.0 for score in scores["scores"]]) and (
|
||||
random.random() < self.config.percent_length_penalty
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue