mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
support managedserver in mcqa thinking
This commit is contained in:
parent
f0fee7fba6
commit
e28297b625
1 changed files with 31 additions and 24 deletions
|
|
@ -14,7 +14,6 @@ from atroposlib.envs.base import (
|
|||
Item,
|
||||
ScoredDataGroup,
|
||||
)
|
||||
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
|
||||
|
||||
system_prompt = (
|
||||
"You are a deep thinking AI, you may use extremely long chains of thought to deeply consider the "
|
||||
|
|
@ -168,13 +167,17 @@ class MCQAThinkingEnv(BaseEnv):
|
|||
messages, add_generation_prompt=True, tokenize=False
|
||||
)
|
||||
|
||||
# Get completions from the model using completion() instead of chat_completion()
|
||||
completions = await self.server.completion(
|
||||
prompt=prompt,
|
||||
n=self.config.group_size,
|
||||
max_tokens=1024 * 15,
|
||||
temperature=1.0, # Using temperature to get diverse responses
|
||||
)
|
||||
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
|
||||
# Get completions from the model using completion() instead of chat_completion()
|
||||
completions = await managed.completion(
|
||||
prompt=prompt,
|
||||
n=self.config.group_size,
|
||||
max_tokens=1024 * 15,
|
||||
temperature=1.0, # Using temperature to get diverse responses
|
||||
)
|
||||
|
||||
state = managed.get_state()
|
||||
nodes = state["nodes"]
|
||||
|
||||
to_score = list()
|
||||
|
||||
|
|
@ -191,12 +194,15 @@ class MCQAThinkingEnv(BaseEnv):
|
|||
|
||||
# Add to scoring queue with expected answer, ground truth text, and stop reason
|
||||
to_score.append(
|
||||
(
|
||||
tuple(trajectory_messages),
|
||||
item[1], # Letter (A, B, C, D)
|
||||
item[2], # Include the answer_string/ground_truth_text
|
||||
completion_choice.finish_reason, # Add the stop reason
|
||||
)
|
||||
{
|
||||
"messages": tuple(trajectory_messages),
|
||||
"expected_answer": item[1], # Letter (A, B, C, D)
|
||||
"ground_truth_text": item[2], # Include the answer_string/ground_truth_text
|
||||
"finish_reason": completion_choice.finish_reason, # Add the stop reason
|
||||
"tokens": nodes[i].tokens,
|
||||
"masks": nodes[i].masked_tokens,
|
||||
"logprobs": nodes[i].logprobs,
|
||||
}
|
||||
)
|
||||
|
||||
# Call score to get the scored data
|
||||
|
|
@ -313,18 +319,19 @@ class MCQAThinkingEnv(BaseEnv):
|
|||
scores["tokens"] = list()
|
||||
scores["masks"] = list()
|
||||
scores["scores"] = list()
|
||||
scores["inference_logprobs"] = list()
|
||||
|
||||
# Get the expected answer letter
|
||||
expected_answer = rollout_group_data[0][1] # Letter A, B, C, D
|
||||
ground_truth_text = rollout_group_data[0][2]
|
||||
expected_answer = rollout_group_data[0]["expected_answer"] # Letter A, B, C, D
|
||||
ground_truth_text = rollout_group_data[0]["ground_truth_text"]
|
||||
|
||||
# Shuffle to avoid bias in selection
|
||||
random.shuffle(rollout_group_data)
|
||||
|
||||
for item in rollout_group_data:
|
||||
# Extract the model's response
|
||||
model_response = item[0][-1]["content"]
|
||||
stop_reason = item[3] # Get the stop reason
|
||||
model_response = item["messages"][-1]["content"]
|
||||
stop_reason = item["finish_reason"] # Get the stop reason
|
||||
|
||||
# If the response was cut off due to length, give it a score of 0
|
||||
if stop_reason == "length":
|
||||
|
|
@ -343,10 +350,9 @@ class MCQAThinkingEnv(BaseEnv):
|
|||
else:
|
||||
reward = 0 # Wrong answer gets 0 reward
|
||||
|
||||
# Tokenize the conversation for learning
|
||||
out_dict = tokenize_for_trainer(self.tokenizer, item[0])
|
||||
tokens = out_dict["tokens"]
|
||||
masks = out_dict["masks"]
|
||||
tokens = item["tokens"]
|
||||
masks = item["masks"]
|
||||
logprobs = item["logprobs"]
|
||||
|
||||
# Remove examples with insufficient context
|
||||
if len([1 for i in masks if i != -100]) < 10:
|
||||
|
|
@ -354,6 +360,7 @@ class MCQAThinkingEnv(BaseEnv):
|
|||
|
||||
scores["tokens"].append(tokens)
|
||||
scores["masks"].append(masks)
|
||||
scores["inference_logprobs"].append(logprobs)
|
||||
scores["scores"].append(1.0 if reward else -1.0)
|
||||
|
||||
# Break once we have enough examples
|
||||
|
|
@ -448,8 +455,8 @@ class MCQAThinkingEnv(BaseEnv):
|
|||
(
|
||||
self.tokenizer.decode(scored_data["tokens"][i]),
|
||||
scored_data["scores"][i],
|
||||
item[1],
|
||||
item[2],
|
||||
item[1] if isinstance(item, tuple) else item["expected_answer"],
|
||||
item[2] if isinstance(item, tuple) else item["ground_truth_text"],
|
||||
)
|
||||
for i in range(num_keep)
|
||||
]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue