mirror of
https://github.com/InternLM/InternBootcamp.git
synced 2026-04-22 16:49:04 +00:00
166 lines
5.8 KiB
Python
Executable file
166 lines
5.8 KiB
Python
Executable file
# Copyright (c) InternLM. All rights reserved.
|
|
import json
|
|
import random
|
|
|
|
import numpy as np
|
|
import torch
|
|
from xtuner._lite import get_logger
|
|
from xtuner._lite.algorithms.sft.dataset import SftCollator
|
|
|
|
logger = get_logger()
|
|
|
|
|
|
class InferDataset(torch.utils.data.Dataset):
|
|
|
|
def __init__(self, prompts_input_ids, responses_ids, message_data, metadata):
|
|
super().__init__()
|
|
|
|
assert (
|
|
len(prompts_input_ids)
|
|
== len(responses_ids)
|
|
== len(message_data)
|
|
== len(metadata)
|
|
), f"The length of prompts_input_ids, responses_ids, message_data, metadata should be the same, but got {len(prompts_input_ids)}, {len(responses_ids)}, {len(message_data)}, {len(metadata)}"
|
|
self.prompts_input_ids = prompts_input_ids
|
|
self.responses_ids = responses_ids
|
|
self.message_data = message_data
|
|
self.metadata = metadata
|
|
|
|
def __len__(self):
|
|
return len(self.prompts_input_ids)
|
|
|
|
def __getitem__(self, item):
|
|
|
|
prompt_input_ids = self.prompts_input_ids[item]
|
|
response_ids = self.responses_ids[item]
|
|
num_prefill_tokens = len(prompt_input_ids)
|
|
|
|
input_ids = prompt_input_ids + response_ids
|
|
labels = [-100] * (num_prefill_tokens - 1) + response_ids + [-100]
|
|
|
|
return {
|
|
"input_ids": input_ids,
|
|
"labels": labels,
|
|
"num_tokens": len(input_ids),
|
|
"message_data": self.message_data[item],
|
|
"metadata": self.metadata[item],
|
|
}
|
|
|
|
|
|
class TrajectoryDataset(torch.utils.data.Dataset):
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
self._num_action_tokens = 0
|
|
self._num_total_tokens = 0
|
|
self._trajectories = []
|
|
|
|
@property
|
|
def num_action_tokens(self):
|
|
return self._num_action_tokens.item()
|
|
|
|
@property
|
|
def num_total_tokens(self):
|
|
return self._num_total_tokens
|
|
|
|
def update(self, trajectories):
|
|
num_total_tokens = 0
|
|
num_action_tokens = 0
|
|
for data in trajectories:
|
|
labels = np.array(data["labels"])
|
|
num_total_tokens += labels.size
|
|
num_action_tokens += (labels >= 0).sum()
|
|
|
|
self._num_action_tokens = num_action_tokens
|
|
self._num_total_tokens = num_total_tokens
|
|
|
|
self._trajectories = trajectories
|
|
|
|
def dump_jsonl(self, path, tokenizer, debug=False):
|
|
|
|
with open(path, "w", encoding="utf8") as f:
|
|
for data in self._trajectories:
|
|
json_line = {
|
|
"sequence": (
|
|
data["sequence_text"]
|
|
if "sequence_text" in data
|
|
else tokenizer.decode(data["input_ids"])
|
|
),
|
|
"num_tokens": data["num_tokens"],
|
|
}
|
|
json_line["judger_reward"] = data["judger_reward"]
|
|
json_line["judger_advantage"] = data["judger_advantage"]
|
|
|
|
if debug:
|
|
json_line["input_ids"] = data["input_ids"]
|
|
json_line["labels"] = data["labels"]
|
|
|
|
json_str = json.dumps(json_line, ensure_ascii=False)
|
|
f.write(json_str + "\n")
|
|
|
|
def dump_log(self, path, tokenizer, debug=False):
|
|
|
|
with open(path, "w", encoding="utf8") as f:
|
|
for data in self._trajectories:
|
|
log_string = f"[sequence]:\n{data['sequence_text'] if 'sequence_text' in data else tokenizer.decode(data['input_ids'])}\n\n"
|
|
log_string += f"[num_tokens]: {data['num_tokens']}\n"
|
|
log_string += f"[judger_reward]: {data['judger_reward']}\n"
|
|
log_string += f"[judger_advantage]: {data['judger_advantage']}\n"
|
|
f.write(log_string + "\n\n=======================\n")
|
|
|
|
def __len__(self):
|
|
return len(self._trajectories)
|
|
|
|
def __getitem__(self, item):
|
|
|
|
return self._trajectories[item]
|
|
|
|
|
|
class TrajectoryDatasetWithFilter(TrajectoryDataset):
|
|
def __init__(self, repeat_k=1, only_keep_1_pair=True):
|
|
super().__init__()
|
|
self.repeat_k = repeat_k
|
|
self.only_keep_1_pair = only_keep_1_pair
|
|
|
|
def update(self, trajectories):
|
|
# split trajectories into k groups: (a, a, b, b, c, c) -> [(a, a), (b, b), (c, c)]
|
|
groups = [
|
|
trajectories[i : i + self.repeat_k]
|
|
for i in range(0, len(trajectories), self.repeat_k)
|
|
]
|
|
keeped_trajectories = []
|
|
for group in groups:
|
|
correctness = [1 if data["judger_reward"] == 1 else 0 for data in group]
|
|
correct = [data for data in group if data["judger_reward"] == 1]
|
|
incorrect = [data for data in group if data["judger_reward"] != 1]
|
|
pass_rate = sum(correctness) / len(correctness)
|
|
if self.only_keep_1_pair:
|
|
if pass_rate == 1 or pass_rate == 0:
|
|
continue
|
|
# max keep 1 correct and 1 incorrect
|
|
correct = random.choice(correct)
|
|
incorrect = random.choice(incorrect)
|
|
correct["pass_rate"] = pass_rate
|
|
incorrect["pass_rate"] = pass_rate
|
|
keeped_trajectories.append(correct)
|
|
keeped_trajectories.append(incorrect)
|
|
else:
|
|
if pass_rate == 1 or pass_rate == 0:
|
|
continue
|
|
for data in group:
|
|
data["pass_rate"] = pass_rate
|
|
keeped_trajectories.append(data)
|
|
|
|
super().update(keeped_trajectories)
|
|
|
|
|
|
class TrajectoryCollator(SftCollator):
|
|
|
|
def __call__(self, instances):
|
|
|
|
data = super().__call__(instances)
|
|
data["judger_rewards"] = [item["judger_reward"] for item in instances]
|
|
data["judger_advantages"] = [item["judger_advantage"] for item in instances]
|
|
if "pass_rate" in instances[0]:
|
|
data["pass_rate"] = [item["pass_rate"] for item in instances]
|
|
return data
|