support managedserver in mcqa thinking

This commit is contained in:
teknium 2025-11-14 00:10:04 +00:00
parent f0fee7fba6
commit e28297b625

View file

@ -14,7 +14,6 @@ from atroposlib.envs.base import (
Item, Item,
ScoredDataGroup, ScoredDataGroup,
) )
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
system_prompt = ( system_prompt = (
"You are a deep thinking AI, you may use extremely long chains of thought to deeply consider the " "You are a deep thinking AI, you may use extremely long chains of thought to deeply consider the "
@ -168,13 +167,17 @@ class MCQAThinkingEnv(BaseEnv):
messages, add_generation_prompt=True, tokenize=False messages, add_generation_prompt=True, tokenize=False
) )
# Get completions from the model using completion() instead of chat_completion() async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
completions = await self.server.completion( # Get completions from the model using completion() instead of chat_completion()
prompt=prompt, completions = await managed.completion(
n=self.config.group_size, prompt=prompt,
max_tokens=1024 * 15, n=self.config.group_size,
temperature=1.0, # Using temperature to get diverse responses max_tokens=1024 * 15,
) temperature=1.0, # Using temperature to get diverse responses
)
state = managed.get_state()
nodes = state["nodes"]
to_score = list() to_score = list()
@ -191,12 +194,15 @@ class MCQAThinkingEnv(BaseEnv):
# Add to scoring queue with expected answer, ground truth text, and stop reason # Add to scoring queue with expected answer, ground truth text, and stop reason
to_score.append( to_score.append(
( {
tuple(trajectory_messages), "messages": tuple(trajectory_messages),
item[1], # Letter (A, B, C, D) "expected_answer": item[1], # Letter (A, B, C, D)
item[2], # Include the answer_string/ground_truth_text "ground_truth_text": item[2], # Include the answer_string/ground_truth_text
completion_choice.finish_reason, # Add the stop reason "finish_reason": completion_choice.finish_reason, # Add the stop reason
) "tokens": nodes[i].tokens,
"masks": nodes[i].masked_tokens,
"logprobs": nodes[i].logprobs,
}
) )
# Call score to get the scored data # Call score to get the scored data
@ -313,18 +319,19 @@ class MCQAThinkingEnv(BaseEnv):
scores["tokens"] = list() scores["tokens"] = list()
scores["masks"] = list() scores["masks"] = list()
scores["scores"] = list() scores["scores"] = list()
scores["inference_logprobs"] = list()
# Get the expected answer letter # Get the expected answer letter
expected_answer = rollout_group_data[0][1] # Letter A, B, C, D expected_answer = rollout_group_data[0]["expected_answer"] # Letter A, B, C, D
ground_truth_text = rollout_group_data[0][2] ground_truth_text = rollout_group_data[0]["ground_truth_text"]
# Shuffle to avoid bias in selection # Shuffle to avoid bias in selection
random.shuffle(rollout_group_data) random.shuffle(rollout_group_data)
for item in rollout_group_data: for item in rollout_group_data:
# Extract the model's response # Extract the model's response
model_response = item[0][-1]["content"] model_response = item["messages"][-1]["content"]
stop_reason = item[3] # Get the stop reason stop_reason = item["finish_reason"] # Get the stop reason
# If the response was cut off due to length, give it a score of 0 # If the response was cut off due to length, give it a score of 0
if stop_reason == "length": if stop_reason == "length":
@ -343,10 +350,9 @@ class MCQAThinkingEnv(BaseEnv):
else: else:
reward = 0 # Wrong answer gets 0 reward reward = 0 # Wrong answer gets 0 reward
# Tokenize the conversation for learning tokens = item["tokens"]
out_dict = tokenize_for_trainer(self.tokenizer, item[0]) masks = item["masks"]
tokens = out_dict["tokens"] logprobs = item["logprobs"]
masks = out_dict["masks"]
# Remove examples with insufficient context # Remove examples with insufficient context
if len([1 for i in masks if i != -100]) < 10: if len([1 for i in masks if i != -100]) < 10:
@ -354,6 +360,7 @@ class MCQAThinkingEnv(BaseEnv):
scores["tokens"].append(tokens) scores["tokens"].append(tokens)
scores["masks"].append(masks) scores["masks"].append(masks)
scores["inference_logprobs"].append(logprobs)
scores["scores"].append(1.0 if reward else -1.0) scores["scores"].append(1.0 if reward else -1.0)
# Break once we have enough examples # Break once we have enough examples
@ -448,8 +455,8 @@ class MCQAThinkingEnv(BaseEnv):
( (
self.tokenizer.decode(scored_data["tokens"][i]), self.tokenizer.decode(scored_data["tokens"][i]),
scored_data["scores"][i], scored_data["scores"][i],
item[1], item[1] if isinstance(item, tuple) else item["expected_answer"],
item[2], item[2] if isinstance(item, tuple) else item["ground_truth_text"],
) )
for i in range(num_keep) for i in range(num_keep)
] ]