mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-23 16:54:56 +00:00
convert answer format env to use managedserver
This commit is contained in:
parent
ae101ea8e4
commit
9034d4c78e
1 changed files with 39 additions and 51 deletions
|
|
@ -84,7 +84,6 @@ from atroposlib.envs.base import (
|
|||
Item,
|
||||
ScoredDataGroup,
|
||||
)
|
||||
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
|
||||
|
||||
|
||||
class AnswerFormat(Enum):
|
||||
|
|
@ -3507,7 +3506,7 @@ class AnswerFormatEnv(BaseEnv):
|
|||
|
||||
async def score(
|
||||
self,
|
||||
rollout_group_data: List[Tuple[Tuple[Dict[str, str], ...], Dict[str, Any]]],
|
||||
rollout_group_data: List[Dict[str, Any]],
|
||||
) -> Optional[ScoredDataGroup]:
|
||||
"""Score rollouts based on format adherence."""
|
||||
if self.debug_logging:
|
||||
|
|
@ -3517,11 +3516,12 @@ class AnswerFormatEnv(BaseEnv):
|
|||
scores_obj["tokens"] = list()
|
||||
scores_obj["masks"] = list()
|
||||
scores_obj["scores"] = list()
|
||||
scores_obj["inference_logprobs"] = list()
|
||||
|
||||
if not rollout_group_data:
|
||||
return None
|
||||
|
||||
dataset_item = rollout_group_data[0][1]
|
||||
dataset_item = rollout_group_data[0]["dataset_item"]
|
||||
selected_format = dataset_item["selected_format"]
|
||||
format_name = selected_format.value
|
||||
|
||||
|
|
@ -3531,7 +3531,12 @@ class AnswerFormatEnv(BaseEnv):
|
|||
random.shuffle(rollout_group_data)
|
||||
failed_rollouts_this_group = []
|
||||
|
||||
for item_messages, _ in rollout_group_data:
|
||||
for rollout_item in rollout_group_data:
|
||||
item_messages = rollout_item["messages"]
|
||||
tokens = rollout_item["tokens"]
|
||||
masks = rollout_item["masks"]
|
||||
logprobs = rollout_item["logprobs"]
|
||||
|
||||
messages_as_dicts = [dict(fs_message) for fs_message in item_messages]
|
||||
model_response_text = messages_as_dicts[-1]["content"]
|
||||
|
||||
|
|
@ -3553,43 +3558,13 @@ class AnswerFormatEnv(BaseEnv):
|
|||
self.format_success_counts.get(format_name, 0) + 1
|
||||
)
|
||||
|
||||
try:
|
||||
# Validate message format for tokenization
|
||||
for msg_idx, msg in enumerate(messages_as_dicts):
|
||||
if not isinstance(msg, dict):
|
||||
if self.debug_logging:
|
||||
self.logger.error(
|
||||
f"Message {msg_idx} is not a dict: {type(msg)}"
|
||||
)
|
||||
continue
|
||||
if "role" not in msg or "content" not in msg:
|
||||
if self.debug_logging:
|
||||
self.logger.error(
|
||||
f"Message {msg_idx} missing required keys: {msg.keys()}"
|
||||
)
|
||||
continue
|
||||
if not isinstance(msg["content"], str):
|
||||
msg["content"] = str(msg["content"])
|
||||
|
||||
out_dict = tokenize_for_trainer(
|
||||
self.tokenizer,
|
||||
messages_as_dicts,
|
||||
include_messages=self.config.include_messages,
|
||||
)
|
||||
tokens = out_dict["tokens"]
|
||||
masks = out_dict["masks"]
|
||||
|
||||
except Exception as e:
|
||||
if self.debug_logging:
|
||||
self.logger.error(f"Tokenization failed: {e}")
|
||||
continue
|
||||
|
||||
# Remove examples with insufficient context
|
||||
if len([1 for m_val in masks if m_val != -100]) < 10:
|
||||
continue
|
||||
|
||||
scores_obj["tokens"].append(tokens)
|
||||
scores_obj["masks"].append(masks)
|
||||
scores_obj["inference_logprobs"].append(logprobs)
|
||||
scores_obj["scores"].append(reward)
|
||||
|
||||
# Track failed rollouts for debugging
|
||||
|
|
@ -3742,21 +3717,33 @@ class AnswerFormatEnv(BaseEnv):
|
|||
messages_for_api, add_generation_prompt=True, tokenize=False
|
||||
)
|
||||
|
||||
completions = await self.server.completion(
|
||||
prompt=prompt_str,
|
||||
n=self.config.group_size,
|
||||
max_tokens=self.config.max_token_length,
|
||||
temperature=0.9,
|
||||
)
|
||||
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
|
||||
completions = await managed.completion(
|
||||
prompt=prompt_str,
|
||||
n=self.config.group_size,
|
||||
max_tokens=self.config.max_token_length,
|
||||
temperature=0.9,
|
||||
)
|
||||
|
||||
state = managed.get_state()
|
||||
nodes = state["nodes"]
|
||||
|
||||
to_score_list = []
|
||||
for choice in completions.choices:
|
||||
for i, choice in enumerate(completions.choices):
|
||||
# Create a full message list for this choice
|
||||
current_trajectory_messages = list(prompt_messages_tuple)
|
||||
current_trajectory_messages.append(
|
||||
frozenset({"role": "assistant", "content": choice.text}.items())
|
||||
)
|
||||
to_score_list.append((tuple(current_trajectory_messages), dataset_item))
|
||||
to_score_list.append(
|
||||
{
|
||||
"messages": tuple(current_trajectory_messages),
|
||||
"dataset_item": dataset_item,
|
||||
"tokens": nodes[i].tokens,
|
||||
"masks": nodes[i].masked_tokens,
|
||||
"logprobs": nodes[i].logprobs,
|
||||
}
|
||||
)
|
||||
|
||||
scored_data = await self.score(to_score_list)
|
||||
|
||||
|
|
@ -3776,7 +3763,7 @@ class AnswerFormatEnv(BaseEnv):
|
|||
for i in range(num_scored_rollouts):
|
||||
# Get conversation messages directly from to_score_list like reasoning gym does
|
||||
conversation_messages = [
|
||||
dict(fs_msg) for fs_msg in to_score_list[i][0]
|
||||
dict(fs_msg) for fs_msg in to_score_list[i]["messages"]
|
||||
]
|
||||
score_for_rollout = group_scores[i]
|
||||
|
||||
|
|
@ -3953,13 +3940,14 @@ class AnswerFormatEnv(BaseEnv):
|
|||
messages, add_generation_prompt=True, tokenize=False
|
||||
)
|
||||
|
||||
completion = await self.server.completion(
|
||||
prompt=prompt,
|
||||
n=1,
|
||||
max_tokens=self.config.max_token_length,
|
||||
temperature=0.1,
|
||||
split="eval",
|
||||
)
|
||||
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
|
||||
completion = await managed.completion(
|
||||
prompt=prompt,
|
||||
n=1,
|
||||
max_tokens=self.config.max_token_length,
|
||||
temperature=0.1,
|
||||
split="eval",
|
||||
)
|
||||
|
||||
model_response = completion.choices[0].text
|
||||
extracted_content = self._extract_answer_content(
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue