init-commit

This commit is contained in:
lilinyang 2025-05-23 15:27:15 +08:00
commit 18a552597a
3461 changed files with 1150579 additions and 0 deletions

View file

@ -0,0 +1,18 @@
# Copyright (c) InternLM. All rights reserved.
from .prompt import bootcampPromptDataset, PromptCollator, InfiniteDataLoaderIter
from .trajectory import (
InferDataset,
TrajectoryCollator,
TrajectoryDataset,
TrajectoryDatasetWithFilter,
)
__all__ = [
"bootcampPromptDataset",
"PromptCollator",
"InferDataset",
"TrajectoryDataset",
"TrajectoryDatasetWithFilter",
"TrajectoryCollator",
"InfiniteDataLoaderIter",
]

View file

@ -0,0 +1,214 @@
# Copyright (c) InternLM. All rights reserved.
import json
import time
import numpy as np
import torch
from datasets import load_dataset
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset
from xtuner._lite import get_logger
logger = get_logger()
def load_hf_datasets(repo, split="train"):
dataset = load_dataset(repo, split=split)
converted_ds = []
for sample in dataset:
converted_ds.append(
{
"pass_rate": sample["pass_rate"],
"message_data": [{"role": "user", "content": sample["question"]}],
"metadata": {
"data_source": "math", # for the router to know which judger to use
"gold_answer": sample["gold_answer"],
},
}
)
logger.info(f"Loaded {len(converted_ds)} samples from {repo}")
return converted_ds
def load_jsonl_datasets(file_path):
subsample_ratio = 1.0
if "::" in file_path:
file_path, subsample_ratio = file_path.split("::")
subsample_ratio = float(subsample_ratio)
with open(file_path, "r") as f:
lines = f.readlines()
datasets = []
for line in lines:
sample = json.loads(line)
if "message_data" not in sample:
datasets.append(
{
"pass_rate": sample["pass_rate"],
"message_data": [{"role": "user", "content": sample["question"]}],
"metadata": {
"data_source": "math", # for the router to know which judger to use
"gold_answer": sample["gold_answer"],
},
}
)
else:
datasets.append(sample)
if subsample_ratio < 1.0:
np.random.seed(0)
datasets = np.random.choice(
datasets, int(len(datasets) * subsample_ratio), replace=False
).tolist()
logger.info(f"Loaded {len(datasets)} samples from {file_path}")
return datasets
def balance_difficulty_with_cfg(dataset, difficulty_balance_cfg):
balanced_dataset = []
for sample in dataset:
pass_rate = sample["pass_rate"]
for (low, high), repeat in difficulty_balance_cfg:
if low <= pass_rate < high:
balanced_dataset.extend([sample] * repeat)
break
logger.info(
f"After difficulty balancing, the dataset size is {len(balanced_dataset)}"
)
return balanced_dataset
class bootcampPromptDataset(Dataset):
def __init__(self, path, tokenizer, difficulty_balance_cfg=None):
if isinstance(path, str):
path = [path]
dataset = []
for p in path:
if p.endswith(".jsonl"):
dataset.extend(load_jsonl_datasets(p))
else:
dataset.extend(load_hf_datasets(p))
if difficulty_balance_cfg:
dataset = balance_difficulty_with_cfg(dataset, difficulty_balance_cfg)
self.dataset = dataset
self.tokenizer = tokenizer
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx):
sample = self.dataset[idx]
input_ids = self.tokenizer.apply_chat_template(
sample["message_data"], add_generation_prompt=True
)
sample["input_ids"] = input_ids
sample["labels"] = input_ids
sample["num_tokens"] = len(input_ids)
return sample
class PromptCollator:
def __init__(self, pad_token_id=0, ignore_id=-100, pack_batch=False):
self.pack_batch = pack_batch
self.pad_token_id = pad_token_id
self.ignore_id = ignore_id
def __call__(self, instances):
_instances = []
for ins in instances:
if isinstance(ins, list):
_instances.extend(ins)
else:
_instances.append(ins)
instances = _instances
input_ids = []
labels = []
num_tokens = []
metadatas = []
message_datas = []
for data in instances:
input_ids.append(torch.LongTensor(data["input_ids"]))
labels.append(torch.LongTensor(data["labels"]))
metadatas.append(data["metadata"])
message_datas.append(data["message_data"])
if isinstance(data["num_tokens"], int):
num_tokens.append(data["num_tokens"])
else:
num_tokens.extend(data["num_tokens"])
attention_mask = [torch.ones_like(ids) for ids in input_ids]
num_tokens = torch.IntTensor(num_tokens)
if len(instances) > 1 and self.pack_batch:
input_ids = torch.cat(input_ids, dim=0).unsqueeze(0)
labels = torch.cat(labels, dim=0).unsqueeze(0)
attention_mask = torch.cat(attention_mask, dim=0).unsqueeze(0)
elif len(instances) > 1 and not self.pack_batch:
input_ids = pad_sequence(
input_ids, batch_first=True, padding_value=self.pad_token_id
)
labels = pad_sequence(
labels, batch_first=True, padding_value=self.ignore_id
)
attention_mask = pad_sequence(
attention_mask, batch_first=True, padding_value=0
)
else:
input_ids = torch.stack(input_ids)
labels = torch.stack(labels)
attention_mask = torch.stack(attention_mask)
if input_ids.shape != labels.shape:
logger.error(f"[instances] {instances}")
logger.error(f"[num_tokens] {num_tokens}")
logger.error(f"[input_ids] {input_ids}")
logger.error(f"[labels] {labels}")
raise RuntimeError(
"The shape of input_ids and labels must be "
f"equal, but found {input_ids.shape} and "
f"{labels.shape}."
)
data_dict = {
"input_ids": input_ids,
"labels": labels,
"num_tokens": num_tokens,
"attention_mask": attention_mask.bool(),
"metadata": metadatas,
"message_data": message_datas,
}
return data_dict
class InfiniteDataLoaderIter:
def __init__(self, dataloader):
self.dataloader = dataloader
self.iterator = iter(dataloader)
self._epoch = 0
def __iter__(self):
return self
def __next__(self):
try:
data = next(self.iterator)
except StopIteration:
logger.info(f"Dataloader epoch {self._epoch} finished. Start a new epoch.")
self._epoch += 1
if hasattr(self.dataloader, 'sampler') and hasattr(
self.dataloader.sampler, 'set_epoch'):
# In case the` _SingleProcessDataLoaderIter` has no sampler,
# or data loader uses `SequentialSampler` in Pytorch.
self.dataloader.sampler.set_epoch(self._epoch)
time.sleep(2) # Prevent possible deadlock during epoch transition
self.iterator = iter(self.dataloader)
data = next(self.iterator)
return data

View file

@ -0,0 +1,166 @@
# Copyright (c) InternLM. All rights reserved.
import json
import random
import numpy as np
import torch
from xtuner._lite import get_logger
from xtuner._lite.algorithms.sft.dataset import SftCollator
logger = get_logger()
class InferDataset(torch.utils.data.Dataset):
def __init__(self, prompts_input_ids, responses_ids, message_data, metadata):
super().__init__()
assert (
len(prompts_input_ids)
== len(responses_ids)
== len(message_data)
== len(metadata)
), f"The length of prompts_input_ids, responses_ids, message_data, metadata should be the same, but got {len(prompts_input_ids)}, {len(responses_ids)}, {len(message_data)}, {len(metadata)}"
self.prompts_input_ids = prompts_input_ids
self.responses_ids = responses_ids
self.message_data = message_data
self.metadata = metadata
def __len__(self):
return len(self.prompts_input_ids)
def __getitem__(self, item):
prompt_input_ids = self.prompts_input_ids[item]
response_ids = self.responses_ids[item]
num_prefill_tokens = len(prompt_input_ids)
input_ids = prompt_input_ids + response_ids
labels = [-100] * (num_prefill_tokens - 1) + response_ids + [-100]
return {
"input_ids": input_ids,
"labels": labels,
"num_tokens": len(input_ids),
"message_data": self.message_data[item],
"metadata": self.metadata[item],
}
class TrajectoryDataset(torch.utils.data.Dataset):
def __init__(self):
super().__init__()
self._num_action_tokens = 0
self._num_total_tokens = 0
self._trajectories = []
@property
def num_action_tokens(self):
return self._num_action_tokens.item()
@property
def num_total_tokens(self):
return self._num_total_tokens
def update(self, trajectories):
num_total_tokens = 0
num_action_tokens = 0
for data in trajectories:
labels = np.array(data["labels"])
num_total_tokens += labels.size
num_action_tokens += (labels >= 0).sum()
self._num_action_tokens = num_action_tokens
self._num_total_tokens = num_total_tokens
self._trajectories = trajectories
def dump_jsonl(self, path, tokenizer, debug=False):
with open(path, "w", encoding="utf8") as f:
for data in self._trajectories:
json_line = {
"sequence": (
data["sequence_text"]
if "sequence_text" in data
else tokenizer.decode(data["input_ids"])
),
"num_tokens": data["num_tokens"],
}
json_line["judger_reward"] = data["judger_reward"]
json_line["judger_advantage"] = data["judger_advantage"]
if debug:
json_line["input_ids"] = data["input_ids"]
json_line["labels"] = data["labels"]
json_str = json.dumps(json_line, ensure_ascii=False)
f.write(json_str + "\n")
def dump_log(self, path, tokenizer, debug=False):
with open(path, "w", encoding="utf8") as f:
for data in self._trajectories:
log_string = f"[sequence]:\n{data['sequence_text'] if 'sequence_text' in data else tokenizer.decode(data['input_ids'])}\n\n"
log_string += f"[num_tokens]: {data['num_tokens']}\n"
log_string += f"[judger_reward]: {data['judger_reward']}\n"
log_string += f"[judger_advantage]: {data['judger_advantage']}\n"
f.write(log_string + "\n\n=======================\n")
def __len__(self):
return len(self._trajectories)
def __getitem__(self, item):
return self._trajectories[item]
class TrajectoryDatasetWithFilter(TrajectoryDataset):
def __init__(self, repeat_k=1, only_keep_1_pair=True):
super().__init__()
self.repeat_k = repeat_k
self.only_keep_1_pair = only_keep_1_pair
def update(self, trajectories):
# split trajectories into k groups: (a, a, b, b, c, c) -> [(a, a), (b, b), (c, c)]
groups = [
trajectories[i : i + self.repeat_k]
for i in range(0, len(trajectories), self.repeat_k)
]
keeped_trajectories = []
for group in groups:
correctness = [1 if data["judger_reward"] == 1 else 0 for data in group]
correct = [data for data in group if data["judger_reward"] == 1]
incorrect = [data for data in group if data["judger_reward"] != 1]
pass_rate = sum(correctness) / len(correctness)
if self.only_keep_1_pair:
if pass_rate == 1 or pass_rate == 0:
continue
# max keep 1 correct and 1 incorrect
correct = random.choice(correct)
incorrect = random.choice(incorrect)
correct["pass_rate"] = pass_rate
incorrect["pass_rate"] = pass_rate
keeped_trajectories.append(correct)
keeped_trajectories.append(incorrect)
else:
if pass_rate == 1 or pass_rate == 0:
continue
for data in group:
data["pass_rate"] = pass_rate
keeped_trajectories.append(data)
super().update(keeped_trajectories)
class TrajectoryCollator(SftCollator):
def __call__(self, instances):
data = super().__call__(instances)
data["judger_rewards"] = [item["judger_reward"] for item in instances]
data["judger_advantages"] = [item["judger_advantage"] for item in instances]
if "pass_rate" in instances[0]:
data["pass_rate"] = [item["pass_rate"] for item in instances]
return data

View file

@ -0,0 +1,19 @@
# Copyright (c) InternLM. All rights reserved.
from .base_judger import (
BaseJudger,
register_judger,
registered_judgers,
)
from .math_judger import MathJudger
from .router import InputData, ParallelRouter
from .bootcamp_judger import bootcampJudger
__all__ = [
"register_judger",
"registered_judgers",
"BaseJudger",
"MathJudger",
"InputData",
"ParallelRouter",
"bootcampJudger",
]

View file

@ -0,0 +1,61 @@
# Copyright (c) InternLM. All rights reserved.
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import (
Dict,
Generic,
List,
Optional,
Type,
TypedDict,
TypeVar,
Union,
)
T = TypeVar("T")
MessageItem = TypedDict("MessageItem", {"role": str, "content": str})
Reward = Union[float, List[float], None]
MetaData = TypedDict("MetaData", {"data_source": str})
@dataclass
class JudgeStatus(Generic[T]):
ok: bool = True
reason: Optional[str] = None
handle: Optional[T] = None
class BaseJudger(ABC):
def __init__(self):
pass
@abstractmethod
def on_data_received(
self,
prompt_messages: List[MessageItem],
completion_messages: List[MessageItem],
metadata: dict,
) -> JudgeStatus:
raise NotImplementedError()
@abstractmethod
def on_reward_required(
self,
status: JudgeStatus,
timeout: Optional[float] = None,
) -> Reward:
raise NotImplementedError()
registered_judgers: Dict[str, Type[BaseJudger]] = {}
def register_judger(name: str):
global registered_judgers
def wrapper(cls):
assert name not in registered_judgers, f"{name} already in {registered_judgers}"
registered_judgers[name] = cls
return cls
return wrapper

View file

@ -0,0 +1,81 @@
# Copyright (c) InternLM. All rights reserved.
import random
import re
import time
from typing import List, Optional, Tuple
import asyncio
from concurrent.futures import ThreadPoolExecutor
import requests
import internbootcamp
from .base_judger import BaseJudger, JudgeStatus, MessageItem, Reward, register_judger
@register_judger("bootcamp_judger")
class bootcampJudger(BaseJudger):
def __init__(
self,
stop_word="<|im_end|>",
format_score=0,
format_penalty=True,
short_penalty=True,
short_threshold=128,
):
super().__init__()
self.stop_word = stop_word
self.format_score = format_score
self.format_penalty = format_penalty
self.short_penalty = short_penalty
self.short_threshold = short_threshold
def on_data_received(
self,
prompt_messages: List[MessageItem],
completion_messages: List[MessageItem],
metadata: dict, # 存在数据集对应的字段里面,想存啥都可以,自己解析出来就行
) -> JudgeStatus:
question = prompt_messages[-1]["content"]
response = completion_messages[-1]["content"]
identity = metadata["ground_truth"]
data_source = metadata["data_source"]
verify_label = None
if not response.strip().endswith(self.stop_word):
# If the response does not end with the stop word, it is not a complete response, treat as incorrect
verify_label = False
return JudgeStatus(
ok=True,
handle={
"data_source": data_source,
"question": question,
"response": response,
"identity": identity,
"verify_label": verify_label,
},
)
def on_reward_required( # 把judger的判断结果转成reward的score
self, status: JudgeStatus, timeout: Optional[float] = None
) -> Reward:
if status.handle["verify_label"] is False:
score = 0.0
return score
# 把judger的判断结果转成reward的score
data_source = status.handle["data_source"]
response = status.handle["response"]
identity = status.handle["identity"]
prompt = status.handle["question"]
bootcamp_cls= getattr(internbootcamp, data_source[0].upper() + data_source[1:] + "bootcamp")
try:
score = bootcamp_cls.verify_score(response,identity,format_score=self.format_score,format_penalty=self.format_penalty,short_penalty=self.short_penalty,short_threshold=self.short_threshold)
except:
score = bootcamp_cls.verify_score(response,identity,format_score=self.format_score)
return score
# print(f"[Debug] Prompt: {prompt}")
# print(f"[Debug]: score: {score}, response: {response}")
# if type(score) == int:
# assert score >= 0 and score <= 1
# return score
# return 0

View file

@ -0,0 +1,198 @@
# Copyright (c) InternLM. All rights reserved.
import random
import re
import time
from typing import List, Optional, Tuple
import requests
from .base_judger import BaseJudger, JudgeStatus, MessageItem, Reward, register_judger
from .utils import extract_answer, math_equal
@register_judger("math_judger")
class MathJudger(BaseJudger):
verify_prompt = """You are a helpful assistant who evaluates the correctness and quality of models' outputs.
Please as a grading expert, judge whether the final answers given by the candidates below are consistent with the standard answers, that is, whether the candidates answered correctly.
Here are some evaluation criteria:
1. Please refer to the given standard answer. You don't need to re-generate the answer to the question because the standard answer has been given. You only need to judge whether the candidate's answer is consistent with the standard answer according to the form of the question. Don't try to answer the original question. You can assume that the standard answer is definitely correct.
2. Because the candidate's answer may be different from the standard answer in the form of expression, before making a judgment, please understand the question and the standard answer first, and then judge whether the candidate's answer is correct, but be careful not to try to answer the original question.
3. Some answers may contain multiple items, such as multiple-choice questions, multiple-select questions, fill-in-the-blank questions, etc. As long as the answer is the same as the standard answer, it is enough. For multiple-select questions and multiple-blank fill-in-the-blank questions, the candidate needs to answer all the corresponding options or blanks correctly to be considered correct.
4. Some answers may be expressed in different ways, such as some answers may be a mathematical expression, some answers may be a textual description, as long as the meaning expressed is the same. And some formulas are expressed in different ways, but they are equivalent and correct.
5. If the prediction is given with \\boxed{{}}, please ignore the \\boxed{{}} and only judge whether the candidate's answer is consistent with the standard answer.
Please judge whether the following answers are consistent with the standard answer based on the above criteria. Grade the predicted answer of this new question as one of:
A: CORRECT
B: INCORRECT
Just return the letters \"A\" or \"B\", with no text around it.
Here is your task. Simply reply with either CORRECT, INCORRECT. Don't apologize or correct yourself if there was a mistake; we are just trying to grade the answer.
<Original Question Begin>:
{question}
<Original Question End>
<Gold Target Begin>:
{gold_answer}
<Gold Target End>
<Predicted Answer Begin>:
{answer}
<Predicted End>
Judging the correctness of candidates' answers:"""
def __init__(
self,
hosts: List[str],
max_retries: int = 1,
retry_delay: float = 1.0,
stop_word="<|im_end|>",
thinking_finish_words=["<conclude>", "**Final Answer**", "</think>"],
):
super().__init__()
self.hosts = hosts
self.max_retries = max_retries
self.retry_delay = retry_delay
self.stop_word = stop_word
self.thinking_finish_words = thinking_finish_words
self.host_ip_idx = random.randint(0, len(hosts) - 1)
self.model_name = requests.get(
f"http://{self.hosts[self.host_ip_idx]}/v1/models",
headers={"Authorization": "Bearer "},
).json()["data"][0]["id"]
def on_data_received(
self,
prompt_messages: List[MessageItem],
completion_messages: List[MessageItem],
metadata: dict,
) -> JudgeStatus:
question = prompt_messages[-1]["content"]
response = completion_messages[-1]["content"]
question_type = metadata.get("question_type", None)
gold_answer = metadata["gold_answer"]
if not response.strip().endswith(self.stop_word):
# If the response does not end with the stop word, it is not a complete response, treat as incorrect
return JudgeStatus(
ok=True,
handle={
"question": question,
"question_type": question_type,
"response": response,
"gold_answer": gold_answer,
"verify_label": False,
},
)
for thinking_finish_word in self.thinking_finish_words:
if thinking_finish_word in response:
response = response.split(thinking_finish_word)[-1]
response = response.replace(self.stop_word, "")
# first try to extract and verify with rule, if correct, return
extracted_answer, verify_label = self._extract_and_verify_with_logic(
response, gold_answer
)
if verify_label is True:
return JudgeStatus(
ok=True,
handle={
"question": question,
"question_type": question_type,
"response": response,
"gold_answer": gold_answer,
"verify_label": verify_label,
},
)
# then try to evaluate with model
res_string, verify_label = self._evaluate_answer_with_llm(
question, question_type, response, gold_answer
)
return JudgeStatus(
ok=True,
handle={
"question": question,
"question_type": question_type,
"response": response,
"gold_answer": gold_answer,
"verify_label": verify_label,
},
)
def on_reward_required(
self, status: JudgeStatus, timeout: Optional[float] = None
) -> Reward:
if status.handle is None:
return None
if status.handle["verify_label"] is not None:
return 1.0 if status.handle["verify_label"] else -1.0
return None
def _evaluate_answer_with_llm(
self, question: str, question_type: str, answer: str, gold_answer: str
) -> Tuple[str, bool]:
for i in range(self.max_retries):
host = self.hosts[self.host_ip_idx]
self.host_ip_idx = (self.host_ip_idx + 1) % len(self.hosts)
prompt = self.verify_prompt.format(
"", "", question=question, answer=answer, gold_answer=gold_answer
)
try:
res = requests.post(
f"http://{host}/v1/chat/completions",
json={
"model": self.model_name,
"messages": [
{
"role": "user",
"content": prompt,
}
],
"temperature": 0.0,
"top_p": 0.8,
"top_k": 20,
"repetition_penalty": 1.05,
"max_tokens": 100,
"stop": ["<|im_end|>", "<|endoftext|>"],
},
)
res_string = res.json()["choices"][0]["message"]["content"]
print(f"Evaluate result: {res_string}")
verify_label = self._verify_from_string(res_string)
if verify_label is None:
raise ValueError(
f"Evaluate result is None, judger prediction: {res_string}"
)
return res_string, verify_label
except Exception as e:
print(f"Error verifying answer: {e}")
time.sleep(self.retry_delay)
continue
print(f"Failed to verify answer after {self.max_retries} retries.")
return None, None
def _verify_from_string(self, verification: str):
if "A" in verification and "B" not in verification:
label = True
elif "B" in verification and "A" not in verification:
label = False
else: # judger model failed to predict A or B
label = None
return label
def _extract_and_verify_with_logic(
self, response: str, gold_answer: str
) -> Tuple[str, bool]:
extracted_answer = extract_answer(response)
verify_label = math_equal(extracted_answer, gold_answer)
return extracted_answer, verify_label

View file

@ -0,0 +1,473 @@
# Copyright (c) InternLM. All rights reserved.
import atexit
import functools
import os
import queue
import time
import traceback
from collections import defaultdict
from copy import deepcopy
from dataclasses import dataclass
from multiprocessing import Event, Process, Queue, connection
from multiprocessing.synchronize import Event as EventClass
from typing import (
Callable,
Dict,
Generic,
List,
Optional,
Tuple,
TypedDict,
TypeVar,
cast,
)
from uuid import uuid4
import loguru
from typing_extensions import NotRequired
from .base_judger import (
JudgeStatus,
MessageItem,
MetaData,
Reward,
registered_judgers,
)
class InputData(TypedDict):
prompt_messages: List[MessageItem]
completion_messages: List[MessageItem]
metadata: NotRequired[MetaData]
T = TypeVar("T")
@dataclass
class GenericTask(Generic[T]):
token: str
index: int
judger: str
content: T
@dataclass
class SubprocessConfig:
loguru_handlers: Optional[List[dict]] = None
worker_init_func: Optional[Callable] = None
class ParallelRouter:
def __init__(
self,
judgers_config: Dict[str, dict],
data_judger_mapping: Dict[str, Optional[List[str]]],
logger: Optional["loguru.Logger"] = None,
subprocess_config: Optional[SubprocessConfig] = None,
):
if logger is not None:
self.logger = logger
else:
import mock
self.logger = mock.Mock()
if subprocess_config is not None:
self.subprocess_config = subprocess_config
else:
self.subprocess_config = SubprocessConfig()
if not (
isinstance(judgers_config, dict)
and all(
isinstance(k, str) and isinstance(v, dict)
for k, v in judgers_config.items()
)
):
raise TypeError(
f"Illegal judgers_config: {judgers_config}\n"
"Should be Dict[str, dict]"
)
if "RM" in judgers_config.keys():
raise KeyError(
f"'RM' is a reserved judger keywork for {self.__class__.__name__}, "
f"please remove it from judgers_config: {judgers_config}"
)
self.judgers_config = judgers_config
data_judger_mapping: Dict[str, List[str]] = {
k: v or [] for k, v in data_judger_mapping.items()
} # change None to empty list []
if not (
isinstance(data_judger_mapping, dict)
and all(
isinstance(k, str)
and isinstance(v, (list, tuple, set))
and all(isinstance(vv, str) for vv in v)
for k, v in data_judger_mapping.items()
)
):
raise TypeError(
f"Illegal data_judger_mapping: {data_judger_mapping}\n"
"Should be Dict[str, List[str]]"
)
self.data_judger_mapping = data_judger_mapping
avail_judgers = set(self.judgers_config.keys()) | {"RM"}
_used_judgers: List[str] = []
for v in data_judger_mapping.values():
_used_judgers.extend(v)
used_judgers: set = set(_used_judgers)
if unused := avail_judgers - used_judgers:
self.logger.warning(
"Following judgers are available but not "
f"used in data mapping: {unused}\n"
"Please make sure this is intended"
)
# remove unused configs
for judger_name in unused:
self.judgers_config.pop(judger_name, None)
if missing := used_judgers - avail_judgers:
self.logger.warning(
"Following judgers are configured to be used "
f"but not built in data mapping: {missing}\n"
"Please make sure this is intended"
)
# remove missing judgers from mapping, to prevent potential errors
for source in list(self.data_judger_mapping.keys()):
before = set(self.data_judger_mapping[source])
self.data_judger_mapping[source] = list(before - missing)
# then filter out data_mapping without available judgers
self.data_judger_mapping = {
source: judgers
for source, judgers in self.data_judger_mapping.items()
if len(judgers) > 0
}
# Try build judgers in __init__ so that raise Exceptions earlly
for judger_name, judger_conf in self.judgers_config.items():
_ = self._build_judger(judger_name, judger_conf)
self._processes: List[Process] = []
self._stop_event = Event()
atexit.register(self.shutdown)
self._input_queues: Dict[str, Queue[GenericTask[InputData]]] = {
judger_name: Queue() for judger_name in self.judgers_config.keys()
}
self._output_queue: Queue[GenericTask[Reward]] = Queue()
self._exc_queue: Queue[Tuple[str, Exception]] = Queue()
self._num_tasks: Dict[str, int] = {} # for each token
self._num_indexes: Dict[str, int] = {} # for each token
self._results_buffer: Dict[str, List[GenericTask[Reward]]] = defaultdict(
list
) # results buffer grouped by the key "token"
def submit(self, data_batch: List[InputData]):
indexes_for_ext: List[int] = []
indexes_for_local: List[int] = []
tasks_input: List[GenericTask[InputData]] = []
token = str(uuid4())
for index, data_item in enumerate(data_batch):
if (
not isinstance(data_item, dict)
or "metadata" not in data_item
or "prompt_messages" not in data_item
or "completion_messages" not in data_item
):
indexes_for_local.append(index)
continue
source = data_item["metadata"].get("data_source", None)
if source is None or source not in self.data_judger_mapping:
indexes_for_local.append(index)
continue
indexes_for_ext.append(index)
for judger in self.data_judger_mapping[source]:
if judger == "RM":
indexes_for_local.append(index)
else:
tasks_input.append(
GenericTask(
token=token,
index=index,
judger=judger,
content=data_item,
)
)
self._num_tasks[token] = len(tasks_input)
self._num_indexes[token] = len(data_batch)
for task in tasks_input:
self._input_queues[task.judger].put(task, block=True, timeout=1)
if not self._processes:
self.logger.debug("Starting processes...")
for judger_name, judger_conf in self.judgers_config.items():
num_proc = judger_conf.pop("num_processes", 1)
self._processes.extend(
[
Process(
target=ParallelRouter._safe_process_worker,
kwargs={
"stop_event": self._stop_event,
"judger_name": judger_name,
"judger_conf": judger_conf,
"input_queue": self._input_queues[judger_name],
"output_queue": self._output_queue,
"exc_queue": self._exc_queue,
"config": self.subprocess_config,
},
daemon=True,
)
for _ in range(num_proc)
]
)
for p in self._processes:
p.start()
self.logger.debug(f"Start processes done, total {len(self._processes)}")
return token, indexes_for_local
def query(
self, token: str, timeout: float = 0
) -> Optional[List[Optional[Dict[str, Reward]]]]:
start = time.time()
while True:
self._try_catch_subprocess_exceptions()
try:
result = self._output_queue.get(timeout=0.1)
self._results_buffer[result.token].append(result)
except queue.Empty:
pass
if len(self._results_buffer[token]) == self._num_tasks[token]:
results = self._results_buffer.pop(token)
num_tasks = self._num_tasks.pop(token)
num_indexes = self._num_indexes.pop(token)
rewards: List[Dict[str, Reward]] = [{} for _ in range(num_indexes)]
for result in results:
reward = result.content
if result.judger in rewards[result.index]:
self.logger.warning(
f"{result.judger} already exists: {rewards[result.index]}, "
f"will replace --> {reward}"
)
rewards[result.index][result.judger] = reward
# convert empty dicts to None
return [r or None for r in rewards]
if timeout > 0 and (time.time() - start) > timeout:
raise TimeoutError(
f"Timeout after {timeout} seconds, got {len(self._results_buffer[token])} results, expected {self._num_tasks[token]}"
)
@staticmethod
def _safe_process_worker(
stop_event: EventClass,
judger_name: str,
judger_conf: dict,
input_queue: "Queue[GenericTask[InputData]]",
output_queue: "Queue[GenericTask[Reward]]",
exc_queue: "Queue[Tuple[str, Exception]]",
config: SubprocessConfig,
):
try:
ParallelRouter._process_worker(
stop_event=stop_event,
judger_name=judger_name,
judger_conf=judger_conf,
input_queue=input_queue,
output_queue=output_queue,
exc_queue=exc_queue,
config=config,
)
except Exception as e:
exc_queue.put((judger_name, e), timeout=1)
@staticmethod
def _process_worker(
stop_event: EventClass,
judger_name: str,
judger_conf: dict,
input_queue: "Queue[GenericTask[InputData]]",
output_queue: "Queue[GenericTask[Reward]]",
exc_queue: "Queue[Tuple[str, Exception]]",
config: SubprocessConfig,
):
from xtuner._lite import get_logger
logger = get_logger()
if config.loguru_handlers is not None:
for handler in config.loguru_handlers:
handler["enqueue"] = True
logger.add(*handler)
if config.worker_init_func is not None:
config.worker_init_func()
# Infer num threads for each stage according to configs
_num_threads = judger_conf.pop("concurrency_per_proc", (1, 1))
if isinstance(_num_threads, (tuple, list)) and len(_num_threads) == 2:
num_threads_s1, num_threads_s2 = _num_threads
elif isinstance(_num_threads, int):
num_threads_s1 = max(1, _num_threads // 2)
num_threads_s2 = max(1, _num_threads - num_threads_s1)
else:
raise TypeError(
"`concurrency_per_proc` in judger_conf should be int or "
f"Tuple[int, int], got {type(_num_threads)}: {_num_threads}"
)
# Lazy build judgers in subprocesses to avoid serialization errors
judger = ParallelRouter._build_judger(judger_name, judger_conf)
# input_queue = self._input_queues[judger_name]
# output_queue = self._output_queue
handle_queue: queue.Queue[GenericTask[JudgeStatus]] = queue.Queue()
log_prefix = f"[pid={os.getpid()},{judger_name}]"
def report_exc_wrapper(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except Exception as e:
stack_trace = traceback.format_exc()
logger.error(
f"{log_prefix} "
f"Thread worker of {judger_name} raised "
f"{type(e).__name__}: {e}",
f"Stack trace: {stack_trace}",
)
exc_queue.put((judger_name, e), timeout=1)
return wrapper
# Stage 1: input_queue -> judger.on_data_received -> handle_queue
@report_exc_wrapper
def thread_worker_s1():
while not stop_event.is_set():
try:
task = input_queue.get(timeout=0.1)
logger.debug(f"{log_prefix} dequeue input: {task}")
except queue.Empty:
logger.debug(f"{log_prefix} input queue empty")
time.sleep(0.1)
continue
data = task.content
if "metadata" not in data:
raise RuntimeError(
f"'metadata' not in data.keys(): {list(data.keys())}"
)
logger.debug(f"{log_prefix} on_data_received")
handle = judger.on_data_received(
data["prompt_messages"],
data["completion_messages"],
cast(dict, data["metadata"]),
)
logger.debug(f"{log_prefix} got handle")
new_task = GenericTask(
token=task.token,
index=task.index,
judger=task.judger,
content=handle,
)
while True:
try:
handle_queue.put(
new_task,
timeout=0.1,
)
logger.debug(f"{log_prefix} enqueue handle: {new_task}")
break
except queue.Full:
time.sleep(0.1)
# Stage 2: handle_queue -> judger.on_reward_required -> output_queue
@report_exc_wrapper
def thread_worker_s2():
while not stop_event.is_set():
try:
task = handle_queue.get(timeout=0.1)
logger.debug(f"{log_prefix} dequeue handle: {task}")
except queue.Empty:
logger.debug(f"{log_prefix} handle queue empty")
time.sleep(0.1)
continue
logger.debug(f"{log_prefix} on_reward_required")
reward = judger.on_reward_required(task.content)
logger.info(f"{log_prefix} got result")
new_task = GenericTask(
token=task.token,
index=task.index,
judger=task.judger,
content=reward,
)
while True:
try:
output_queue.put(
new_task,
timeout=0.1,
)
logger.debug(f"{log_prefix} enqueue output: {new_task}")
break
except queue.Full:
time.sleep(0.1)
from threading import Thread
threads: List[Thread] = []
for _ in range(num_threads_s1):
threads.append(Thread(target=thread_worker_s1, daemon=True))
for _ in range(num_threads_s2):
threads.append(Thread(target=thread_worker_s2, daemon=True))
for t in threads:
t.start()
for t in threads:
t.join()
@staticmethod
def _build_judger(judger_name: str, judger_conf: dict):
judger_conf = deepcopy(judger_conf)
judger_conf.pop("num_processes", None)
judger_conf.pop("concurrency_per_proc", None)
_type = judger_conf.pop("type", None)
if _type is None:
_type = judger_name
if _type not in registered_judgers:
raise KeyError(
f"{judger_name} use unregistered judger type: {_type}. "
f"Available judgers are: {list(registered_judgers.keys())}"
)
cls = registered_judgers[_type]
return cls(**judger_conf)
def _try_catch_subprocess_exceptions(self):
exc_handles: List[Tuple[str, Exception]] = []
while True:
try:
exc_handle = self._exc_queue.get(timeout=0.001)
exc_handles.append(exc_handle)
except queue.Empty:
break
if exc_handles:
error_message = "\n".join(
[
f"- [{judger_name}] {type(exc).__name__}: {exc}"
for judger_name, exc in exc_handles
]
)
raise RuntimeError(
"Following threads/processes raise exceptions unexpectedly:\n"
f"{error_message}\n"
"Program terminated"
)
def shutdown(self, timeout: float = 2.0):
if not hasattr(self, "_processes") or not self._processes:
return
if not self._stop_event.is_set():
self._stop_event.set()
connection.wait([p.sentinel for p in self._processes], timeout=timeout)
for p in self._processes:
if p.is_alive():
p.kill()
p.join()
self._processes = []

View file

@ -0,0 +1,485 @@
# flake8: noqa
# isort: skip_file
import multiprocessing
import re
from math import isclose
from typing import Optional, Union
from collections import defaultdict, Counter
from sympy import N, simplify
from sympy.parsing.latex import parse_latex
from sympy.parsing.sympy_parser import parse_expr
def extract_answer(pred_str: str, execute: bool = False) -> str:
if re.search("\\boxed|boxed|\\box|box", pred_str):
answer = re.split("\\boxed|boxed|\\box|box", pred_str)[-1]
if len(answer) == 0:
return ""
elif answer[0] == "{":
stack = 1
a = ""
for c in answer[1:]:
if c == "{":
stack += 1
a += c
elif c == "}":
stack -= 1
if stack == 0:
break
a += c
else:
a += c
else:
a = answer.split("$")[0].strip()
elif re.search("[Tt]he (final )?answer is:?", pred_str):
a = re.split("[Tt]he (final )?answer is:?", pred_str)[-1].strip().rstrip(".")
else: # use the last number
pred = re.findall(r"-?\d*\.?\d+", pred_str.replace(",", ""))
if len(pred) >= 1:
a = pred[-1]
else:
a = ""
choice = re.findall(r"([A-E]):\s*(.*)", a)
if len(choice) > 0:
for option, content in choice:
a = option
choice = re.findall(r"\(([A-E])\)\s*(.*)", a)
if len(choice) > 0:
for option, content in choice:
a = option
a = re.split(r"=|\\approx|≈", a)[-1]
# multiple lines
answer = ""
preds = re.split("\n", a)
for pred in preds:
if "\\begin{align" in pred or pred.endswith(":"):
continue
if pred != "" and pred[0] == ":":
pred = pred[1:]
if pred != "" and pred[-1] == ".":
pred = pred[:-1]
if pred != "" and pred[-1] == "/":
pred = pred[:-1]
pred = strip_string(pred)
pred = re.sub(r"^[a-zA-Z0-9]+[\)]\s*", "", pred)
for p in pred.split("{}"):
if p != "":
pred = p
break
pred = re.sub(r"^\{([A-Z])\}|\(([A-Z])\)", r"\1\2", pred)
if pred != "":
answer = pred
break
return answer
def _fix_fracs(string):
substrs = string.split("\\frac")
new_str = substrs[0]
if len(substrs) > 1:
substrs = substrs[1:]
for substr in substrs:
new_str += "\\frac"
if len(substr) > 0 and substr[0] == "{":
new_str += substr
else:
try:
assert len(substr) >= 2
except Exception:
return string
a = substr[0]
b = substr[1]
if b != "{":
if len(substr) > 2:
post_substr = substr[2:]
new_str += "{" + a + "}{" + b + "}" + post_substr
else:
new_str += "{" + a + "}{" + b + "}"
else:
if len(substr) > 2:
post_substr = substr[2:]
new_str += "{" + a + "}" + b + post_substr
else:
new_str += "{" + a + "}" + b
string = new_str
return string
def _fix_a_slash_b(string):
if len(string.split("/")) != 2:
return string
a = string.split("/")[0]
b = string.split("/")[1]
try:
if "sqrt" not in a:
a = int(a)
if "sqrt" not in b:
b = int(b)
assert string == f"{a}/{b}"
new_string = "\\frac{" + str(a) + "}{" + str(b) + "}"
return new_string
except Exception:
return string
def _fix_sqrt(string):
_string = re.sub(r"\\sqrt(\w+)", r"\\sqrt{\1}", string)
return _string
def strip_string(string):
string = str(string).strip()
# linebreaks
string = string.replace("\n", "")
# right "."
string = string.rstrip(".")
# remove inverse spaces
string = string.replace("\\!", "")
string = string.replace("\\ ", "")
# replace \\ with \
string = string.replace("\\\\", "\\")
string = string.replace("\\\\", "\\")
# replace tfrac and dfrac with frac
string = string.replace("tfrac", "frac")
string = string.replace("dfrac", "frac")
# remove \left and \right
string = string.replace("\\left", "")
string = string.replace("\\right", "")
# Remove unit: miles, dollars if after is not none
_string = re.sub(r"\\text{.*?}$", "", string).strip()
if _string != "" and _string != string:
# print("Warning: unit not removed: '{}' -> '{}'".format(string, _string))
string = _string
# Remove circ (degrees)
string = string.replace("^{\\circ}", "")
string = string.replace("^\\circ", "")
# remove dollar signs
string = string.replace("\\$", "")
string = string.replace("$", "")
string = string.replace("\\text", "")
string = string.replace("x\\in", "")
# remove percentage
string = string.replace("\\%", "")
string = string.replace(r"\%", "")
string = string.replace("%", "")
# " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string
string = string.replace(" .", " 0.")
string = string.replace("{.", "{0.")
# cdot
string = string.replace("\\cdot", "")
# inf
string = string.replace("infinity", "\\infty")
if "\\infty" not in string:
string = string.replace("inf", "\\infty")
string = string.replace("+\\inity", "\\infty")
# and
string = string.replace("and", "")
string = string.replace("\\mathbf", "")
# use regex to remove \mbox{...}
string = re.sub(r"\\mbox{.*?}", "", string)
# quote
string.replace("'", "")
string.replace('"', "")
# i, j
if "j" in string and "i" not in string:
string = string.replace("j", "i")
# replace a.000b where b is not number or b is end, with ab, use regex
string = re.sub(r"(\d+)\.0+([^\d])", r"\1\2", string)
string = re.sub(r"(\d+)\.0+$", r"\1", string)
# if empty, return empty string
if len(string) == 0:
return string
if string[0] == ".":
string = "0" + string
# to consider: get rid of e.g. "k = " or "q = " at beginning
if len(string.split("=")) == 2:
if len(string.split("=")[0]) <= 2:
string = string.split("=")[1]
string = _fix_sqrt(string)
string = string.replace(" ", "")
# \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b}
string = _fix_fracs(string)
# NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y
string = _fix_a_slash_b(string)
return string
def last_boxed_only_string(string):
idx = string.rfind("\\boxed")
if idx < 0:
idx = string.rfind("\\fbox")
if idx < 0:
return None
i = idx
right_brace_idx = None
num_left_braces_open = 0
while i < len(string):
if string[i] == "{":
num_left_braces_open += 1
if string[i] == "}":
num_left_braces_open -= 1
if num_left_braces_open == 0:
right_brace_idx = i
break
i += 1
if right_brace_idx is None:
retval = None
else:
retval = string[idx : right_brace_idx + 1]
return retval
def extract_answer(pred_str: str, execute: bool = False) -> str:
if re.search("\boxed|boxed", pred_str):
answer = re.split("\boxed|boxed", pred_str)[-1]
if len(answer) == 0:
return ""
elif answer[0] == "{":
stack = 1
a = ""
for c in answer[1:]:
if c == "{":
stack += 1
a += c
elif c == "}":
stack -= 1
if stack == 0:
break
a += c
else:
a += c
else:
a = answer.split("$")[0].strip()
elif re.search("[Tt]he (final )?answer is:?", pred_str):
a = re.split("[Tt]he (final )?answer is:?", pred_str)[-1].strip().rstrip(".")
elif pred_str.startswith("```python") and execute:
# fall back to program
from lagent import get_tool
a = get_tool("IPythonInteractive").exec(pred_str).value or ""
else: # use the last number
pred = re.findall(r"-?\d*\.?\d+", pred_str.replace(",", ""))
if len(pred) >= 1:
a = pred[-1]
else:
a = ""
# multiple lines
pred = a.split("\n")[0]
if pred != "" and pred[0] == ":":
pred = pred[1:]
if pred != "" and pred[-1] == ".":
pred = pred[:-1]
if pred != "" and pred[-1] == "/":
pred = pred[:-1]
pred = strip_string(pred)
return pred
def is_digit(s):
try:
float(str(s).replace(",", ""))
return True
except ValueError:
return False
def math_equal(
prediction: Union[bool, float, str],
reference: Union[float, str],
include_percentage: bool = True,
is_close: bool = True,
tolerance: float = 1e-4,
timeout: bool = False,
) -> bool:
"""Exact match of math if and only if:
1. numerical equal: both can convert to float and are equal
2. symbolic equal: both can convert to sympy expression and are equal
"""
try: # 1. numerical equal
if is_digit(prediction) and is_digit(reference):
prediction = float(str(prediction).replace(",", ""))
reference = float(str(reference).replace(",", ""))
# number questions
if include_percentage:
gt_result = [reference / 100, reference, reference * 100]
else:
gt_result = [reference]
for item in gt_result:
try:
if is_close:
if isclose(item, prediction, rel_tol=tolerance):
return True
else:
if item == prediction:
return True
except Exception:
continue
return False
except Exception:
pass
if not prediction and prediction not in [0, False]:
return False
# 2. symbolic equal
reference = str(reference).strip()
prediction = str(prediction).strip()
## deal with [], (), {}
pred_str, ref_str = prediction, reference
if (
prediction.startswith("[")
and prediction.endswith("]")
and not reference.startswith("(")
) or (
prediction.startswith("(")
and prediction.endswith(")")
and not reference.startswith("[")
):
pred_str = pred_str.strip("[]()")
ref_str = ref_str.strip("[]()")
for s in ["{", "}", "(", ")"]:
ref_str = ref_str.replace(s, "")
pred_str = pred_str.replace(s, "")
if pred_str == ref_str:
return True
## [a, b] vs. [c, d], return a==c and b==d
if (
(prediction.startswith("[") and prediction.endswith("]"))
and (reference.startswith("[") and reference.endswith("]"))
or (prediction.startswith("(") and prediction.endswith(")"))
and (reference.startswith("(") and reference.endswith(")"))
):
pred_parts = prediction[1:-1].split(",")
ref_parts = reference[1:-1].split(",")
if len(pred_parts) == len(ref_parts):
if all(
[
math_equal(
pred_parts[i], ref_parts[i], include_percentage, is_close
)
for i in range(len(pred_parts))
]
):
return True
# symbolic equal with sympy
if timeout:
if call_with_timeout(symbolic_equal_process, prediction, reference):
return True
else:
if symbolic_equal(prediction, reference):
return True
return False
def math_equal_process(param):
return math_equal(param[-2], param[-1])
def math_equal_process(param):
if param[-2] is None:
return False
return math_equal(param[-2], param[-1])
def symbolic_equal(a, b):
def _parse(s):
for f in [parse_latex, parse_expr]:
try:
return f(s)
except Exception:
pass
return s
a = _parse(a)
b = _parse(b)
try:
if simplify(a - b) == 0:
return True
except Exception:
pass
try:
if isclose(N(a), N(b), rel_tol=1e-3):
return True
except Exception:
pass
return False
def symbolic_equal_process(a, b, output_queue):
result = symbolic_equal(a, b)
output_queue.put(result)
def call_with_timeout(func, *args, timeout=1, **kwargs):
output_queue = multiprocessing.Queue()
process_args = args + (output_queue,)
process = multiprocessing.Process(target=func, args=process_args, kwargs=kwargs)
process.start()
process.join(timeout)
if process.is_alive():
process.terminate()
process.join()
return False
return output_queue.get()
def math_majority_vote(answers: list, majority: Optional[int] = None):
# threshold = len(answers) // 2 + 1
ans2cnt, ans2idx = Counter(), defaultdict(list)
for i, ans in enumerate(answers):
if isinstance(ans, str) and ans.strip():
for key in ans2cnt.keys():
if math_equal(ans, key):
ans2cnt[key] += 1
ans2idx[key].append(i)
break
else:
ans2cnt[ans] += 1
ans2idx[ans].append(i)
if ans2cnt:
maj, cnt = ans2cnt.most_common(1)[0]
if maj and cnt >= (majority or 1):
return maj, ans2idx[maj]
return None, []

View file

@ -0,0 +1,53 @@
import importlib.util
import os
import types
class ConfigDict(dict):
def __getattr__(self, item):
if item in self:
return self[item]
raise AttributeError(f"'ConfigDict' object has no attribute '{item}'")
def __setattr__(self, key, value):
self[key] = value
class Config:
@staticmethod
def fromfile(file_path):
config_dict = ConfigDict()
if not os.path.isfile(file_path):
raise FileNotFoundError(f"Config file not found: {file_path}")
# Load the configuration file as a module
spec = importlib.util.spec_from_file_location("config_module", file_path)
config_module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(config_module)
# Function to convert nested dictionaries to ConfigDict recursively
def convert_to_config_dict(d):
if isinstance(d, dict):
config_dict = ConfigDict()
for key, value in d.items():
if isinstance(value, dict):
config_dict[key] = convert_to_config_dict(value)
else:
config_dict[key] = value
return config_dict
else:
return d
# Retrieve all attributes (variables) from the module
for attribute_name in dir(config_module):
if not attribute_name.startswith("__"):
config_dict[attribute_name] = convert_to_config_dict(
getattr(config_module, attribute_name)
)
for key, value in list(config_dict.items()):
if isinstance(value, (types.FunctionType, types.ModuleType)):
config_dict.pop(key)
return config_dict