convert answer format env to use managedserver

This commit is contained in:
teknium 2025-11-14 10:21:24 +00:00
parent ae101ea8e4
commit 9034d4c78e

View file

@ -84,7 +84,6 @@ from atroposlib.envs.base import (
Item,
ScoredDataGroup,
)
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
class AnswerFormat(Enum):
@ -3507,7 +3506,7 @@ class AnswerFormatEnv(BaseEnv):
async def score(
self,
rollout_group_data: List[Tuple[Tuple[Dict[str, str], ...], Dict[str, Any]]],
rollout_group_data: List[Dict[str, Any]],
) -> Optional[ScoredDataGroup]:
"""Score rollouts based on format adherence."""
if self.debug_logging:
@ -3517,11 +3516,12 @@ class AnswerFormatEnv(BaseEnv):
scores_obj["tokens"] = list()
scores_obj["masks"] = list()
scores_obj["scores"] = list()
scores_obj["inference_logprobs"] = list()
if not rollout_group_data:
return None
dataset_item = rollout_group_data[0][1]
dataset_item = rollout_group_data[0]["dataset_item"]
selected_format = dataset_item["selected_format"]
format_name = selected_format.value
@ -3531,7 +3531,12 @@ class AnswerFormatEnv(BaseEnv):
random.shuffle(rollout_group_data)
failed_rollouts_this_group = []
for item_messages, _ in rollout_group_data:
for rollout_item in rollout_group_data:
item_messages = rollout_item["messages"]
tokens = rollout_item["tokens"]
masks = rollout_item["masks"]
logprobs = rollout_item["logprobs"]
messages_as_dicts = [dict(fs_message) for fs_message in item_messages]
model_response_text = messages_as_dicts[-1]["content"]
@ -3553,43 +3558,13 @@ class AnswerFormatEnv(BaseEnv):
self.format_success_counts.get(format_name, 0) + 1
)
try:
# Validate message format for tokenization
for msg_idx, msg in enumerate(messages_as_dicts):
if not isinstance(msg, dict):
if self.debug_logging:
self.logger.error(
f"Message {msg_idx} is not a dict: {type(msg)}"
)
continue
if "role" not in msg or "content" not in msg:
if self.debug_logging:
self.logger.error(
f"Message {msg_idx} missing required keys: {msg.keys()}"
)
continue
if not isinstance(msg["content"], str):
msg["content"] = str(msg["content"])
out_dict = tokenize_for_trainer(
self.tokenizer,
messages_as_dicts,
include_messages=self.config.include_messages,
)
tokens = out_dict["tokens"]
masks = out_dict["masks"]
except Exception as e:
if self.debug_logging:
self.logger.error(f"Tokenization failed: {e}")
continue
# Remove examples with insufficient context
if len([1 for m_val in masks if m_val != -100]) < 10:
continue
scores_obj["tokens"].append(tokens)
scores_obj["masks"].append(masks)
scores_obj["inference_logprobs"].append(logprobs)
scores_obj["scores"].append(reward)
# Track failed rollouts for debugging
@ -3742,21 +3717,33 @@ class AnswerFormatEnv(BaseEnv):
messages_for_api, add_generation_prompt=True, tokenize=False
)
completions = await self.server.completion(
prompt=prompt_str,
n=self.config.group_size,
max_tokens=self.config.max_token_length,
temperature=0.9,
)
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
completions = await managed.completion(
prompt=prompt_str,
n=self.config.group_size,
max_tokens=self.config.max_token_length,
temperature=0.9,
)
state = managed.get_state()
nodes = state["nodes"]
to_score_list = []
for choice in completions.choices:
for i, choice in enumerate(completions.choices):
# Create a full message list for this choice
current_trajectory_messages = list(prompt_messages_tuple)
current_trajectory_messages.append(
frozenset({"role": "assistant", "content": choice.text}.items())
)
to_score_list.append((tuple(current_trajectory_messages), dataset_item))
to_score_list.append(
{
"messages": tuple(current_trajectory_messages),
"dataset_item": dataset_item,
"tokens": nodes[i].tokens,
"masks": nodes[i].masked_tokens,
"logprobs": nodes[i].logprobs,
}
)
scored_data = await self.score(to_score_list)
@ -3776,7 +3763,7 @@ class AnswerFormatEnv(BaseEnv):
for i in range(num_scored_rollouts):
# Get conversation messages directly from to_score_list like reasoning gym does
conversation_messages = [
dict(fs_msg) for fs_msg in to_score_list[i][0]
dict(fs_msg) for fs_msg in to_score_list[i]["messages"]
]
score_for_rollout = group_scores[i]
@ -3953,13 +3940,14 @@ class AnswerFormatEnv(BaseEnv):
messages, add_generation_prompt=True, tokenize=False
)
completion = await self.server.completion(
prompt=prompt,
n=1,
max_tokens=self.config.max_token_length,
temperature=0.1,
split="eval",
)
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
completion = await managed.completion(
prompt=prompt,
n=1,
max_tokens=self.config.max_token_length,
temperature=0.1,
split="eval",
)
model_response = completion.choices[0].text
extracted_content = self._extract_answer_content(