diff --git a/environments/instruction_following_algorithm_environment.py b/environments/instruction_following_algorithm_environment.py index c59f43ec..61a9ae6d 100644 --- a/environments/instruction_following_algorithm_environment.py +++ b/environments/instruction_following_algorithm_environment.py @@ -745,9 +745,7 @@ class InstructionFollowingEnv(BaseEnv): await super().close() - async def score( - self, rollout_group_data: List[Dict] - ) -> Optional[ScoredDataGroup]: + async def score(self, rollout_group_data: List[Dict]) -> Optional[ScoredDataGroup]: # rollout_group_data is a list of dicts with messages, answer_info, tokens, masks, logprobs # answer_info_dict = {"func_name": ..., "args": ...} @@ -772,7 +770,9 @@ class InstructionFollowingEnv(BaseEnv): for trajectory_item in rollout_group_data: full_trajectory_messages = trajectory_item["messages"] - answer_info = trajectory_item["answer_info"] # {"func_name": ..., "args": ...} + answer_info = trajectory_item[ + "answer_info" + ] # {"func_name": ..., "args": ...} model_response_text = full_trajectory_messages[-1]["content"] func_name = answer_info["func_name"] @@ -831,7 +831,9 @@ class InstructionFollowingEnv(BaseEnv): if current_scores: average_score = sum(current_scores) / len(current_scores) # Get task info from the first rollout's answer_info - answer_info = rollout_group_data[0]["answer_info"] if rollout_group_data else {} + answer_info = ( + rollout_group_data[0]["answer_info"] if rollout_group_data else {} + ) func_name = answer_info.get("func_name", "unknown_task") # Check if group is too easy for training (but still allow data dumping)