diff --git a/environments/instruction_following_algorithm_environment.py b/environments/instruction_following_algorithm_environment.py index 69b2b3ec..c59f43ec 100644 --- a/environments/instruction_following_algorithm_environment.py +++ b/environments/instruction_following_algorithm_environment.py @@ -19,7 +19,6 @@ from atroposlib.envs.base import ( Item, ScoredDataGroup, ) -from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer # System prompt can be reused or adapted for instruction following tasks system_prompt = ( @@ -592,13 +591,14 @@ class InstructionFollowingEnv(BaseEnv): messages, add_generation_prompt=True, tokenize=False ) - completion = await self.server.completion( - prompt=prompt_str, - n=1, - max_tokens=self.config.max_token_length, # Use config for max_tokens - temperature=0.2, # Temperature for eval, can be 0 for deterministic - split="eval", - ) + async with self.server.managed_server(tokenizer=self.tokenizer) as managed: + completion = await managed.completion( + prompt=prompt_str, + n=1, + max_tokens=self.config.max_token_length, # Use config for max_tokens + temperature=0.2, # Temperature for eval, can be 0 for deterministic + split="eval", + ) model_response_text = completion.choices[0].text score_value = await self._get_score_from_verifier( @@ -746,15 +746,16 @@ class InstructionFollowingEnv(BaseEnv): await super().close() async def score( - self, rollout_group_data: List[Tuple[tuple, Dict]] + self, rollout_group_data: List[Dict] ) -> Optional[ScoredDataGroup]: - # rollout_group_data is a list of (trajectory_messages_tuple, answer_info_dict) + # rollout_group_data is a list of dicts with messages, answer_info, tokens, masks, logprobs # answer_info_dict = {"func_name": ..., "args": ...} scores_container = ScoredDataGroup() scores_container["tokens"] = list() scores_container["masks"] = list() scores_container["scores"] = list() + scores_container["inference_logprobs"] = list() if not rollout_group_data: return None @@ -770,8 +771,8 @@ class InstructionFollowingEnv(BaseEnv): failed_rollouts_for_this_group = [] for trajectory_item in rollout_group_data: - full_trajectory_messages = trajectory_item[0] - answer_info = trajectory_item[1] # {"func_name": ..., "args": ...} + full_trajectory_messages = trajectory_item["messages"] + answer_info = trajectory_item["answer_info"] # {"func_name": ..., "args": ...} model_response_text = full_trajectory_messages[-1]["content"] func_name = answer_info["func_name"] @@ -803,12 +804,9 @@ class InstructionFollowingEnv(BaseEnv): elif self.config.dump_failed_rollouts and reward == 0: failed_rollouts_for_this_group.append(rollout_dict) - # Tokenize the conversation for PPO training - # Ensure full_trajectory_messages is a list of dicts - list_of_dicts_trajectory = [dict(msg) for msg in full_trajectory_messages] - out_dict = tokenize_for_trainer(self.tokenizer, list_of_dicts_trajectory) - tokens = out_dict["tokens"] - masks = out_dict["masks"] + tokens = trajectory_item["tokens"] + masks = trajectory_item["masks"] + logprobs = trajectory_item["logprobs"] # Filter out examples with insufficient context (too short) if ( @@ -818,6 +816,7 @@ class InstructionFollowingEnv(BaseEnv): scores_container["tokens"].append(tokens) scores_container["masks"].append(masks) + scores_container["inference_logprobs"].append(logprobs) scores_container["scores"].append(reward) # Stop if we have enough examples for the group @@ -832,7 +831,7 @@ class InstructionFollowingEnv(BaseEnv): if current_scores: average_score = sum(current_scores) / len(current_scores) # Get task info from the first rollout's answer_info - answer_info = rollout_group_data[0][1] if rollout_group_data else {} + answer_info = rollout_group_data[0]["answer_info"] if rollout_group_data else {} func_name = answer_info.get("func_name", "unknown_task") # Check if group is too easy for training (but still allow data dumping) @@ -849,7 +848,7 @@ class InstructionFollowingEnv(BaseEnv): <= self.config.max_group_average_for_training + 0.1 ): # Small buffer for data collection # Extract item info for the group - get from first rollout's answer_info - answer_info = rollout_group_data[0][1] + answer_info = rollout_group_data[0]["answer_info"] item_id = f"allenai_RLVR-IFeval_train_item_{answer_info.get('func_name', 'unknown')}_{hash(str(answer_info)) % 100000}" # noqa group_data_to_save = { @@ -863,7 +862,7 @@ class InstructionFollowingEnv(BaseEnv): if failed_rollouts_for_this_group: # Extract item info for the failed group - answer_info = rollout_group_data[0][1] + answer_info = rollout_group_data[0]["answer_info"] item_id = f"allenai_RLVR-IFeval_train_item_{answer_info.get('func_name', 'unknown')}_{hash(str(answer_info)) % 100000}" # noqa failed_group_data_to_save = { @@ -906,7 +905,7 @@ class InstructionFollowingEnv(BaseEnv): # Create group data structure and add to buffers for data dumping (for training groups) if rollouts_for_this_group: # Extract item info for the group - get from first rollout's answer_info - answer_info = rollout_group_data[0][1] + answer_info = rollout_group_data[0]["answer_info"] item_id = f"allenai_RLVR-IFeval_train_item_{answer_info.get('func_name', 'unknown')}_{hash(str(answer_info)) % 100000}" # noqa group_data_to_save = { @@ -924,7 +923,7 @@ class InstructionFollowingEnv(BaseEnv): if failed_rollouts_for_this_group: # Extract item info for the failed group - answer_info = rollout_group_data[0][1] + answer_info = rollout_group_data[0]["answer_info"] item_id = f"allenai_RLVR-IFeval_train_item_{answer_info.get('func_name', 'unknown')}_{hash(str(answer_info)) % 100000}" # noqa failed_group_data_to_save = {