mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
parent
46f05673aa
commit
77e14199ce
1 changed files with 7 additions and 5 deletions
|
|
@ -745,9 +745,7 @@ class InstructionFollowingEnv(BaseEnv):
|
|||
|
||||
await super().close()
|
||||
|
||||
async def score(
|
||||
self, rollout_group_data: List[Dict]
|
||||
) -> Optional[ScoredDataGroup]:
|
||||
async def score(self, rollout_group_data: List[Dict]) -> Optional[ScoredDataGroup]:
|
||||
# rollout_group_data is a list of dicts with messages, answer_info, tokens, masks, logprobs
|
||||
# answer_info_dict = {"func_name": ..., "args": ...}
|
||||
|
||||
|
|
@ -772,7 +770,9 @@ class InstructionFollowingEnv(BaseEnv):
|
|||
|
||||
for trajectory_item in rollout_group_data:
|
||||
full_trajectory_messages = trajectory_item["messages"]
|
||||
answer_info = trajectory_item["answer_info"] # {"func_name": ..., "args": ...}
|
||||
answer_info = trajectory_item[
|
||||
"answer_info"
|
||||
] # {"func_name": ..., "args": ...}
|
||||
|
||||
model_response_text = full_trajectory_messages[-1]["content"]
|
||||
func_name = answer_info["func_name"]
|
||||
|
|
@ -831,7 +831,9 @@ class InstructionFollowingEnv(BaseEnv):
|
|||
if current_scores:
|
||||
average_score = sum(current_scores) / len(current_scores)
|
||||
# Get task info from the first rollout's answer_info
|
||||
answer_info = rollout_group_data[0]["answer_info"] if rollout_group_data else {}
|
||||
answer_info = (
|
||||
rollout_group_data[0]["answer_info"] if rollout_group_data else {}
|
||||
)
|
||||
func_name = answer_info.get("func_name", "unknown_task")
|
||||
|
||||
# Check if group is too easy for training (but still allow data dumping)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue