convert pydantic schema env to use managed server

This commit is contained in:
teknium 2025-11-14 10:09:43 +00:00
parent 8cc83db6ee
commit c4ecc42139

View file

@ -69,7 +69,6 @@ from atroposlib.envs.base import (
Item, Item,
ScoredDataGroup, ScoredDataGroup,
) )
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
# Import editing functionality # Import editing functionality
try: try:
@ -718,7 +717,7 @@ class PydanticSchemaFollowingEnv(BaseEnv):
async def score( async def score(
self, self,
rollout_group_data: List[Tuple[Tuple[Dict[str, str], ...], Dict[str, Any]]], rollout_group_data: List[Dict[str, Any]],
) -> Optional[ScoredDataGroup]: ) -> Optional[ScoredDataGroup]:
"""Score the rollouts based on Pydantic validation or other structural checks.""" """Score the rollouts based on Pydantic validation or other structural checks."""
if self.debug_logging: if self.debug_logging:
@ -729,13 +728,14 @@ class PydanticSchemaFollowingEnv(BaseEnv):
scores_obj["masks"] = list() scores_obj["masks"] = list()
scores_obj["scores"] = list() scores_obj["scores"] = list()
scores_obj["messages"] = list() scores_obj["messages"] = list()
scores_obj["inference_logprobs"] = list()
if not rollout_group_data: if not rollout_group_data:
if self.debug_logging: if self.debug_logging:
self.logger.warning("No rollout data to score") self.logger.warning("No rollout data to score")
return None return None
dataset_item = rollout_group_data[0][1] dataset_item = rollout_group_data[0]["dataset_item"]
problem_id = dataset_item.get("problem_id", "N/A") problem_id = dataset_item.get("problem_id", "N/A")
selected_structured_format = dataset_item["selected_structured_format"] selected_structured_format = dataset_item["selected_structured_format"]
selected_container_format = dataset_item["selected_container_format"] selected_container_format = dataset_item["selected_container_format"]
@ -779,7 +779,12 @@ class PydanticSchemaFollowingEnv(BaseEnv):
random.shuffle(rollout_group_data) random.shuffle(rollout_group_data)
for i, (item_messages, _) in enumerate(rollout_group_data): for i, rollout_item in enumerate(rollout_group_data):
item_messages = rollout_item["messages"]
tokens = rollout_item["tokens"]
masks = rollout_item["masks"]
logprobs = rollout_item["logprobs"]
messages_as_dicts = [dict(fs_message) for fs_message in item_messages] messages_as_dicts = [dict(fs_message) for fs_message in item_messages]
model_response_text = messages_as_dicts[-1]["content"] model_response_text = messages_as_dicts[-1]["content"]
@ -906,63 +911,20 @@ class PydanticSchemaFollowingEnv(BaseEnv):
f"Rollout {i}: Extraction failed for {selected_structured_format.value} with container {selected_container_format.value}" # noqa: E501 f"Rollout {i}: Extraction failed for {selected_structured_format.value} with container {selected_container_format.value}" # noqa: E501
) )
try: # Remove examples with insufficient context
if not isinstance(messages_as_dicts, list):
if self.debug_logging:
self.logger.error(
f"Expected list for tokenization, got {type(messages_as_dicts)}"
)
continue
for msg_idx, msg in enumerate(messages_as_dicts):
if not isinstance(msg, dict):
if self.debug_logging:
self.logger.error(
f"Message {msg_idx} is not a dict: {type(msg)}"
)
continue # Skip this rollout if message format is incorrect
if "role" not in msg or "content" not in msg:
if self.debug_logging:
self.logger.error(
f"Message {msg_idx} missing required keys: {msg.keys()}"
)
continue # Skip this rollout
if not isinstance(msg["content"], str):
if self.debug_logging:
self.logger.warning(
f"Converting content to string for message {msg_idx}"
)
msg["content"] = str(msg["content"])
out_dict = tokenize_for_trainer(
self.tokenizer,
messages_as_dicts,
include_messages=self.config.include_messages,
)
tokens = out_dict["tokens"]
masks = out_dict["masks"]
except Exception as e:
if self.debug_logging:
self.logger.error(
f"Tokenization failed for rollout {i} (problem: {problem_id}): {e}"
)
self.logger.debug(
f"Messages format: {[type(m) for m in messages_as_dicts]}"
)
continue
if len([1 for m_val in masks if m_val != -100]) < 10: # Min context length if len([1 for m_val in masks if m_val != -100]) < 10: # Min context length
if self.debug_logging: if self.debug_logging:
self.logger.debug( self.logger.debug(
f"Skipping rollout {i} (problem: {problem_id}) due to insufficient context length after tokenization." # noqa: E501 f"Skipping rollout {i} (problem: {problem_id}) due to insufficient context length." # noqa: E501
) )
continue continue
scores_obj["tokens"].append(tokens) scores_obj["tokens"].append(tokens)
scores_obj["masks"].append(masks) scores_obj["masks"].append(masks)
scores_obj["inference_logprobs"].append(logprobs)
scores_obj["scores"].append(reward) scores_obj["scores"].append(reward)
# Store original messages (converted to dicts) if available in out_dict, else the modified ones # Store original messages (converted to dicts)
scores_obj["messages"].append(out_dict.get("messages", messages_as_dicts)) scores_obj["messages"].append(messages_as_dicts)
self.percent_correct_buffer.append(1.0 if reward == 1.0 else 0.0) self.percent_correct_buffer.append(1.0 if reward == 1.0 else 0.0)
@ -1036,12 +998,16 @@ class PydanticSchemaFollowingEnv(BaseEnv):
f"Requesting {self.config.group_size} completions with max_tokens={self.config.max_token_length}, temperature=0.9" # noqa: E501 f"Requesting {self.config.group_size} completions with max_tokens={self.config.max_token_length}, temperature=0.9" # noqa: E501
) )
completions = await self.server.completion( async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
prompt=prompt_str, completions = await managed.completion(
n=self.config.group_size, prompt=prompt_str,
max_tokens=self.config.max_token_length, n=self.config.group_size,
temperature=0.9, max_tokens=self.config.max_token_length,
) temperature=0.9,
)
state = managed.get_state()
nodes = state["nodes"]
if self.debug_logging: if self.debug_logging:
self.logger.debug( self.logger.debug(
@ -1060,7 +1026,15 @@ class PydanticSchemaFollowingEnv(BaseEnv):
current_trajectory_messages.append( current_trajectory_messages.append(
frozenset({"role": "assistant", "content": choice.text}.items()) frozenset({"role": "assistant", "content": choice.text}.items())
) )
to_score_list.append((tuple(current_trajectory_messages), dataset_item)) to_score_list.append(
{
"messages": tuple(current_trajectory_messages),
"dataset_item": dataset_item,
"tokens": nodes[i].tokens,
"masks": nodes[i].masked_tokens,
"logprobs": nodes[i].logprobs,
}
)
scored_data = await self.score(to_score_list) scored_data = await self.score(to_score_list)
@ -1251,13 +1225,14 @@ class PydanticSchemaFollowingEnv(BaseEnv):
f"Eval prompt length for {problem_id}: {len(prompt)} characters" f"Eval prompt length for {problem_id}: {len(prompt)} characters"
) )
completion = await self.server.completion( async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
prompt=prompt, completion = await managed.completion(
n=1, prompt=prompt,
max_tokens=self.config.max_token_length, n=1,
temperature=0.1, max_tokens=self.config.max_token_length,
split="eval", temperature=0.1,
) split="eval",
)
model_response_text = completion.choices[0].text model_response_text = completion.choices[0].text