mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
convert pydantic schema env to use managed server
This commit is contained in:
parent
8cc83db6ee
commit
c4ecc42139
1 changed files with 41 additions and 66 deletions
|
|
@ -69,7 +69,6 @@ from atroposlib.envs.base import (
|
|||
Item,
|
||||
ScoredDataGroup,
|
||||
)
|
||||
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
|
||||
|
||||
# Import editing functionality
|
||||
try:
|
||||
|
|
@ -718,7 +717,7 @@ class PydanticSchemaFollowingEnv(BaseEnv):
|
|||
|
||||
async def score(
|
||||
self,
|
||||
rollout_group_data: List[Tuple[Tuple[Dict[str, str], ...], Dict[str, Any]]],
|
||||
rollout_group_data: List[Dict[str, Any]],
|
||||
) -> Optional[ScoredDataGroup]:
|
||||
"""Score the rollouts based on Pydantic validation or other structural checks."""
|
||||
if self.debug_logging:
|
||||
|
|
@ -729,13 +728,14 @@ class PydanticSchemaFollowingEnv(BaseEnv):
|
|||
scores_obj["masks"] = list()
|
||||
scores_obj["scores"] = list()
|
||||
scores_obj["messages"] = list()
|
||||
scores_obj["inference_logprobs"] = list()
|
||||
|
||||
if not rollout_group_data:
|
||||
if self.debug_logging:
|
||||
self.logger.warning("No rollout data to score")
|
||||
return None
|
||||
|
||||
dataset_item = rollout_group_data[0][1]
|
||||
dataset_item = rollout_group_data[0]["dataset_item"]
|
||||
problem_id = dataset_item.get("problem_id", "N/A")
|
||||
selected_structured_format = dataset_item["selected_structured_format"]
|
||||
selected_container_format = dataset_item["selected_container_format"]
|
||||
|
|
@ -779,7 +779,12 @@ class PydanticSchemaFollowingEnv(BaseEnv):
|
|||
|
||||
random.shuffle(rollout_group_data)
|
||||
|
||||
for i, (item_messages, _) in enumerate(rollout_group_data):
|
||||
for i, rollout_item in enumerate(rollout_group_data):
|
||||
|
||||
item_messages = rollout_item["messages"]
|
||||
tokens = rollout_item["tokens"]
|
||||
masks = rollout_item["masks"]
|
||||
logprobs = rollout_item["logprobs"]
|
||||
|
||||
messages_as_dicts = [dict(fs_message) for fs_message in item_messages]
|
||||
model_response_text = messages_as_dicts[-1]["content"]
|
||||
|
|
@ -906,63 +911,20 @@ class PydanticSchemaFollowingEnv(BaseEnv):
|
|||
f"Rollout {i}: Extraction failed for {selected_structured_format.value} with container {selected_container_format.value}" # noqa: E501
|
||||
)
|
||||
|
||||
try:
|
||||
if not isinstance(messages_as_dicts, list):
|
||||
if self.debug_logging:
|
||||
self.logger.error(
|
||||
f"Expected list for tokenization, got {type(messages_as_dicts)}"
|
||||
)
|
||||
continue
|
||||
for msg_idx, msg in enumerate(messages_as_dicts):
|
||||
if not isinstance(msg, dict):
|
||||
if self.debug_logging:
|
||||
self.logger.error(
|
||||
f"Message {msg_idx} is not a dict: {type(msg)}"
|
||||
)
|
||||
continue # Skip this rollout if message format is incorrect
|
||||
if "role" not in msg or "content" not in msg:
|
||||
if self.debug_logging:
|
||||
self.logger.error(
|
||||
f"Message {msg_idx} missing required keys: {msg.keys()}"
|
||||
)
|
||||
continue # Skip this rollout
|
||||
if not isinstance(msg["content"], str):
|
||||
if self.debug_logging:
|
||||
self.logger.warning(
|
||||
f"Converting content to string for message {msg_idx}"
|
||||
)
|
||||
msg["content"] = str(msg["content"])
|
||||
|
||||
out_dict = tokenize_for_trainer(
|
||||
self.tokenizer,
|
||||
messages_as_dicts,
|
||||
include_messages=self.config.include_messages,
|
||||
)
|
||||
tokens = out_dict["tokens"]
|
||||
masks = out_dict["masks"]
|
||||
|
||||
except Exception as e:
|
||||
if self.debug_logging:
|
||||
self.logger.error(
|
||||
f"Tokenization failed for rollout {i} (problem: {problem_id}): {e}"
|
||||
)
|
||||
self.logger.debug(
|
||||
f"Messages format: {[type(m) for m in messages_as_dicts]}"
|
||||
)
|
||||
continue
|
||||
|
||||
# Remove examples with insufficient context
|
||||
if len([1 for m_val in masks if m_val != -100]) < 10: # Min context length
|
||||
if self.debug_logging:
|
||||
self.logger.debug(
|
||||
f"Skipping rollout {i} (problem: {problem_id}) due to insufficient context length after tokenization." # noqa: E501
|
||||
f"Skipping rollout {i} (problem: {problem_id}) due to insufficient context length." # noqa: E501
|
||||
)
|
||||
continue
|
||||
|
||||
scores_obj["tokens"].append(tokens)
|
||||
scores_obj["masks"].append(masks)
|
||||
scores_obj["inference_logprobs"].append(logprobs)
|
||||
scores_obj["scores"].append(reward)
|
||||
# Store original messages (converted to dicts) if available in out_dict, else the modified ones
|
||||
scores_obj["messages"].append(out_dict.get("messages", messages_as_dicts))
|
||||
# Store original messages (converted to dicts)
|
||||
scores_obj["messages"].append(messages_as_dicts)
|
||||
|
||||
self.percent_correct_buffer.append(1.0 if reward == 1.0 else 0.0)
|
||||
|
||||
|
|
@ -1036,12 +998,16 @@ class PydanticSchemaFollowingEnv(BaseEnv):
|
|||
f"Requesting {self.config.group_size} completions with max_tokens={self.config.max_token_length}, temperature=0.9" # noqa: E501
|
||||
)
|
||||
|
||||
completions = await self.server.completion(
|
||||
prompt=prompt_str,
|
||||
n=self.config.group_size,
|
||||
max_tokens=self.config.max_token_length,
|
||||
temperature=0.9,
|
||||
)
|
||||
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
|
||||
completions = await managed.completion(
|
||||
prompt=prompt_str,
|
||||
n=self.config.group_size,
|
||||
max_tokens=self.config.max_token_length,
|
||||
temperature=0.9,
|
||||
)
|
||||
|
||||
state = managed.get_state()
|
||||
nodes = state["nodes"]
|
||||
|
||||
if self.debug_logging:
|
||||
self.logger.debug(
|
||||
|
|
@ -1060,7 +1026,15 @@ class PydanticSchemaFollowingEnv(BaseEnv):
|
|||
current_trajectory_messages.append(
|
||||
frozenset({"role": "assistant", "content": choice.text}.items())
|
||||
)
|
||||
to_score_list.append((tuple(current_trajectory_messages), dataset_item))
|
||||
to_score_list.append(
|
||||
{
|
||||
"messages": tuple(current_trajectory_messages),
|
||||
"dataset_item": dataset_item,
|
||||
"tokens": nodes[i].tokens,
|
||||
"masks": nodes[i].masked_tokens,
|
||||
"logprobs": nodes[i].logprobs,
|
||||
}
|
||||
)
|
||||
|
||||
scored_data = await self.score(to_score_list)
|
||||
|
||||
|
|
@ -1251,13 +1225,14 @@ class PydanticSchemaFollowingEnv(BaseEnv):
|
|||
f"Eval prompt length for {problem_id}: {len(prompt)} characters"
|
||||
)
|
||||
|
||||
completion = await self.server.completion(
|
||||
prompt=prompt,
|
||||
n=1,
|
||||
max_tokens=self.config.max_token_length,
|
||||
temperature=0.1,
|
||||
split="eval",
|
||||
)
|
||||
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
|
||||
completion = await managed.completion(
|
||||
prompt=prompt,
|
||||
n=1,
|
||||
max_tokens=self.config.max_token_length,
|
||||
temperature=0.1,
|
||||
split="eval",
|
||||
)
|
||||
|
||||
model_response_text = completion.choices[0].text
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue