diff --git a/environments/pydantic_schema_following_environment/pydantic_schema_following_environment.py b/environments/pydantic_schema_following_environment/pydantic_schema_following_environment.py index f3711c7c..6bee739c 100644 --- a/environments/pydantic_schema_following_environment/pydantic_schema_following_environment.py +++ b/environments/pydantic_schema_following_environment/pydantic_schema_following_environment.py @@ -69,7 +69,6 @@ from atroposlib.envs.base import ( Item, ScoredDataGroup, ) -from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer # Import editing functionality try: @@ -718,7 +717,7 @@ class PydanticSchemaFollowingEnv(BaseEnv): async def score( self, - rollout_group_data: List[Tuple[Tuple[Dict[str, str], ...], Dict[str, Any]]], + rollout_group_data: List[Dict[str, Any]], ) -> Optional[ScoredDataGroup]: """Score the rollouts based on Pydantic validation or other structural checks.""" if self.debug_logging: @@ -729,13 +728,14 @@ class PydanticSchemaFollowingEnv(BaseEnv): scores_obj["masks"] = list() scores_obj["scores"] = list() scores_obj["messages"] = list() + scores_obj["inference_logprobs"] = list() if not rollout_group_data: if self.debug_logging: self.logger.warning("No rollout data to score") return None - dataset_item = rollout_group_data[0][1] + dataset_item = rollout_group_data[0]["dataset_item"] problem_id = dataset_item.get("problem_id", "N/A") selected_structured_format = dataset_item["selected_structured_format"] selected_container_format = dataset_item["selected_container_format"] @@ -779,7 +779,12 @@ class PydanticSchemaFollowingEnv(BaseEnv): random.shuffle(rollout_group_data) - for i, (item_messages, _) in enumerate(rollout_group_data): + for i, rollout_item in enumerate(rollout_group_data): + + item_messages = rollout_item["messages"] + tokens = rollout_item["tokens"] + masks = rollout_item["masks"] + logprobs = rollout_item["logprobs"] messages_as_dicts = [dict(fs_message) for fs_message in item_messages] model_response_text = messages_as_dicts[-1]["content"] @@ -906,63 +911,20 @@ class PydanticSchemaFollowingEnv(BaseEnv): f"Rollout {i}: Extraction failed for {selected_structured_format.value} with container {selected_container_format.value}" # noqa: E501 ) - try: - if not isinstance(messages_as_dicts, list): - if self.debug_logging: - self.logger.error( - f"Expected list for tokenization, got {type(messages_as_dicts)}" - ) - continue - for msg_idx, msg in enumerate(messages_as_dicts): - if not isinstance(msg, dict): - if self.debug_logging: - self.logger.error( - f"Message {msg_idx} is not a dict: {type(msg)}" - ) - continue # Skip this rollout if message format is incorrect - if "role" not in msg or "content" not in msg: - if self.debug_logging: - self.logger.error( - f"Message {msg_idx} missing required keys: {msg.keys()}" - ) - continue # Skip this rollout - if not isinstance(msg["content"], str): - if self.debug_logging: - self.logger.warning( - f"Converting content to string for message {msg_idx}" - ) - msg["content"] = str(msg["content"]) - - out_dict = tokenize_for_trainer( - self.tokenizer, - messages_as_dicts, - include_messages=self.config.include_messages, - ) - tokens = out_dict["tokens"] - masks = out_dict["masks"] - - except Exception as e: - if self.debug_logging: - self.logger.error( - f"Tokenization failed for rollout {i} (problem: {problem_id}): {e}" - ) - self.logger.debug( - f"Messages format: {[type(m) for m in messages_as_dicts]}" - ) - continue - + # Remove examples with insufficient context if len([1 for m_val in masks if m_val != -100]) < 10: # Min context length if self.debug_logging: self.logger.debug( - f"Skipping rollout {i} (problem: {problem_id}) due to insufficient context length after tokenization." # noqa: E501 + f"Skipping rollout {i} (problem: {problem_id}) due to insufficient context length." # noqa: E501 ) continue scores_obj["tokens"].append(tokens) scores_obj["masks"].append(masks) + scores_obj["inference_logprobs"].append(logprobs) scores_obj["scores"].append(reward) - # Store original messages (converted to dicts) if available in out_dict, else the modified ones - scores_obj["messages"].append(out_dict.get("messages", messages_as_dicts)) + # Store original messages (converted to dicts) + scores_obj["messages"].append(messages_as_dicts) self.percent_correct_buffer.append(1.0 if reward == 1.0 else 0.0) @@ -1036,12 +998,16 @@ class PydanticSchemaFollowingEnv(BaseEnv): f"Requesting {self.config.group_size} completions with max_tokens={self.config.max_token_length}, temperature=0.9" # noqa: E501 ) - completions = await self.server.completion( - prompt=prompt_str, - n=self.config.group_size, - max_tokens=self.config.max_token_length, - temperature=0.9, - ) + async with self.server.managed_server(tokenizer=self.tokenizer) as managed: + completions = await managed.completion( + prompt=prompt_str, + n=self.config.group_size, + max_tokens=self.config.max_token_length, + temperature=0.9, + ) + + state = managed.get_state() + nodes = state["nodes"] if self.debug_logging: self.logger.debug( @@ -1060,7 +1026,15 @@ class PydanticSchemaFollowingEnv(BaseEnv): current_trajectory_messages.append( frozenset({"role": "assistant", "content": choice.text}.items()) ) - to_score_list.append((tuple(current_trajectory_messages), dataset_item)) + to_score_list.append( + { + "messages": tuple(current_trajectory_messages), + "dataset_item": dataset_item, + "tokens": nodes[i].tokens, + "masks": nodes[i].masked_tokens, + "logprobs": nodes[i].logprobs, + } + ) scored_data = await self.score(to_score_list) @@ -1251,13 +1225,14 @@ class PydanticSchemaFollowingEnv(BaseEnv): f"Eval prompt length for {problem_id}: {len(prompt)} characters" ) - completion = await self.server.completion( - prompt=prompt, - n=1, - max_tokens=self.config.max_token_length, - temperature=0.1, - split="eval", - ) + async with self.server.managed_server(tokenizer=self.tokenizer) as managed: + completion = await managed.completion( + prompt=prompt, + n=1, + max_tokens=self.config.max_token_length, + temperature=0.1, + split="eval", + ) model_response_text = completion.choices[0].text