mirror of
https://github.com/InternLM/InternBootcamp.git
synced 2026-04-28 17:29:37 +00:00
init-commit
This commit is contained in:
commit
18a552597a
3461 changed files with 1150579 additions and 0 deletions
201
examples/xpuyu_usage/LICENSE
Executable file
201
examples/xpuyu_usage/LICENSE
Executable file
|
|
@ -0,0 +1,201 @@
|
|||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright [yyyy] [name of copyright owner]
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
91
examples/xpuyu_usage/README.md
Executable file
91
examples/xpuyu_usage/README.md
Executable file
|
|
@ -0,0 +1,91 @@
|
|||
# bootcamp Training with Xtuner
|
||||
|
||||
|
||||
|
||||
## 🚄 Training Tutorial
|
||||
|
||||
### 1. Install Dependencies
|
||||
|
||||
We utilizes [XTuner](https://github.com/InternLM/xtuner/tree/main) as the training engine.
|
||||
|
||||
You should make sure that InternBootcamp is successfully installed.
|
||||
|
||||
```bash
|
||||
pip install -e $InternBootcamp_path
|
||||
```
|
||||
|
||||
Then install xtuner and its dependencies.
|
||||
|
||||
```bash
|
||||
pip install torch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 --index-url https://download.pytorch.org/whl/cu124
|
||||
pip install flash-attn --no-build-isolation
|
||||
pip install xtuner[all]==0.2.0rc0
|
||||
```
|
||||
|
||||
### 2. Prepare Data
|
||||
|
||||
|
||||
The bootcamp data can be transfered into training format by using examples/xpuyu_usage/xpuyu_data_preprocess.py.
|
||||
|
||||
|
||||
**Example usage:**
|
||||
```python
|
||||
python examples/xpuyu_usage/xpuyu_preprocess.py --src examples/bootcamp_generator_outputs/{%Y-%m-%d-%H:%M:%S}
|
||||
```
|
||||
|
||||
|
||||
|
||||
### 3. Prepare your training config
|
||||
|
||||
Prepare your training config for starting GRPO training.
|
||||
|
||||
An example config is in
|
||||
|
||||
```
|
||||
examples/xpuyu_usage/bootcamp_rl/configs/example_training_config.py
|
||||
```
|
||||
|
||||
|
||||
### 4. Start Training
|
||||
|
||||
|
||||
```bash
|
||||
cd examples/xpuyu_usage
|
||||
|
||||
GPUS_PER_NODE=$(python -c 'import torch; print(torch.cuda.device_count())')
|
||||
|
||||
# Number of GPU workers, for single-worker training, please set to 1
|
||||
NNODES=${WORLD_SIZE:-1} # modified to adapt cluster
|
||||
|
||||
# The rank of this worker, should be in {0, ..., WORKER_CNT-1}, for single-worker training, please set to 0
|
||||
NODE_RANK=${RANK:-0} # modified to adapt cluster
|
||||
|
||||
# The ip address of the rank-0 worker, for single-worker training, please set to localhost
|
||||
MASTER_ADDR=${MASTER_ADDR:-localhost}
|
||||
|
||||
# The port for communication
|
||||
MASTER_PORT=${MASTER_PORT:-6001}
|
||||
|
||||
DISTRIBUTED_ARGS="
|
||||
--nproc_per_node $GPUS_PER_NODE \
|
||||
--nnodes $NNODES \
|
||||
--node_rank $NODE_RANK \
|
||||
--master_addr $MASTER_ADDR \
|
||||
--master_port $MASTER_PORT
|
||||
"
|
||||
|
||||
echo $DISTRIBUTED_ARGS
|
||||
|
||||
torchrun $DISTRIBUTED_ARGS train_grpo.py ./bootcamp_rl/configs/example_training_config.py --work_dir examples/xpuyu_usage/ckpts/experiment_name
|
||||
```
|
||||
|
||||
|
||||
### 5. Training Curve Visualization
|
||||
|
||||
You could use examples/xpuyu_usage/report_to_wandb.py to visualize your training curve.
|
||||
|
||||
```bash
|
||||
python examples/xpuyu_usage/report_to_wandb.py examples/xpuyu_usage/ckpts/{experiment_name}/{timestamp}/rank0.log.jsonl {wandb_project_name}
|
||||
```
|
||||
|
||||
|
||||
18
examples/xpuyu_usage/bootcamp_rl/datasets/__init__.py
Executable file
18
examples/xpuyu_usage/bootcamp_rl/datasets/__init__.py
Executable file
|
|
@ -0,0 +1,18 @@
|
|||
# Copyright (c) InternLM. All rights reserved.
|
||||
from .prompt import bootcampPromptDataset, PromptCollator, InfiniteDataLoaderIter
|
||||
from .trajectory import (
|
||||
InferDataset,
|
||||
TrajectoryCollator,
|
||||
TrajectoryDataset,
|
||||
TrajectoryDatasetWithFilter,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"bootcampPromptDataset",
|
||||
"PromptCollator",
|
||||
"InferDataset",
|
||||
"TrajectoryDataset",
|
||||
"TrajectoryDatasetWithFilter",
|
||||
"TrajectoryCollator",
|
||||
"InfiniteDataLoaderIter",
|
||||
]
|
||||
214
examples/xpuyu_usage/bootcamp_rl/datasets/prompt.py
Executable file
214
examples/xpuyu_usage/bootcamp_rl/datasets/prompt.py
Executable file
|
|
@ -0,0 +1,214 @@
|
|||
# Copyright (c) InternLM. All rights reserved.
|
||||
import json
|
||||
|
||||
import time
|
||||
import numpy as np
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
from torch.utils.data import Dataset
|
||||
from xtuner._lite import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
def load_hf_datasets(repo, split="train"):
|
||||
dataset = load_dataset(repo, split=split)
|
||||
converted_ds = []
|
||||
for sample in dataset:
|
||||
converted_ds.append(
|
||||
{
|
||||
"pass_rate": sample["pass_rate"],
|
||||
"message_data": [{"role": "user", "content": sample["question"]}],
|
||||
"metadata": {
|
||||
"data_source": "math", # for the router to know which judger to use
|
||||
"gold_answer": sample["gold_answer"],
|
||||
},
|
||||
}
|
||||
)
|
||||
logger.info(f"Loaded {len(converted_ds)} samples from {repo}")
|
||||
return converted_ds
|
||||
|
||||
|
||||
def load_jsonl_datasets(file_path):
|
||||
subsample_ratio = 1.0
|
||||
if "::" in file_path:
|
||||
file_path, subsample_ratio = file_path.split("::")
|
||||
subsample_ratio = float(subsample_ratio)
|
||||
with open(file_path, "r") as f:
|
||||
lines = f.readlines()
|
||||
datasets = []
|
||||
for line in lines:
|
||||
sample = json.loads(line)
|
||||
if "message_data" not in sample:
|
||||
datasets.append(
|
||||
{
|
||||
"pass_rate": sample["pass_rate"],
|
||||
"message_data": [{"role": "user", "content": sample["question"]}],
|
||||
"metadata": {
|
||||
"data_source": "math", # for the router to know which judger to use
|
||||
"gold_answer": sample["gold_answer"],
|
||||
},
|
||||
}
|
||||
)
|
||||
else:
|
||||
datasets.append(sample)
|
||||
if subsample_ratio < 1.0:
|
||||
np.random.seed(0)
|
||||
datasets = np.random.choice(
|
||||
datasets, int(len(datasets) * subsample_ratio), replace=False
|
||||
).tolist()
|
||||
|
||||
logger.info(f"Loaded {len(datasets)} samples from {file_path}")
|
||||
return datasets
|
||||
|
||||
|
||||
def balance_difficulty_with_cfg(dataset, difficulty_balance_cfg):
|
||||
balanced_dataset = []
|
||||
for sample in dataset:
|
||||
pass_rate = sample["pass_rate"]
|
||||
for (low, high), repeat in difficulty_balance_cfg:
|
||||
if low <= pass_rate < high:
|
||||
balanced_dataset.extend([sample] * repeat)
|
||||
break
|
||||
logger.info(
|
||||
f"After difficulty balancing, the dataset size is {len(balanced_dataset)}"
|
||||
)
|
||||
return balanced_dataset
|
||||
|
||||
|
||||
class bootcampPromptDataset(Dataset):
|
||||
def __init__(self, path, tokenizer, difficulty_balance_cfg=None):
|
||||
if isinstance(path, str):
|
||||
path = [path]
|
||||
dataset = []
|
||||
for p in path:
|
||||
if p.endswith(".jsonl"):
|
||||
dataset.extend(load_jsonl_datasets(p))
|
||||
else:
|
||||
dataset.extend(load_hf_datasets(p))
|
||||
if difficulty_balance_cfg:
|
||||
dataset = balance_difficulty_with_cfg(dataset, difficulty_balance_cfg)
|
||||
self.dataset = dataset
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
def __len__(self):
|
||||
return len(self.dataset)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
sample = self.dataset[idx]
|
||||
input_ids = self.tokenizer.apply_chat_template(
|
||||
sample["message_data"], add_generation_prompt=True
|
||||
)
|
||||
sample["input_ids"] = input_ids
|
||||
sample["labels"] = input_ids
|
||||
sample["num_tokens"] = len(input_ids)
|
||||
return sample
|
||||
|
||||
|
||||
class PromptCollator:
|
||||
|
||||
def __init__(self, pad_token_id=0, ignore_id=-100, pack_batch=False):
|
||||
self.pack_batch = pack_batch
|
||||
self.pad_token_id = pad_token_id
|
||||
self.ignore_id = ignore_id
|
||||
|
||||
def __call__(self, instances):
|
||||
|
||||
_instances = []
|
||||
for ins in instances:
|
||||
if isinstance(ins, list):
|
||||
_instances.extend(ins)
|
||||
else:
|
||||
_instances.append(ins)
|
||||
|
||||
instances = _instances
|
||||
|
||||
input_ids = []
|
||||
labels = []
|
||||
num_tokens = []
|
||||
metadatas = []
|
||||
message_datas = []
|
||||
|
||||
for data in instances:
|
||||
|
||||
input_ids.append(torch.LongTensor(data["input_ids"]))
|
||||
labels.append(torch.LongTensor(data["labels"]))
|
||||
metadatas.append(data["metadata"])
|
||||
message_datas.append(data["message_data"])
|
||||
|
||||
if isinstance(data["num_tokens"], int):
|
||||
num_tokens.append(data["num_tokens"])
|
||||
else:
|
||||
num_tokens.extend(data["num_tokens"])
|
||||
|
||||
attention_mask = [torch.ones_like(ids) for ids in input_ids]
|
||||
num_tokens = torch.IntTensor(num_tokens)
|
||||
|
||||
if len(instances) > 1 and self.pack_batch:
|
||||
|
||||
input_ids = torch.cat(input_ids, dim=0).unsqueeze(0)
|
||||
labels = torch.cat(labels, dim=0).unsqueeze(0)
|
||||
attention_mask = torch.cat(attention_mask, dim=0).unsqueeze(0)
|
||||
|
||||
elif len(instances) > 1 and not self.pack_batch:
|
||||
|
||||
input_ids = pad_sequence(
|
||||
input_ids, batch_first=True, padding_value=self.pad_token_id
|
||||
)
|
||||
labels = pad_sequence(
|
||||
labels, batch_first=True, padding_value=self.ignore_id
|
||||
)
|
||||
attention_mask = pad_sequence(
|
||||
attention_mask, batch_first=True, padding_value=0
|
||||
)
|
||||
else:
|
||||
input_ids = torch.stack(input_ids)
|
||||
labels = torch.stack(labels)
|
||||
attention_mask = torch.stack(attention_mask)
|
||||
|
||||
if input_ids.shape != labels.shape:
|
||||
logger.error(f"[instances] {instances}")
|
||||
logger.error(f"[num_tokens] {num_tokens}")
|
||||
logger.error(f"[input_ids] {input_ids}")
|
||||
logger.error(f"[labels] {labels}")
|
||||
raise RuntimeError(
|
||||
"The shape of input_ids and labels must be "
|
||||
f"equal, but found {input_ids.shape} and "
|
||||
f"{labels.shape}."
|
||||
)
|
||||
data_dict = {
|
||||
"input_ids": input_ids,
|
||||
"labels": labels,
|
||||
"num_tokens": num_tokens,
|
||||
"attention_mask": attention_mask.bool(),
|
||||
"metadata": metadatas,
|
||||
"message_data": message_datas,
|
||||
}
|
||||
|
||||
return data_dict
|
||||
|
||||
class InfiniteDataLoaderIter:
|
||||
def __init__(self, dataloader):
|
||||
self.dataloader = dataloader
|
||||
self.iterator = iter(dataloader)
|
||||
self._epoch = 0
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
try:
|
||||
data = next(self.iterator)
|
||||
except StopIteration:
|
||||
logger.info(f"Dataloader epoch {self._epoch} finished. Start a new epoch.")
|
||||
self._epoch += 1
|
||||
if hasattr(self.dataloader, 'sampler') and hasattr(
|
||||
self.dataloader.sampler, 'set_epoch'):
|
||||
# In case the` _SingleProcessDataLoaderIter` has no sampler,
|
||||
# or data loader uses `SequentialSampler` in Pytorch.
|
||||
self.dataloader.sampler.set_epoch(self._epoch)
|
||||
time.sleep(2) # Prevent possible deadlock during epoch transition
|
||||
self.iterator = iter(self.dataloader)
|
||||
data = next(self.iterator)
|
||||
return data
|
||||
166
examples/xpuyu_usage/bootcamp_rl/datasets/trajectory.py
Executable file
166
examples/xpuyu_usage/bootcamp_rl/datasets/trajectory.py
Executable file
|
|
@ -0,0 +1,166 @@
|
|||
# Copyright (c) InternLM. All rights reserved.
|
||||
import json
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from xtuner._lite import get_logger
|
||||
from xtuner._lite.algorithms.sft.dataset import SftCollator
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
class InferDataset(torch.utils.data.Dataset):
|
||||
|
||||
def __init__(self, prompts_input_ids, responses_ids, message_data, metadata):
|
||||
super().__init__()
|
||||
|
||||
assert (
|
||||
len(prompts_input_ids)
|
||||
== len(responses_ids)
|
||||
== len(message_data)
|
||||
== len(metadata)
|
||||
), f"The length of prompts_input_ids, responses_ids, message_data, metadata should be the same, but got {len(prompts_input_ids)}, {len(responses_ids)}, {len(message_data)}, {len(metadata)}"
|
||||
self.prompts_input_ids = prompts_input_ids
|
||||
self.responses_ids = responses_ids
|
||||
self.message_data = message_data
|
||||
self.metadata = metadata
|
||||
|
||||
def __len__(self):
|
||||
return len(self.prompts_input_ids)
|
||||
|
||||
def __getitem__(self, item):
|
||||
|
||||
prompt_input_ids = self.prompts_input_ids[item]
|
||||
response_ids = self.responses_ids[item]
|
||||
num_prefill_tokens = len(prompt_input_ids)
|
||||
|
||||
input_ids = prompt_input_ids + response_ids
|
||||
labels = [-100] * (num_prefill_tokens - 1) + response_ids + [-100]
|
||||
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"labels": labels,
|
||||
"num_tokens": len(input_ids),
|
||||
"message_data": self.message_data[item],
|
||||
"metadata": self.metadata[item],
|
||||
}
|
||||
|
||||
|
||||
class TrajectoryDataset(torch.utils.data.Dataset):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self._num_action_tokens = 0
|
||||
self._num_total_tokens = 0
|
||||
self._trajectories = []
|
||||
|
||||
@property
|
||||
def num_action_tokens(self):
|
||||
return self._num_action_tokens.item()
|
||||
|
||||
@property
|
||||
def num_total_tokens(self):
|
||||
return self._num_total_tokens
|
||||
|
||||
def update(self, trajectories):
|
||||
num_total_tokens = 0
|
||||
num_action_tokens = 0
|
||||
for data in trajectories:
|
||||
labels = np.array(data["labels"])
|
||||
num_total_tokens += labels.size
|
||||
num_action_tokens += (labels >= 0).sum()
|
||||
|
||||
self._num_action_tokens = num_action_tokens
|
||||
self._num_total_tokens = num_total_tokens
|
||||
|
||||
self._trajectories = trajectories
|
||||
|
||||
def dump_jsonl(self, path, tokenizer, debug=False):
|
||||
|
||||
with open(path, "w", encoding="utf8") as f:
|
||||
for data in self._trajectories:
|
||||
json_line = {
|
||||
"sequence": (
|
||||
data["sequence_text"]
|
||||
if "sequence_text" in data
|
||||
else tokenizer.decode(data["input_ids"])
|
||||
),
|
||||
"num_tokens": data["num_tokens"],
|
||||
}
|
||||
json_line["judger_reward"] = data["judger_reward"]
|
||||
json_line["judger_advantage"] = data["judger_advantage"]
|
||||
|
||||
if debug:
|
||||
json_line["input_ids"] = data["input_ids"]
|
||||
json_line["labels"] = data["labels"]
|
||||
|
||||
json_str = json.dumps(json_line, ensure_ascii=False)
|
||||
f.write(json_str + "\n")
|
||||
|
||||
def dump_log(self, path, tokenizer, debug=False):
|
||||
|
||||
with open(path, "w", encoding="utf8") as f:
|
||||
for data in self._trajectories:
|
||||
log_string = f"[sequence]:\n{data['sequence_text'] if 'sequence_text' in data else tokenizer.decode(data['input_ids'])}\n\n"
|
||||
log_string += f"[num_tokens]: {data['num_tokens']}\n"
|
||||
log_string += f"[judger_reward]: {data['judger_reward']}\n"
|
||||
log_string += f"[judger_advantage]: {data['judger_advantage']}\n"
|
||||
f.write(log_string + "\n\n=======================\n")
|
||||
|
||||
def __len__(self):
|
||||
return len(self._trajectories)
|
||||
|
||||
def __getitem__(self, item):
|
||||
|
||||
return self._trajectories[item]
|
||||
|
||||
|
||||
class TrajectoryDatasetWithFilter(TrajectoryDataset):
|
||||
def __init__(self, repeat_k=1, only_keep_1_pair=True):
|
||||
super().__init__()
|
||||
self.repeat_k = repeat_k
|
||||
self.only_keep_1_pair = only_keep_1_pair
|
||||
|
||||
def update(self, trajectories):
|
||||
# split trajectories into k groups: (a, a, b, b, c, c) -> [(a, a), (b, b), (c, c)]
|
||||
groups = [
|
||||
trajectories[i : i + self.repeat_k]
|
||||
for i in range(0, len(trajectories), self.repeat_k)
|
||||
]
|
||||
keeped_trajectories = []
|
||||
for group in groups:
|
||||
correctness = [1 if data["judger_reward"] == 1 else 0 for data in group]
|
||||
correct = [data for data in group if data["judger_reward"] == 1]
|
||||
incorrect = [data for data in group if data["judger_reward"] != 1]
|
||||
pass_rate = sum(correctness) / len(correctness)
|
||||
if self.only_keep_1_pair:
|
||||
if pass_rate == 1 or pass_rate == 0:
|
||||
continue
|
||||
# max keep 1 correct and 1 incorrect
|
||||
correct = random.choice(correct)
|
||||
incorrect = random.choice(incorrect)
|
||||
correct["pass_rate"] = pass_rate
|
||||
incorrect["pass_rate"] = pass_rate
|
||||
keeped_trajectories.append(correct)
|
||||
keeped_trajectories.append(incorrect)
|
||||
else:
|
||||
if pass_rate == 1 or pass_rate == 0:
|
||||
continue
|
||||
for data in group:
|
||||
data["pass_rate"] = pass_rate
|
||||
keeped_trajectories.append(data)
|
||||
|
||||
super().update(keeped_trajectories)
|
||||
|
||||
|
||||
class TrajectoryCollator(SftCollator):
|
||||
|
||||
def __call__(self, instances):
|
||||
|
||||
data = super().__call__(instances)
|
||||
data["judger_rewards"] = [item["judger_reward"] for item in instances]
|
||||
data["judger_advantages"] = [item["judger_advantage"] for item in instances]
|
||||
if "pass_rate" in instances[0]:
|
||||
data["pass_rate"] = [item["pass_rate"] for item in instances]
|
||||
return data
|
||||
19
examples/xpuyu_usage/bootcamp_rl/judgers/__init__.py
Executable file
19
examples/xpuyu_usage/bootcamp_rl/judgers/__init__.py
Executable file
|
|
@ -0,0 +1,19 @@
|
|||
# Copyright (c) InternLM. All rights reserved.
|
||||
from .base_judger import (
|
||||
BaseJudger,
|
||||
register_judger,
|
||||
registered_judgers,
|
||||
)
|
||||
from .math_judger import MathJudger
|
||||
from .router import InputData, ParallelRouter
|
||||
from .bootcamp_judger import bootcampJudger
|
||||
|
||||
__all__ = [
|
||||
"register_judger",
|
||||
"registered_judgers",
|
||||
"BaseJudger",
|
||||
"MathJudger",
|
||||
"InputData",
|
||||
"ParallelRouter",
|
||||
"bootcampJudger",
|
||||
]
|
||||
61
examples/xpuyu_usage/bootcamp_rl/judgers/base_judger.py
Executable file
61
examples/xpuyu_usage/bootcamp_rl/judgers/base_judger.py
Executable file
|
|
@ -0,0 +1,61 @@
|
|||
# Copyright (c) InternLM. All rights reserved.
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import (
|
||||
Dict,
|
||||
Generic,
|
||||
List,
|
||||
Optional,
|
||||
Type,
|
||||
TypedDict,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
|
||||
T = TypeVar("T")
|
||||
MessageItem = TypedDict("MessageItem", {"role": str, "content": str})
|
||||
Reward = Union[float, List[float], None]
|
||||
MetaData = TypedDict("MetaData", {"data_source": str})
|
||||
|
||||
|
||||
@dataclass
|
||||
class JudgeStatus(Generic[T]):
|
||||
ok: bool = True
|
||||
reason: Optional[str] = None
|
||||
handle: Optional[T] = None
|
||||
|
||||
|
||||
class BaseJudger(ABC):
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def on_data_received(
|
||||
self,
|
||||
prompt_messages: List[MessageItem],
|
||||
completion_messages: List[MessageItem],
|
||||
metadata: dict,
|
||||
) -> JudgeStatus:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def on_reward_required(
|
||||
self,
|
||||
status: JudgeStatus,
|
||||
timeout: Optional[float] = None,
|
||||
) -> Reward:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
registered_judgers: Dict[str, Type[BaseJudger]] = {}
|
||||
|
||||
|
||||
def register_judger(name: str):
|
||||
global registered_judgers
|
||||
|
||||
def wrapper(cls):
|
||||
assert name not in registered_judgers, f"{name} already in {registered_judgers}"
|
||||
registered_judgers[name] = cls
|
||||
return cls
|
||||
|
||||
return wrapper
|
||||
81
examples/xpuyu_usage/bootcamp_rl/judgers/bootcamp_judger.py
Executable file
81
examples/xpuyu_usage/bootcamp_rl/judgers/bootcamp_judger.py
Executable file
|
|
@ -0,0 +1,81 @@
|
|||
# Copyright (c) InternLM. All rights reserved.
|
||||
import random
|
||||
import re
|
||||
import time
|
||||
from typing import List, Optional, Tuple
|
||||
import asyncio
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
import requests
|
||||
import internbootcamp
|
||||
|
||||
from .base_judger import BaseJudger, JudgeStatus, MessageItem, Reward, register_judger
|
||||
|
||||
|
||||
|
||||
@register_judger("bootcamp_judger")
|
||||
class bootcampJudger(BaseJudger):
|
||||
def __init__(
|
||||
self,
|
||||
stop_word="<|im_end|>",
|
||||
format_score=0,
|
||||
format_penalty=True,
|
||||
short_penalty=True,
|
||||
short_threshold=128,
|
||||
|
||||
):
|
||||
super().__init__()
|
||||
self.stop_word = stop_word
|
||||
self.format_score = format_score
|
||||
self.format_penalty = format_penalty
|
||||
self.short_penalty = short_penalty
|
||||
self.short_threshold = short_threshold
|
||||
|
||||
def on_data_received(
|
||||
self,
|
||||
prompt_messages: List[MessageItem],
|
||||
completion_messages: List[MessageItem],
|
||||
metadata: dict, # 存在数据集对应的字段里面,想存啥都可以,自己解析出来就行
|
||||
) -> JudgeStatus:
|
||||
question = prompt_messages[-1]["content"]
|
||||
response = completion_messages[-1]["content"]
|
||||
identity = metadata["ground_truth"]
|
||||
data_source = metadata["data_source"]
|
||||
verify_label = None
|
||||
if not response.strip().endswith(self.stop_word):
|
||||
# If the response does not end with the stop word, it is not a complete response, treat as incorrect
|
||||
verify_label = False
|
||||
return JudgeStatus(
|
||||
ok=True,
|
||||
handle={
|
||||
"data_source": data_source,
|
||||
"question": question,
|
||||
"response": response,
|
||||
"identity": identity,
|
||||
"verify_label": verify_label,
|
||||
},
|
||||
)
|
||||
|
||||
def on_reward_required( # 把judger的判断结果转成reward的score
|
||||
self, status: JudgeStatus, timeout: Optional[float] = None
|
||||
) -> Reward:
|
||||
if status.handle["verify_label"] is False:
|
||||
score = 0.0
|
||||
return score
|
||||
# 把judger的判断结果转成reward的score
|
||||
data_source = status.handle["data_source"]
|
||||
response = status.handle["response"]
|
||||
identity = status.handle["identity"]
|
||||
prompt = status.handle["question"]
|
||||
bootcamp_cls= getattr(internbootcamp, data_source[0].upper() + data_source[1:] + "bootcamp")
|
||||
try:
|
||||
score = bootcamp_cls.verify_score(response,identity,format_score=self.format_score,format_penalty=self.format_penalty,short_penalty=self.short_penalty,short_threshold=self.short_threshold)
|
||||
except:
|
||||
score = bootcamp_cls.verify_score(response,identity,format_score=self.format_score)
|
||||
return score
|
||||
# print(f"[Debug] Prompt: {prompt}")
|
||||
# print(f"[Debug]: score: {score}, response: {response}")
|
||||
# if type(score) == int:
|
||||
# assert score >= 0 and score <= 1
|
||||
# return score
|
||||
# return 0
|
||||
198
examples/xpuyu_usage/bootcamp_rl/judgers/math_judger.py
Executable file
198
examples/xpuyu_usage/bootcamp_rl/judgers/math_judger.py
Executable file
|
|
@ -0,0 +1,198 @@
|
|||
# Copyright (c) InternLM. All rights reserved.
|
||||
import random
|
||||
import re
|
||||
import time
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import requests
|
||||
|
||||
from .base_judger import BaseJudger, JudgeStatus, MessageItem, Reward, register_judger
|
||||
from .utils import extract_answer, math_equal
|
||||
|
||||
|
||||
@register_judger("math_judger")
|
||||
class MathJudger(BaseJudger):
|
||||
verify_prompt = """You are a helpful assistant who evaluates the correctness and quality of models' outputs.
|
||||
Please as a grading expert, judge whether the final answers given by the candidates below are consistent with the standard answers, that is, whether the candidates answered correctly.
|
||||
|
||||
Here are some evaluation criteria:
|
||||
1. Please refer to the given standard answer. You don't need to re-generate the answer to the question because the standard answer has been given. You only need to judge whether the candidate's answer is consistent with the standard answer according to the form of the question. Don't try to answer the original question. You can assume that the standard answer is definitely correct.
|
||||
2. Because the candidate's answer may be different from the standard answer in the form of expression, before making a judgment, please understand the question and the standard answer first, and then judge whether the candidate's answer is correct, but be careful not to try to answer the original question.
|
||||
3. Some answers may contain multiple items, such as multiple-choice questions, multiple-select questions, fill-in-the-blank questions, etc. As long as the answer is the same as the standard answer, it is enough. For multiple-select questions and multiple-blank fill-in-the-blank questions, the candidate needs to answer all the corresponding options or blanks correctly to be considered correct.
|
||||
4. Some answers may be expressed in different ways, such as some answers may be a mathematical expression, some answers may be a textual description, as long as the meaning expressed is the same. And some formulas are expressed in different ways, but they are equivalent and correct.
|
||||
5. If the prediction is given with \\boxed{{}}, please ignore the \\boxed{{}} and only judge whether the candidate's answer is consistent with the standard answer.
|
||||
|
||||
Please judge whether the following answers are consistent with the standard answer based on the above criteria. Grade the predicted answer of this new question as one of:
|
||||
A: CORRECT
|
||||
B: INCORRECT
|
||||
Just return the letters \"A\" or \"B\", with no text around it.
|
||||
|
||||
Here is your task. Simply reply with either CORRECT, INCORRECT. Don't apologize or correct yourself if there was a mistake; we are just trying to grade the answer.
|
||||
|
||||
|
||||
<Original Question Begin>:
|
||||
{question}
|
||||
<Original Question End>
|
||||
|
||||
|
||||
<Gold Target Begin>:
|
||||
{gold_answer}
|
||||
<Gold Target End>
|
||||
|
||||
|
||||
<Predicted Answer Begin>:
|
||||
{answer}
|
||||
<Predicted End>
|
||||
|
||||
|
||||
Judging the correctness of candidates' answers:"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hosts: List[str],
|
||||
max_retries: int = 1,
|
||||
retry_delay: float = 1.0,
|
||||
stop_word="<|im_end|>",
|
||||
thinking_finish_words=["<conclude>", "**Final Answer**", "</think>"],
|
||||
):
|
||||
super().__init__()
|
||||
self.hosts = hosts
|
||||
self.max_retries = max_retries
|
||||
self.retry_delay = retry_delay
|
||||
self.stop_word = stop_word
|
||||
self.thinking_finish_words = thinking_finish_words
|
||||
|
||||
self.host_ip_idx = random.randint(0, len(hosts) - 1)
|
||||
self.model_name = requests.get(
|
||||
f"http://{self.hosts[self.host_ip_idx]}/v1/models",
|
||||
headers={"Authorization": "Bearer "},
|
||||
).json()["data"][0]["id"]
|
||||
|
||||
def on_data_received(
|
||||
self,
|
||||
prompt_messages: List[MessageItem],
|
||||
completion_messages: List[MessageItem],
|
||||
metadata: dict,
|
||||
) -> JudgeStatus:
|
||||
question = prompt_messages[-1]["content"]
|
||||
response = completion_messages[-1]["content"]
|
||||
question_type = metadata.get("question_type", None)
|
||||
gold_answer = metadata["gold_answer"]
|
||||
if not response.strip().endswith(self.stop_word):
|
||||
# If the response does not end with the stop word, it is not a complete response, treat as incorrect
|
||||
return JudgeStatus(
|
||||
ok=True,
|
||||
handle={
|
||||
"question": question,
|
||||
"question_type": question_type,
|
||||
"response": response,
|
||||
"gold_answer": gold_answer,
|
||||
"verify_label": False,
|
||||
},
|
||||
)
|
||||
|
||||
for thinking_finish_word in self.thinking_finish_words:
|
||||
if thinking_finish_word in response:
|
||||
response = response.split(thinking_finish_word)[-1]
|
||||
|
||||
response = response.replace(self.stop_word, "")
|
||||
|
||||
# first try to extract and verify with rule, if correct, return
|
||||
extracted_answer, verify_label = self._extract_and_verify_with_logic(
|
||||
response, gold_answer
|
||||
)
|
||||
if verify_label is True:
|
||||
return JudgeStatus(
|
||||
ok=True,
|
||||
handle={
|
||||
"question": question,
|
||||
"question_type": question_type,
|
||||
"response": response,
|
||||
"gold_answer": gold_answer,
|
||||
"verify_label": verify_label,
|
||||
},
|
||||
)
|
||||
|
||||
# then try to evaluate with model
|
||||
res_string, verify_label = self._evaluate_answer_with_llm(
|
||||
question, question_type, response, gold_answer
|
||||
)
|
||||
return JudgeStatus(
|
||||
ok=True,
|
||||
handle={
|
||||
"question": question,
|
||||
"question_type": question_type,
|
||||
"response": response,
|
||||
"gold_answer": gold_answer,
|
||||
"verify_label": verify_label,
|
||||
},
|
||||
)
|
||||
|
||||
def on_reward_required(
|
||||
self, status: JudgeStatus, timeout: Optional[float] = None
|
||||
) -> Reward:
|
||||
if status.handle is None:
|
||||
return None
|
||||
if status.handle["verify_label"] is not None:
|
||||
return 1.0 if status.handle["verify_label"] else -1.0
|
||||
return None
|
||||
|
||||
def _evaluate_answer_with_llm(
|
||||
self, question: str, question_type: str, answer: str, gold_answer: str
|
||||
) -> Tuple[str, bool]:
|
||||
for i in range(self.max_retries):
|
||||
host = self.hosts[self.host_ip_idx]
|
||||
self.host_ip_idx = (self.host_ip_idx + 1) % len(self.hosts)
|
||||
prompt = self.verify_prompt.format(
|
||||
"", "", question=question, answer=answer, gold_answer=gold_answer
|
||||
)
|
||||
try:
|
||||
res = requests.post(
|
||||
f"http://{host}/v1/chat/completions",
|
||||
json={
|
||||
"model": self.model_name,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": prompt,
|
||||
}
|
||||
],
|
||||
"temperature": 0.0,
|
||||
"top_p": 0.8,
|
||||
"top_k": 20,
|
||||
"repetition_penalty": 1.05,
|
||||
"max_tokens": 100,
|
||||
"stop": ["<|im_end|>", "<|endoftext|>"],
|
||||
},
|
||||
)
|
||||
res_string = res.json()["choices"][0]["message"]["content"]
|
||||
print(f"Evaluate result: {res_string}")
|
||||
verify_label = self._verify_from_string(res_string)
|
||||
if verify_label is None:
|
||||
raise ValueError(
|
||||
f"Evaluate result is None, judger prediction: {res_string}"
|
||||
)
|
||||
return res_string, verify_label
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error verifying answer: {e}")
|
||||
time.sleep(self.retry_delay)
|
||||
continue
|
||||
print(f"Failed to verify answer after {self.max_retries} retries.")
|
||||
return None, None
|
||||
|
||||
def _verify_from_string(self, verification: str):
|
||||
if "A" in verification and "B" not in verification:
|
||||
label = True
|
||||
elif "B" in verification and "A" not in verification:
|
||||
label = False
|
||||
else: # judger model failed to predict A or B
|
||||
label = None
|
||||
return label
|
||||
|
||||
def _extract_and_verify_with_logic(
|
||||
self, response: str, gold_answer: str
|
||||
) -> Tuple[str, bool]:
|
||||
extracted_answer = extract_answer(response)
|
||||
verify_label = math_equal(extracted_answer, gold_answer)
|
||||
return extracted_answer, verify_label
|
||||
473
examples/xpuyu_usage/bootcamp_rl/judgers/router.py
Executable file
473
examples/xpuyu_usage/bootcamp_rl/judgers/router.py
Executable file
|
|
@ -0,0 +1,473 @@
|
|||
# Copyright (c) InternLM. All rights reserved.
|
||||
import atexit
|
||||
import functools
|
||||
import os
|
||||
import queue
|
||||
import time
|
||||
import traceback
|
||||
from collections import defaultdict
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
from multiprocessing import Event, Process, Queue, connection
|
||||
from multiprocessing.synchronize import Event as EventClass
|
||||
from typing import (
|
||||
Callable,
|
||||
Dict,
|
||||
Generic,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
TypedDict,
|
||||
TypeVar,
|
||||
cast,
|
||||
)
|
||||
from uuid import uuid4
|
||||
|
||||
import loguru
|
||||
from typing_extensions import NotRequired
|
||||
|
||||
from .base_judger import (
|
||||
JudgeStatus,
|
||||
MessageItem,
|
||||
MetaData,
|
||||
Reward,
|
||||
registered_judgers,
|
||||
)
|
||||
|
||||
|
||||
class InputData(TypedDict):
|
||||
prompt_messages: List[MessageItem]
|
||||
completion_messages: List[MessageItem]
|
||||
metadata: NotRequired[MetaData]
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
@dataclass
|
||||
class GenericTask(Generic[T]):
|
||||
token: str
|
||||
index: int
|
||||
judger: str
|
||||
content: T
|
||||
|
||||
|
||||
@dataclass
|
||||
class SubprocessConfig:
|
||||
loguru_handlers: Optional[List[dict]] = None
|
||||
worker_init_func: Optional[Callable] = None
|
||||
|
||||
|
||||
class ParallelRouter:
|
||||
def __init__(
|
||||
self,
|
||||
judgers_config: Dict[str, dict],
|
||||
data_judger_mapping: Dict[str, Optional[List[str]]],
|
||||
logger: Optional["loguru.Logger"] = None,
|
||||
subprocess_config: Optional[SubprocessConfig] = None,
|
||||
):
|
||||
if logger is not None:
|
||||
self.logger = logger
|
||||
else:
|
||||
import mock
|
||||
|
||||
self.logger = mock.Mock()
|
||||
|
||||
if subprocess_config is not None:
|
||||
self.subprocess_config = subprocess_config
|
||||
else:
|
||||
self.subprocess_config = SubprocessConfig()
|
||||
|
||||
if not (
|
||||
isinstance(judgers_config, dict)
|
||||
and all(
|
||||
isinstance(k, str) and isinstance(v, dict)
|
||||
for k, v in judgers_config.items()
|
||||
)
|
||||
):
|
||||
raise TypeError(
|
||||
f"Illegal judgers_config: {judgers_config}\n"
|
||||
"Should be Dict[str, dict]"
|
||||
)
|
||||
if "RM" in judgers_config.keys():
|
||||
raise KeyError(
|
||||
f"'RM' is a reserved judger keywork for {self.__class__.__name__}, "
|
||||
f"please remove it from judgers_config: {judgers_config}"
|
||||
)
|
||||
self.judgers_config = judgers_config
|
||||
|
||||
data_judger_mapping: Dict[str, List[str]] = {
|
||||
k: v or [] for k, v in data_judger_mapping.items()
|
||||
} # change None to empty list []
|
||||
if not (
|
||||
isinstance(data_judger_mapping, dict)
|
||||
and all(
|
||||
isinstance(k, str)
|
||||
and isinstance(v, (list, tuple, set))
|
||||
and all(isinstance(vv, str) for vv in v)
|
||||
for k, v in data_judger_mapping.items()
|
||||
)
|
||||
):
|
||||
raise TypeError(
|
||||
f"Illegal data_judger_mapping: {data_judger_mapping}\n"
|
||||
"Should be Dict[str, List[str]]"
|
||||
)
|
||||
self.data_judger_mapping = data_judger_mapping
|
||||
|
||||
avail_judgers = set(self.judgers_config.keys()) | {"RM"}
|
||||
_used_judgers: List[str] = []
|
||||
for v in data_judger_mapping.values():
|
||||
_used_judgers.extend(v)
|
||||
used_judgers: set = set(_used_judgers)
|
||||
if unused := avail_judgers - used_judgers:
|
||||
self.logger.warning(
|
||||
"Following judgers are available but not "
|
||||
f"used in data mapping: {unused}\n"
|
||||
"Please make sure this is intended"
|
||||
)
|
||||
# remove unused configs
|
||||
for judger_name in unused:
|
||||
self.judgers_config.pop(judger_name, None)
|
||||
if missing := used_judgers - avail_judgers:
|
||||
self.logger.warning(
|
||||
"Following judgers are configured to be used "
|
||||
f"but not built in data mapping: {missing}\n"
|
||||
"Please make sure this is intended"
|
||||
)
|
||||
# remove missing judgers from mapping, to prevent potential errors
|
||||
for source in list(self.data_judger_mapping.keys()):
|
||||
before = set(self.data_judger_mapping[source])
|
||||
self.data_judger_mapping[source] = list(before - missing)
|
||||
# then filter out data_mapping without available judgers
|
||||
self.data_judger_mapping = {
|
||||
source: judgers
|
||||
for source, judgers in self.data_judger_mapping.items()
|
||||
if len(judgers) > 0
|
||||
}
|
||||
|
||||
# Try build judgers in __init__ so that raise Exceptions earlly
|
||||
for judger_name, judger_conf in self.judgers_config.items():
|
||||
_ = self._build_judger(judger_name, judger_conf)
|
||||
|
||||
self._processes: List[Process] = []
|
||||
self._stop_event = Event()
|
||||
atexit.register(self.shutdown)
|
||||
|
||||
self._input_queues: Dict[str, Queue[GenericTask[InputData]]] = {
|
||||
judger_name: Queue() for judger_name in self.judgers_config.keys()
|
||||
}
|
||||
self._output_queue: Queue[GenericTask[Reward]] = Queue()
|
||||
self._exc_queue: Queue[Tuple[str, Exception]] = Queue()
|
||||
self._num_tasks: Dict[str, int] = {} # for each token
|
||||
self._num_indexes: Dict[str, int] = {} # for each token
|
||||
self._results_buffer: Dict[str, List[GenericTask[Reward]]] = defaultdict(
|
||||
list
|
||||
) # results buffer grouped by the key "token"
|
||||
|
||||
def submit(self, data_batch: List[InputData]):
|
||||
indexes_for_ext: List[int] = []
|
||||
indexes_for_local: List[int] = []
|
||||
tasks_input: List[GenericTask[InputData]] = []
|
||||
token = str(uuid4())
|
||||
for index, data_item in enumerate(data_batch):
|
||||
if (
|
||||
not isinstance(data_item, dict)
|
||||
or "metadata" not in data_item
|
||||
or "prompt_messages" not in data_item
|
||||
or "completion_messages" not in data_item
|
||||
):
|
||||
indexes_for_local.append(index)
|
||||
continue
|
||||
source = data_item["metadata"].get("data_source", None)
|
||||
if source is None or source not in self.data_judger_mapping:
|
||||
indexes_for_local.append(index)
|
||||
continue
|
||||
indexes_for_ext.append(index)
|
||||
for judger in self.data_judger_mapping[source]:
|
||||
if judger == "RM":
|
||||
indexes_for_local.append(index)
|
||||
else:
|
||||
tasks_input.append(
|
||||
GenericTask(
|
||||
token=token,
|
||||
index=index,
|
||||
judger=judger,
|
||||
content=data_item,
|
||||
)
|
||||
)
|
||||
|
||||
self._num_tasks[token] = len(tasks_input)
|
||||
self._num_indexes[token] = len(data_batch)
|
||||
for task in tasks_input:
|
||||
self._input_queues[task.judger].put(task, block=True, timeout=1)
|
||||
|
||||
if not self._processes:
|
||||
self.logger.debug("Starting processes...")
|
||||
for judger_name, judger_conf in self.judgers_config.items():
|
||||
num_proc = judger_conf.pop("num_processes", 1)
|
||||
self._processes.extend(
|
||||
[
|
||||
Process(
|
||||
target=ParallelRouter._safe_process_worker,
|
||||
kwargs={
|
||||
"stop_event": self._stop_event,
|
||||
"judger_name": judger_name,
|
||||
"judger_conf": judger_conf,
|
||||
"input_queue": self._input_queues[judger_name],
|
||||
"output_queue": self._output_queue,
|
||||
"exc_queue": self._exc_queue,
|
||||
"config": self.subprocess_config,
|
||||
},
|
||||
daemon=True,
|
||||
)
|
||||
for _ in range(num_proc)
|
||||
]
|
||||
)
|
||||
for p in self._processes:
|
||||
p.start()
|
||||
self.logger.debug(f"Start processes done, total {len(self._processes)}")
|
||||
|
||||
return token, indexes_for_local
|
||||
|
||||
def query(
|
||||
self, token: str, timeout: float = 0
|
||||
) -> Optional[List[Optional[Dict[str, Reward]]]]:
|
||||
start = time.time()
|
||||
while True:
|
||||
self._try_catch_subprocess_exceptions()
|
||||
try:
|
||||
result = self._output_queue.get(timeout=0.1)
|
||||
self._results_buffer[result.token].append(result)
|
||||
except queue.Empty:
|
||||
pass
|
||||
if len(self._results_buffer[token]) == self._num_tasks[token]:
|
||||
results = self._results_buffer.pop(token)
|
||||
num_tasks = self._num_tasks.pop(token)
|
||||
num_indexes = self._num_indexes.pop(token)
|
||||
rewards: List[Dict[str, Reward]] = [{} for _ in range(num_indexes)]
|
||||
for result in results:
|
||||
reward = result.content
|
||||
if result.judger in rewards[result.index]:
|
||||
self.logger.warning(
|
||||
f"{result.judger} already exists: {rewards[result.index]}, "
|
||||
f"will replace --> {reward}"
|
||||
)
|
||||
rewards[result.index][result.judger] = reward
|
||||
# convert empty dicts to None
|
||||
return [r or None for r in rewards]
|
||||
if timeout > 0 and (time.time() - start) > timeout:
|
||||
raise TimeoutError(
|
||||
f"Timeout after {timeout} seconds, got {len(self._results_buffer[token])} results, expected {self._num_tasks[token]}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _safe_process_worker(
|
||||
stop_event: EventClass,
|
||||
judger_name: str,
|
||||
judger_conf: dict,
|
||||
input_queue: "Queue[GenericTask[InputData]]",
|
||||
output_queue: "Queue[GenericTask[Reward]]",
|
||||
exc_queue: "Queue[Tuple[str, Exception]]",
|
||||
config: SubprocessConfig,
|
||||
):
|
||||
try:
|
||||
ParallelRouter._process_worker(
|
||||
stop_event=stop_event,
|
||||
judger_name=judger_name,
|
||||
judger_conf=judger_conf,
|
||||
input_queue=input_queue,
|
||||
output_queue=output_queue,
|
||||
exc_queue=exc_queue,
|
||||
config=config,
|
||||
)
|
||||
except Exception as e:
|
||||
exc_queue.put((judger_name, e), timeout=1)
|
||||
|
||||
@staticmethod
|
||||
def _process_worker(
|
||||
stop_event: EventClass,
|
||||
judger_name: str,
|
||||
judger_conf: dict,
|
||||
input_queue: "Queue[GenericTask[InputData]]",
|
||||
output_queue: "Queue[GenericTask[Reward]]",
|
||||
exc_queue: "Queue[Tuple[str, Exception]]",
|
||||
config: SubprocessConfig,
|
||||
):
|
||||
from xtuner._lite import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
if config.loguru_handlers is not None:
|
||||
for handler in config.loguru_handlers:
|
||||
handler["enqueue"] = True
|
||||
logger.add(*handler)
|
||||
if config.worker_init_func is not None:
|
||||
config.worker_init_func()
|
||||
|
||||
# Infer num threads for each stage according to configs
|
||||
_num_threads = judger_conf.pop("concurrency_per_proc", (1, 1))
|
||||
if isinstance(_num_threads, (tuple, list)) and len(_num_threads) == 2:
|
||||
num_threads_s1, num_threads_s2 = _num_threads
|
||||
elif isinstance(_num_threads, int):
|
||||
num_threads_s1 = max(1, _num_threads // 2)
|
||||
num_threads_s2 = max(1, _num_threads - num_threads_s1)
|
||||
else:
|
||||
raise TypeError(
|
||||
"`concurrency_per_proc` in judger_conf should be int or "
|
||||
f"Tuple[int, int], got {type(_num_threads)}: {_num_threads}"
|
||||
)
|
||||
|
||||
# Lazy build judgers in subprocesses to avoid serialization errors
|
||||
judger = ParallelRouter._build_judger(judger_name, judger_conf)
|
||||
# input_queue = self._input_queues[judger_name]
|
||||
# output_queue = self._output_queue
|
||||
handle_queue: queue.Queue[GenericTask[JudgeStatus]] = queue.Queue()
|
||||
log_prefix = f"[pid={os.getpid()},{judger_name}]"
|
||||
|
||||
def report_exc_wrapper(func):
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
stack_trace = traceback.format_exc()
|
||||
logger.error(
|
||||
f"{log_prefix} "
|
||||
f"Thread worker of {judger_name} raised "
|
||||
f"{type(e).__name__}: {e}",
|
||||
f"Stack trace: {stack_trace}",
|
||||
)
|
||||
exc_queue.put((judger_name, e), timeout=1)
|
||||
|
||||
return wrapper
|
||||
|
||||
# Stage 1: input_queue -> judger.on_data_received -> handle_queue
|
||||
@report_exc_wrapper
|
||||
def thread_worker_s1():
|
||||
while not stop_event.is_set():
|
||||
try:
|
||||
task = input_queue.get(timeout=0.1)
|
||||
logger.debug(f"{log_prefix} dequeue input: {task}")
|
||||
except queue.Empty:
|
||||
logger.debug(f"{log_prefix} input queue empty")
|
||||
time.sleep(0.1)
|
||||
continue
|
||||
data = task.content
|
||||
if "metadata" not in data:
|
||||
raise RuntimeError(
|
||||
f"'metadata' not in data.keys(): {list(data.keys())}"
|
||||
)
|
||||
logger.debug(f"{log_prefix} on_data_received")
|
||||
handle = judger.on_data_received(
|
||||
data["prompt_messages"],
|
||||
data["completion_messages"],
|
||||
cast(dict, data["metadata"]),
|
||||
)
|
||||
logger.debug(f"{log_prefix} got handle")
|
||||
new_task = GenericTask(
|
||||
token=task.token,
|
||||
index=task.index,
|
||||
judger=task.judger,
|
||||
content=handle,
|
||||
)
|
||||
while True:
|
||||
try:
|
||||
handle_queue.put(
|
||||
new_task,
|
||||
timeout=0.1,
|
||||
)
|
||||
logger.debug(f"{log_prefix} enqueue handle: {new_task}")
|
||||
break
|
||||
except queue.Full:
|
||||
time.sleep(0.1)
|
||||
|
||||
# Stage 2: handle_queue -> judger.on_reward_required -> output_queue
|
||||
@report_exc_wrapper
|
||||
def thread_worker_s2():
|
||||
while not stop_event.is_set():
|
||||
try:
|
||||
task = handle_queue.get(timeout=0.1)
|
||||
logger.debug(f"{log_prefix} dequeue handle: {task}")
|
||||
except queue.Empty:
|
||||
logger.debug(f"{log_prefix} handle queue empty")
|
||||
time.sleep(0.1)
|
||||
continue
|
||||
logger.debug(f"{log_prefix} on_reward_required")
|
||||
reward = judger.on_reward_required(task.content)
|
||||
logger.info(f"{log_prefix} got result")
|
||||
new_task = GenericTask(
|
||||
token=task.token,
|
||||
index=task.index,
|
||||
judger=task.judger,
|
||||
content=reward,
|
||||
)
|
||||
while True:
|
||||
try:
|
||||
output_queue.put(
|
||||
new_task,
|
||||
timeout=0.1,
|
||||
)
|
||||
logger.debug(f"{log_prefix} enqueue output: {new_task}")
|
||||
break
|
||||
except queue.Full:
|
||||
time.sleep(0.1)
|
||||
|
||||
from threading import Thread
|
||||
|
||||
threads: List[Thread] = []
|
||||
for _ in range(num_threads_s1):
|
||||
threads.append(Thread(target=thread_worker_s1, daemon=True))
|
||||
for _ in range(num_threads_s2):
|
||||
threads.append(Thread(target=thread_worker_s2, daemon=True))
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
@staticmethod
|
||||
def _build_judger(judger_name: str, judger_conf: dict):
|
||||
judger_conf = deepcopy(judger_conf)
|
||||
judger_conf.pop("num_processes", None)
|
||||
judger_conf.pop("concurrency_per_proc", None)
|
||||
_type = judger_conf.pop("type", None)
|
||||
if _type is None:
|
||||
_type = judger_name
|
||||
if _type not in registered_judgers:
|
||||
raise KeyError(
|
||||
f"{judger_name} use unregistered judger type: {_type}. "
|
||||
f"Available judgers are: {list(registered_judgers.keys())}"
|
||||
)
|
||||
cls = registered_judgers[_type]
|
||||
return cls(**judger_conf)
|
||||
|
||||
def _try_catch_subprocess_exceptions(self):
|
||||
exc_handles: List[Tuple[str, Exception]] = []
|
||||
while True:
|
||||
try:
|
||||
exc_handle = self._exc_queue.get(timeout=0.001)
|
||||
exc_handles.append(exc_handle)
|
||||
except queue.Empty:
|
||||
break
|
||||
if exc_handles:
|
||||
error_message = "\n".join(
|
||||
[
|
||||
f"- [{judger_name}] {type(exc).__name__}: {exc}"
|
||||
for judger_name, exc in exc_handles
|
||||
]
|
||||
)
|
||||
raise RuntimeError(
|
||||
"Following threads/processes raise exceptions unexpectedly:\n"
|
||||
f"{error_message}\n"
|
||||
"Program terminated"
|
||||
)
|
||||
|
||||
def shutdown(self, timeout: float = 2.0):
|
||||
if not hasattr(self, "_processes") or not self._processes:
|
||||
return
|
||||
if not self._stop_event.is_set():
|
||||
self._stop_event.set()
|
||||
connection.wait([p.sentinel for p in self._processes], timeout=timeout)
|
||||
for p in self._processes:
|
||||
if p.is_alive():
|
||||
p.kill()
|
||||
p.join()
|
||||
self._processes = []
|
||||
485
examples/xpuyu_usage/bootcamp_rl/judgers/utils.py
Executable file
485
examples/xpuyu_usage/bootcamp_rl/judgers/utils.py
Executable file
|
|
@ -0,0 +1,485 @@
|
|||
# flake8: noqa
|
||||
# isort: skip_file
|
||||
|
||||
import multiprocessing
|
||||
import re
|
||||
from math import isclose
|
||||
from typing import Optional, Union
|
||||
from collections import defaultdict, Counter
|
||||
|
||||
from sympy import N, simplify
|
||||
from sympy.parsing.latex import parse_latex
|
||||
from sympy.parsing.sympy_parser import parse_expr
|
||||
|
||||
|
||||
def extract_answer(pred_str: str, execute: bool = False) -> str:
|
||||
if re.search("\\boxed|boxed|\\box|box", pred_str):
|
||||
answer = re.split("\\boxed|boxed|\\box|box", pred_str)[-1]
|
||||
if len(answer) == 0:
|
||||
return ""
|
||||
elif answer[0] == "{":
|
||||
stack = 1
|
||||
a = ""
|
||||
for c in answer[1:]:
|
||||
if c == "{":
|
||||
stack += 1
|
||||
a += c
|
||||
elif c == "}":
|
||||
stack -= 1
|
||||
if stack == 0:
|
||||
break
|
||||
a += c
|
||||
else:
|
||||
a += c
|
||||
else:
|
||||
a = answer.split("$")[0].strip()
|
||||
elif re.search("[Tt]he (final )?answer is:?", pred_str):
|
||||
a = re.split("[Tt]he (final )?answer is:?", pred_str)[-1].strip().rstrip(".")
|
||||
else: # use the last number
|
||||
pred = re.findall(r"-?\d*\.?\d+", pred_str.replace(",", ""))
|
||||
if len(pred) >= 1:
|
||||
a = pred[-1]
|
||||
else:
|
||||
a = ""
|
||||
choice = re.findall(r"([A-E]):\s*(.*)", a)
|
||||
if len(choice) > 0:
|
||||
for option, content in choice:
|
||||
a = option
|
||||
choice = re.findall(r"\(([A-E])\)\s*(.*)", a)
|
||||
if len(choice) > 0:
|
||||
for option, content in choice:
|
||||
a = option
|
||||
|
||||
a = re.split(r"=|\\approx|≈", a)[-1]
|
||||
|
||||
# multiple lines
|
||||
answer = ""
|
||||
preds = re.split("\n", a)
|
||||
for pred in preds:
|
||||
if "\\begin{align" in pred or pred.endswith(":"):
|
||||
continue
|
||||
if pred != "" and pred[0] == ":":
|
||||
pred = pred[1:]
|
||||
if pred != "" and pred[-1] == ".":
|
||||
pred = pred[:-1]
|
||||
if pred != "" and pred[-1] == "/":
|
||||
pred = pred[:-1]
|
||||
pred = strip_string(pred)
|
||||
pred = re.sub(r"^[a-zA-Z0-9]+[\)]\s*", "", pred)
|
||||
for p in pred.split("{}"):
|
||||
if p != "":
|
||||
pred = p
|
||||
break
|
||||
|
||||
pred = re.sub(r"^\{([A-Z])\}|\(([A-Z])\)", r"\1\2", pred)
|
||||
if pred != "":
|
||||
answer = pred
|
||||
break
|
||||
return answer
|
||||
|
||||
|
||||
def _fix_fracs(string):
|
||||
substrs = string.split("\\frac")
|
||||
new_str = substrs[0]
|
||||
if len(substrs) > 1:
|
||||
substrs = substrs[1:]
|
||||
for substr in substrs:
|
||||
new_str += "\\frac"
|
||||
if len(substr) > 0 and substr[0] == "{":
|
||||
new_str += substr
|
||||
else:
|
||||
try:
|
||||
assert len(substr) >= 2
|
||||
except Exception:
|
||||
return string
|
||||
a = substr[0]
|
||||
b = substr[1]
|
||||
if b != "{":
|
||||
if len(substr) > 2:
|
||||
post_substr = substr[2:]
|
||||
new_str += "{" + a + "}{" + b + "}" + post_substr
|
||||
else:
|
||||
new_str += "{" + a + "}{" + b + "}"
|
||||
else:
|
||||
if len(substr) > 2:
|
||||
post_substr = substr[2:]
|
||||
new_str += "{" + a + "}" + b + post_substr
|
||||
else:
|
||||
new_str += "{" + a + "}" + b
|
||||
string = new_str
|
||||
return string
|
||||
|
||||
|
||||
def _fix_a_slash_b(string):
|
||||
if len(string.split("/")) != 2:
|
||||
return string
|
||||
a = string.split("/")[0]
|
||||
b = string.split("/")[1]
|
||||
try:
|
||||
if "sqrt" not in a:
|
||||
a = int(a)
|
||||
if "sqrt" not in b:
|
||||
b = int(b)
|
||||
assert string == f"{a}/{b}"
|
||||
new_string = "\\frac{" + str(a) + "}{" + str(b) + "}"
|
||||
return new_string
|
||||
except Exception:
|
||||
return string
|
||||
|
||||
|
||||
def _fix_sqrt(string):
|
||||
_string = re.sub(r"\\sqrt(\w+)", r"\\sqrt{\1}", string)
|
||||
return _string
|
||||
|
||||
|
||||
def strip_string(string):
|
||||
string = str(string).strip()
|
||||
# linebreaks
|
||||
string = string.replace("\n", "")
|
||||
|
||||
# right "."
|
||||
string = string.rstrip(".")
|
||||
|
||||
# remove inverse spaces
|
||||
string = string.replace("\\!", "")
|
||||
string = string.replace("\\ ", "")
|
||||
|
||||
# replace \\ with \
|
||||
string = string.replace("\\\\", "\\")
|
||||
string = string.replace("\\\\", "\\")
|
||||
|
||||
# replace tfrac and dfrac with frac
|
||||
string = string.replace("tfrac", "frac")
|
||||
string = string.replace("dfrac", "frac")
|
||||
|
||||
# remove \left and \right
|
||||
string = string.replace("\\left", "")
|
||||
string = string.replace("\\right", "")
|
||||
|
||||
# Remove unit: miles, dollars if after is not none
|
||||
_string = re.sub(r"\\text{.*?}$", "", string).strip()
|
||||
if _string != "" and _string != string:
|
||||
# print("Warning: unit not removed: '{}' -> '{}'".format(string, _string))
|
||||
string = _string
|
||||
|
||||
# Remove circ (degrees)
|
||||
string = string.replace("^{\\circ}", "")
|
||||
string = string.replace("^\\circ", "")
|
||||
|
||||
# remove dollar signs
|
||||
string = string.replace("\\$", "")
|
||||
string = string.replace("$", "")
|
||||
|
||||
string = string.replace("\\text", "")
|
||||
string = string.replace("x\\in", "")
|
||||
|
||||
# remove percentage
|
||||
string = string.replace("\\%", "")
|
||||
string = string.replace(r"\%", "")
|
||||
string = string.replace("%", "")
|
||||
|
||||
# " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string
|
||||
string = string.replace(" .", " 0.")
|
||||
string = string.replace("{.", "{0.")
|
||||
|
||||
# cdot
|
||||
string = string.replace("\\cdot", "")
|
||||
|
||||
# inf
|
||||
string = string.replace("infinity", "\\infty")
|
||||
if "\\infty" not in string:
|
||||
string = string.replace("inf", "\\infty")
|
||||
string = string.replace("+\\inity", "\\infty")
|
||||
|
||||
# and
|
||||
string = string.replace("and", "")
|
||||
string = string.replace("\\mathbf", "")
|
||||
|
||||
# use regex to remove \mbox{...}
|
||||
string = re.sub(r"\\mbox{.*?}", "", string)
|
||||
|
||||
# quote
|
||||
string.replace("'", "")
|
||||
string.replace('"', "")
|
||||
|
||||
# i, j
|
||||
if "j" in string and "i" not in string:
|
||||
string = string.replace("j", "i")
|
||||
|
||||
# replace a.000b where b is not number or b is end, with ab, use regex
|
||||
string = re.sub(r"(\d+)\.0+([^\d])", r"\1\2", string)
|
||||
string = re.sub(r"(\d+)\.0+$", r"\1", string)
|
||||
|
||||
# if empty, return empty string
|
||||
if len(string) == 0:
|
||||
return string
|
||||
if string[0] == ".":
|
||||
string = "0" + string
|
||||
|
||||
# to consider: get rid of e.g. "k = " or "q = " at beginning
|
||||
if len(string.split("=")) == 2:
|
||||
if len(string.split("=")[0]) <= 2:
|
||||
string = string.split("=")[1]
|
||||
|
||||
string = _fix_sqrt(string)
|
||||
string = string.replace(" ", "")
|
||||
|
||||
# \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b}
|
||||
string = _fix_fracs(string)
|
||||
|
||||
# NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y
|
||||
string = _fix_a_slash_b(string)
|
||||
|
||||
return string
|
||||
|
||||
|
||||
def last_boxed_only_string(string):
|
||||
idx = string.rfind("\\boxed")
|
||||
if idx < 0:
|
||||
idx = string.rfind("\\fbox")
|
||||
if idx < 0:
|
||||
return None
|
||||
|
||||
i = idx
|
||||
right_brace_idx = None
|
||||
num_left_braces_open = 0
|
||||
while i < len(string):
|
||||
if string[i] == "{":
|
||||
num_left_braces_open += 1
|
||||
if string[i] == "}":
|
||||
num_left_braces_open -= 1
|
||||
if num_left_braces_open == 0:
|
||||
right_brace_idx = i
|
||||
break
|
||||
i += 1
|
||||
|
||||
if right_brace_idx is None:
|
||||
retval = None
|
||||
else:
|
||||
retval = string[idx : right_brace_idx + 1]
|
||||
|
||||
return retval
|
||||
|
||||
|
||||
def extract_answer(pred_str: str, execute: bool = False) -> str:
|
||||
if re.search("\boxed|boxed", pred_str):
|
||||
answer = re.split("\boxed|boxed", pred_str)[-1]
|
||||
if len(answer) == 0:
|
||||
return ""
|
||||
elif answer[0] == "{":
|
||||
stack = 1
|
||||
a = ""
|
||||
for c in answer[1:]:
|
||||
if c == "{":
|
||||
stack += 1
|
||||
a += c
|
||||
elif c == "}":
|
||||
stack -= 1
|
||||
if stack == 0:
|
||||
break
|
||||
a += c
|
||||
else:
|
||||
a += c
|
||||
else:
|
||||
a = answer.split("$")[0].strip()
|
||||
elif re.search("[Tt]he (final )?answer is:?", pred_str):
|
||||
a = re.split("[Tt]he (final )?answer is:?", pred_str)[-1].strip().rstrip(".")
|
||||
elif pred_str.startswith("```python") and execute:
|
||||
# fall back to program
|
||||
from lagent import get_tool
|
||||
|
||||
a = get_tool("IPythonInteractive").exec(pred_str).value or ""
|
||||
else: # use the last number
|
||||
pred = re.findall(r"-?\d*\.?\d+", pred_str.replace(",", ""))
|
||||
if len(pred) >= 1:
|
||||
a = pred[-1]
|
||||
else:
|
||||
a = ""
|
||||
# multiple lines
|
||||
pred = a.split("\n")[0]
|
||||
if pred != "" and pred[0] == ":":
|
||||
pred = pred[1:]
|
||||
if pred != "" and pred[-1] == ".":
|
||||
pred = pred[:-1]
|
||||
if pred != "" and pred[-1] == "/":
|
||||
pred = pred[:-1]
|
||||
pred = strip_string(pred)
|
||||
return pred
|
||||
|
||||
|
||||
def is_digit(s):
|
||||
try:
|
||||
float(str(s).replace(",", ""))
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
|
||||
def math_equal(
|
||||
prediction: Union[bool, float, str],
|
||||
reference: Union[float, str],
|
||||
include_percentage: bool = True,
|
||||
is_close: bool = True,
|
||||
tolerance: float = 1e-4,
|
||||
timeout: bool = False,
|
||||
) -> bool:
|
||||
"""Exact match of math if and only if:
|
||||
|
||||
1. numerical equal: both can convert to float and are equal
|
||||
2. symbolic equal: both can convert to sympy expression and are equal
|
||||
"""
|
||||
try: # 1. numerical equal
|
||||
if is_digit(prediction) and is_digit(reference):
|
||||
prediction = float(str(prediction).replace(",", ""))
|
||||
reference = float(str(reference).replace(",", ""))
|
||||
# number questions
|
||||
if include_percentage:
|
||||
gt_result = [reference / 100, reference, reference * 100]
|
||||
else:
|
||||
gt_result = [reference]
|
||||
for item in gt_result:
|
||||
try:
|
||||
if is_close:
|
||||
if isclose(item, prediction, rel_tol=tolerance):
|
||||
return True
|
||||
else:
|
||||
if item == prediction:
|
||||
return True
|
||||
except Exception:
|
||||
continue
|
||||
return False
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if not prediction and prediction not in [0, False]:
|
||||
return False
|
||||
|
||||
# 2. symbolic equal
|
||||
reference = str(reference).strip()
|
||||
prediction = str(prediction).strip()
|
||||
|
||||
## deal with [], (), {}
|
||||
pred_str, ref_str = prediction, reference
|
||||
if (
|
||||
prediction.startswith("[")
|
||||
and prediction.endswith("]")
|
||||
and not reference.startswith("(")
|
||||
) or (
|
||||
prediction.startswith("(")
|
||||
and prediction.endswith(")")
|
||||
and not reference.startswith("[")
|
||||
):
|
||||
pred_str = pred_str.strip("[]()")
|
||||
ref_str = ref_str.strip("[]()")
|
||||
for s in ["{", "}", "(", ")"]:
|
||||
ref_str = ref_str.replace(s, "")
|
||||
pred_str = pred_str.replace(s, "")
|
||||
if pred_str == ref_str:
|
||||
return True
|
||||
|
||||
## [a, b] vs. [c, d], return a==c and b==d
|
||||
if (
|
||||
(prediction.startswith("[") and prediction.endswith("]"))
|
||||
and (reference.startswith("[") and reference.endswith("]"))
|
||||
or (prediction.startswith("(") and prediction.endswith(")"))
|
||||
and (reference.startswith("(") and reference.endswith(")"))
|
||||
):
|
||||
pred_parts = prediction[1:-1].split(",")
|
||||
ref_parts = reference[1:-1].split(",")
|
||||
if len(pred_parts) == len(ref_parts):
|
||||
if all(
|
||||
[
|
||||
math_equal(
|
||||
pred_parts[i], ref_parts[i], include_percentage, is_close
|
||||
)
|
||||
for i in range(len(pred_parts))
|
||||
]
|
||||
):
|
||||
return True
|
||||
|
||||
# symbolic equal with sympy
|
||||
if timeout:
|
||||
if call_with_timeout(symbolic_equal_process, prediction, reference):
|
||||
return True
|
||||
else:
|
||||
if symbolic_equal(prediction, reference):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def math_equal_process(param):
|
||||
return math_equal(param[-2], param[-1])
|
||||
|
||||
|
||||
def math_equal_process(param):
|
||||
if param[-2] is None:
|
||||
return False
|
||||
return math_equal(param[-2], param[-1])
|
||||
|
||||
|
||||
def symbolic_equal(a, b):
|
||||
|
||||
def _parse(s):
|
||||
for f in [parse_latex, parse_expr]:
|
||||
try:
|
||||
return f(s)
|
||||
except Exception:
|
||||
pass
|
||||
return s
|
||||
|
||||
a = _parse(a)
|
||||
b = _parse(b)
|
||||
|
||||
try:
|
||||
if simplify(a - b) == 0:
|
||||
return True
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
if isclose(N(a), N(b), rel_tol=1e-3):
|
||||
return True
|
||||
except Exception:
|
||||
pass
|
||||
return False
|
||||
|
||||
|
||||
def symbolic_equal_process(a, b, output_queue):
|
||||
result = symbolic_equal(a, b)
|
||||
output_queue.put(result)
|
||||
|
||||
|
||||
def call_with_timeout(func, *args, timeout=1, **kwargs):
|
||||
output_queue = multiprocessing.Queue()
|
||||
process_args = args + (output_queue,)
|
||||
process = multiprocessing.Process(target=func, args=process_args, kwargs=kwargs)
|
||||
process.start()
|
||||
process.join(timeout)
|
||||
|
||||
if process.is_alive():
|
||||
process.terminate()
|
||||
process.join()
|
||||
return False
|
||||
|
||||
return output_queue.get()
|
||||
|
||||
|
||||
def math_majority_vote(answers: list, majority: Optional[int] = None):
|
||||
# threshold = len(answers) // 2 + 1
|
||||
ans2cnt, ans2idx = Counter(), defaultdict(list)
|
||||
for i, ans in enumerate(answers):
|
||||
if isinstance(ans, str) and ans.strip():
|
||||
for key in ans2cnt.keys():
|
||||
if math_equal(ans, key):
|
||||
ans2cnt[key] += 1
|
||||
ans2idx[key].append(i)
|
||||
break
|
||||
else:
|
||||
ans2cnt[ans] += 1
|
||||
ans2idx[ans].append(i)
|
||||
if ans2cnt:
|
||||
maj, cnt = ans2cnt.most_common(1)[0]
|
||||
if maj and cnt >= (majority or 1):
|
||||
return maj, ans2idx[maj]
|
||||
return None, []
|
||||
53
examples/xpuyu_usage/bootcamp_rl/utils.py
Executable file
53
examples/xpuyu_usage/bootcamp_rl/utils.py
Executable file
|
|
@ -0,0 +1,53 @@
|
|||
import importlib.util
|
||||
import os
|
||||
import types
|
||||
|
||||
|
||||
class ConfigDict(dict):
|
||||
|
||||
def __getattr__(self, item):
|
||||
if item in self:
|
||||
return self[item]
|
||||
raise AttributeError(f"'ConfigDict' object has no attribute '{item}'")
|
||||
|
||||
def __setattr__(self, key, value):
|
||||
self[key] = value
|
||||
|
||||
|
||||
class Config:
|
||||
|
||||
@staticmethod
|
||||
def fromfile(file_path):
|
||||
config_dict = ConfigDict()
|
||||
if not os.path.isfile(file_path):
|
||||
raise FileNotFoundError(f"Config file not found: {file_path}")
|
||||
|
||||
# Load the configuration file as a module
|
||||
spec = importlib.util.spec_from_file_location("config_module", file_path)
|
||||
config_module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(config_module)
|
||||
|
||||
# Function to convert nested dictionaries to ConfigDict recursively
|
||||
def convert_to_config_dict(d):
|
||||
if isinstance(d, dict):
|
||||
|
||||
config_dict = ConfigDict()
|
||||
for key, value in d.items():
|
||||
if isinstance(value, dict):
|
||||
config_dict[key] = convert_to_config_dict(value)
|
||||
else:
|
||||
config_dict[key] = value
|
||||
return config_dict
|
||||
else:
|
||||
return d
|
||||
|
||||
# Retrieve all attributes (variables) from the module
|
||||
for attribute_name in dir(config_module):
|
||||
if not attribute_name.startswith("__"):
|
||||
config_dict[attribute_name] = convert_to_config_dict(
|
||||
getattr(config_module, attribute_name)
|
||||
)
|
||||
for key, value in list(config_dict.items()):
|
||||
if isinstance(value, (types.FunctionType, types.ModuleType)):
|
||||
config_dict.pop(key)
|
||||
return config_dict
|
||||
40
examples/xpuyu_usage/report_to_wandb.py
Executable file
40
examples/xpuyu_usage/report_to_wandb.py
Executable file
|
|
@ -0,0 +1,40 @@
|
|||
import os
|
||||
import fire
|
||||
import json
|
||||
import wandb
|
||||
|
||||
|
||||
def main(path, project):
|
||||
name = path.split("/")[-3]
|
||||
# name = os.path.basename(path).split(".")[0]
|
||||
wandb.init(project=project, name=name)
|
||||
previous_step = 0
|
||||
log_cache = {}
|
||||
for line in open(path):
|
||||
log = json.loads(line)
|
||||
parsed_log = {}
|
||||
for key, value in log.items():
|
||||
if key != "rejected_score_mean":
|
||||
key = key.replace("rejected_score", "rejected_score/")
|
||||
if "/" in key:
|
||||
split_key = key.split("/")
|
||||
new_key = "_".join(split_key[1:]) + "/" + split_key[0]
|
||||
parsed_log[new_key] = value
|
||||
else:
|
||||
parsed_log[key] = value
|
||||
print(parsed_log)
|
||||
step = parsed_log.pop("step")
|
||||
if step != previous_step:
|
||||
wandb.log(log_cache, commit=True, step=previous_step)
|
||||
log_cache = {}
|
||||
previous_step = step
|
||||
log_cache.update(parsed_log)
|
||||
if log_cache:
|
||||
wandb.log(log_cache, commit=True, step=previous_step)
|
||||
|
||||
|
||||
wandb.finish()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(main)
|
||||
4
examples/xpuyu_usage/requirements.text
Executable file
4
examples/xpuyu_usage/requirements.text
Executable file
|
|
@ -0,0 +1,4 @@
|
|||
fire
|
||||
flash-attn
|
||||
torch>=2.5.0
|
||||
xtuner[all]==0.2.0rc0
|
||||
839
examples/xpuyu_usage/train_grpo.py
Executable file
839
examples/xpuyu_usage/train_grpo.py
Executable file
|
|
@ -0,0 +1,839 @@
|
|||
# Copyright (c) InternLM. All rights reserved.
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from collections import OrderedDict
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
import fire
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from mmengine import mkdir_or_exist
|
||||
from mmengine.runner import set_random_seed
|
||||
from mmengine.utils import get_git_hash
|
||||
from mmengine.utils.dl_utils import collect_env
|
||||
from torch.nn import functional as F
|
||||
from torch.optim import AdamW
|
||||
from torch.optim.lr_scheduler import CosineAnnealingLR, LambdaLR
|
||||
from torch.utils.data import DataLoader
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from transformers.utils.import_utils import is_flash_attn_2_available
|
||||
from xtuner._lite import get_device, get_logger, get_torch_device_module
|
||||
from xtuner._lite.accelerate import profile_time_and_memory, unpack_sequence
|
||||
from xtuner._lite.algorithms.sft import SftCollator
|
||||
from xtuner._lite.modelings import register_remote_code
|
||||
from xtuner._lite.parallel import (
|
||||
ParallelSampler,
|
||||
setup_parallel,
|
||||
)
|
||||
from xtuner._lite.patches import AutoPatch, FSDPConfig
|
||||
|
||||
from bootcamp_rl.datasets import (
|
||||
InferDataset,
|
||||
bootcampPromptDataset,
|
||||
PromptCollator,
|
||||
TrajectoryCollator,
|
||||
TrajectoryDataset,
|
||||
InfiniteDataLoaderIter,
|
||||
)
|
||||
from bootcamp_rl.judgers import ParallelRouter
|
||||
from bootcamp_rl.utils import Config
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
DEVICE = get_device()
|
||||
DEVICE_MODULE = get_torch_device_module()
|
||||
|
||||
|
||||
torch._dynamo.config.cache_size_limit = 16384
|
||||
|
||||
CHAT_TEMPLATE_MAP = {
|
||||
"qwen": {
|
||||
"chat_template":"{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0]['role'] == 'system' %}\n {{- messages[0]['content'] }}\n {%- else %}\n {{- 'You are a helpful assistant.' }}\n {%- endif %}\n {{- \"\\n\\n# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within <tools></tools> XML tags:\\n<tools>\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n</tools>\\n\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\n<tool_call>\\n{\\\"name\\\": <function-name>, \\\"arguments\\\": <args-json-object>}\\n</tool_call><|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0]['role'] == 'system' %}\n {{- '<|im_start|>system\\n' + messages[0]['content'] + '<|im_end|>\\n' }}\n {%- else %}\n {{- '<|im_start|>system\\nYou are a helpful assistant.<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- for message in messages %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) or (message.role == \"assistant\" and not message.tool_calls) %}\n {{- '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {{- '<|im_start|>' + message.role }}\n {%- if message.content %}\n {{- '\\n' + message.content }}\n {%- endif %}\n {%- for tool_call in message.tool_calls %}\n {%- if tool_call.function is defined %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '\\n<tool_call>\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {{- tool_call.arguments | tojson }}\n {{- '}\\n</tool_call>' }}\n {%- endfor %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n<tool_response>\\n' }}\n {{- message.content }}\n {{- '\\n</tool_response>' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n{%- endif %}\n",
|
||||
"stop_words":["<|im_end|>", "<|endoftext|>"],
|
||||
},
|
||||
"internthinker": {
|
||||
"chat_template":"{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0]['role'] == 'system' %}\n {{- messages[0]['content'] }}\n {%- else %}\n {{- 'You are an expert reasoner with extensive experience in mathematical and code competitions. You approach problems through systematic thinking and rigorous reasoning. Your response should reflect deep understanding and precise logical thinking, making your solution path and reasoning clear to others. Please put your thinking process within <think>...</think> tags.' }}\n {%- endif %}\n {{- \"\\n\\n# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within <tools></tools> XML tags:\\n<tools>\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n</tools>\\n\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\n<tool_call>\\n{\\\"name\\\": <function-name>, \\\"arguments\\\": <args-json-object>}\\n</tool_call><|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0]['role'] == 'system' %}\n {{- '<|im_start|>system\\n' + messages[0]['content'] + '<|im_end|>\\n' }}\n {%- else %}\n {{- '<|im_start|>system\\nYou are an expert reasoner with extensive experience in mathematical and code competitions. You approach problems through systematic thinking and rigorous reasoning. Your response should reflect deep understanding and precise logical thinking, making your solution path and reasoning clear to others. Please put your thinking process within <think>...</think> tags.<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- for message in messages %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) or (message.role == \"assistant\" and not message.tool_calls) %}\n {{- '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {{- '<|im_start|>' + message.role }}\n {%- if message.content %}\n {{- '\\n' + message.content }}\n {%- endif %}\n {%- for tool_call in message.tool_calls %}\n {%- if tool_call.function is defined %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '\\n<tool_call>\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {{- tool_call.arguments | tojson }}\n {{- '}\\n</tool_call>' }}\n {%- endfor %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n<tool_response>\\n' }}\n {{- message.content }}\n {{- '\\n</tool_response>' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n{%- endif %}\n",
|
||||
"stop_words":["<|im_end|>", "<|endoftext|>"],
|
||||
},
|
||||
"r1": {
|
||||
"chat_template":"{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set ns = namespace(is_first=false, is_tool=false, is_output_first=true, system_prompt='', is_first_sp=true) %}{%- for message in messages %}{%- if message['role'] == 'system' %}{%- if ns.is_first_sp %}{% set ns.system_prompt = ns.system_prompt + message['content'] %}{% set ns.is_first_sp = false %}{%- else %}{% set ns.system_prompt = ns.system_prompt + '\\n\\n' + message['content'] %}{%- endif %}{%- endif %}{%- endfor %}{{ bos_token }}{{ ns.system_prompt }}{%- for message in messages %}{%- if message['role'] == 'user' %}{%- set ns.is_tool = false -%}{{'<|User|>' + message['content']}}{%- endif %}{%- if message['role'] == 'assistant' and 'tool_calls' in message %}{%- set ns.is_tool = false -%}{%- for tool in message['tool_calls'] %}{%- if not ns.is_first %}{%- if message['content'] is none %}{{'<|Assistant|><|tool▁calls▁begin|><|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\\n' + '```json' + '\\n' + tool['function']['arguments'] + '\\n' + '```' + '<|tool▁call▁end|>'}}{%- else %}{{'<|Assistant|>' + message['content'] + '<|tool▁calls▁begin|><|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\\n' + '```json' + '\\n' + tool['function']['arguments'] + '\\n' + '```' + '<|tool▁call▁end|>'}}{%- endif %}{%- set ns.is_first = true -%}{%- else %}{{'\\n' + '<|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\\n' + '```json' + '\\n' + tool['function']['arguments'] + '\\n' + '```' + '<|tool▁call▁end|>'}}{%- endif %}{%- endfor %}{{'<|tool▁calls▁end|><|end▁of▁sentence|>'}}{%- endif %}{%- if message['role'] == 'assistant' and 'tool_calls' not in message %}{%- if ns.is_tool %}{{'<|tool▁outputs▁end|>' + message['content'] + '<|end▁of▁sentence|>'}}{%- set ns.is_tool = false -%}{%- else %}{% set content = message['content'] %}{% if '</think>' in content %}{% set content = content.split('</think>')[-1] %}{% endif %}{{'<|Assistant|>' + content + '<|end▁of▁sentence|>'}}{%- endif %}{%- endif %}{%- if message['role'] == 'tool' %}{%- set ns.is_tool = true -%}{%- if ns.is_output_first %}{{'<|tool▁outputs▁begin|><|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}}{%- set ns.is_output_first = false %}{%- else %}{{'<|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}}{%- endif %}{%- endif %}{%- endfor -%}{% if ns.is_tool %}{{'<|tool▁outputs▁end|>'}}{% endif %}{% if add_generation_prompt and not ns.is_tool %}{{'<|Assistant|><think>\\n'}}{% endif %}",
|
||||
"stop_words":["<|end▁of▁sentence|>"],
|
||||
},
|
||||
|
||||
}
|
||||
|
||||
|
||||
class RLParallelSampler(ParallelSampler):
|
||||
def __iter__(self):
|
||||
"""Iterate the indices."""
|
||||
# deterministically shuffle based on epoch and seed
|
||||
if self.shuffle:
|
||||
g = torch.Generator()
|
||||
g.manual_seed(self.seed + self.epoch)
|
||||
indices = torch.randperm(len(self.dataset), generator=g).tolist()
|
||||
else:
|
||||
indices = torch.arange(len(self.dataset)).tolist()
|
||||
|
||||
# add extra samples to make it evenly divisible
|
||||
if self.round_up:
|
||||
indices = (indices * int(self.total_size / len(indices) + 1))[
|
||||
: self.total_size
|
||||
]
|
||||
|
||||
# subsample
|
||||
chunk_size = len(indices) // self.world_size
|
||||
start = self.rank * chunk_size
|
||||
end = start + chunk_size
|
||||
indices = indices[start:end]
|
||||
|
||||
return iter(indices[self.step :])
|
||||
|
||||
|
||||
|
||||
class PGLoss(torch.nn.Module):
|
||||
"""Policy Gradient Loss for policy model."""
|
||||
|
||||
def __init__(self,
|
||||
clip: float = 0.2,
|
||||
loss_type: str = "per_seq"):
|
||||
super().__init__()
|
||||
self.clip = clip
|
||||
self.loss_type = loss_type
|
||||
assert self.loss_type in ["per_token", "per_seq"]
|
||||
|
||||
def forward(self, logprobs, old_logprobs, advantages, loss_factor=None):
|
||||
if self.loss_type == "per_seq":
|
||||
return self.forward_per_seq(logprobs, old_logprobs, advantages, loss_factor)
|
||||
elif self.loss_type == "per_token":
|
||||
return self.forward_per_token(logprobs, old_logprobs, advantages, loss_factor)
|
||||
|
||||
def forward_per_seq(self, logprobs, old_logprobs, advantages, loss_factor=None):
|
||||
logprobs = logprobs.sum(1)
|
||||
old_logprobs = old_logprobs.sum(1)
|
||||
logprobs_diff = logprobs - old_logprobs
|
||||
ratio = torch.exp(logprobs_diff)
|
||||
pg_losses = -advantages * ratio
|
||||
pg_losses2 = -advantages * torch.clamp(ratio, 1.0 - self.clip, 1.0 + self.clip)
|
||||
pg_loss_max = torch.max(pg_losses, pg_losses2)
|
||||
pg_loss = pg_loss_max.mean()
|
||||
return pg_loss
|
||||
|
||||
def forward_per_token(self, logprobs, old_logprobs, advantages, loss_factor=None):
|
||||
ratio = (logprobs - old_logprobs).exp()
|
||||
pg_loss1 = -ratio * advantages
|
||||
pg_loss2 = -ratio.clamp(1 - self.clip,
|
||||
1 + self.clip) * advantages
|
||||
pg_loss_max = torch.max(pg_loss1, pg_loss2)
|
||||
|
||||
assert loss_factor is not None
|
||||
pg_loss = torch.sum(pg_loss_max) * loss_factor
|
||||
return pg_loss
|
||||
|
||||
|
||||
def log_format(rank, debug=False):
|
||||
|
||||
formatter = f"[XTuner][RANK {rank}]"
|
||||
formatter += "[{time:YYYY-MM-DD HH:mm:ss}][<level>{level}</level>]"
|
||||
|
||||
if debug:
|
||||
formatter += "[<cyan>{name}</cyan>:"
|
||||
formatter += "<cyan>{function}</cyan>:"
|
||||
formatter += "<cyan>{line}</cyan>]"
|
||||
|
||||
formatter += " <level>{message}</level>"
|
||||
return formatter
|
||||
|
||||
|
||||
def is_interval(step, total_steps, interval):
|
||||
return (step + 1) % interval == 0 or (step + 1) == total_steps
|
||||
|
||||
|
||||
def reduce_mean(data, group):
|
||||
data_tensor = torch.tensor(data, device=DEVICE)
|
||||
dist.all_reduce(data_tensor, op=dist.ReduceOp.AVG, group=group)
|
||||
return data_tensor.item()
|
||||
|
||||
|
||||
def train_grpo(cfg_path, **kwargs):
|
||||
args = Config.fromfile(cfg_path)
|
||||
args.update(kwargs)
|
||||
|
||||
###########################################################################
|
||||
# 1. Environment #
|
||||
###########################################################################
|
||||
register_remote_code()
|
||||
|
||||
setup_parallel()
|
||||
set_random_seed(args.seed)
|
||||
|
||||
rank = dist.get_rank()
|
||||
|
||||
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
|
||||
|
||||
objects = [timestamp]
|
||||
dist.broadcast_object_list(objects, src=0)
|
||||
timestamp = objects[0]
|
||||
|
||||
args.work_dir = os.path.join(args.work_dir, timestamp)
|
||||
mkdir_or_exist(args.work_dir)
|
||||
|
||||
log_file = os.path.join(args.work_dir, f"rank{rank}.log")
|
||||
|
||||
# Change the log format printed in the terminal
|
||||
lvl = "DEBUG" if args.debug else "INFO"
|
||||
logger.remove()
|
||||
logger.add(sys.stderr, level=lvl, format=log_format(rank, args.debug))
|
||||
# Change the format saved in the log file
|
||||
logger.add(log_file, format=log_format(rank), backtrace=True, catch=True)
|
||||
|
||||
logger.info(args)
|
||||
if rank == 0:
|
||||
env = collect_env()
|
||||
import transformers
|
||||
import xtuner
|
||||
|
||||
env["Transformers"] = transformers.__version__
|
||||
env["XTuner"] = f"{xtuner.__version__}+{get_git_hash(digits=6)}"
|
||||
runtime_env = OrderedDict()
|
||||
runtime_env.update(env)
|
||||
runtime_env["Seed"] = args.seed
|
||||
runtime_env["World Size"] = dist.get_world_size()
|
||||
|
||||
runtime_env_info = "\n " + "\n ".join(f"{k}: {v}" for k, v in runtime_env.items())
|
||||
dash_line = "-" * 60
|
||||
logger.info("\n" + dash_line + "\nRuntime environment:" + runtime_env_info + "\n" + dash_line + "\n")
|
||||
# ------------------- Environment End ------------------------------ #
|
||||
|
||||
###########################################################################
|
||||
# 3. FSDP #
|
||||
###########################################################################
|
||||
if args.dtype == "auto":
|
||||
args.dtype = "bf16" if DEVICE_MODULE.is_bf16_supported() else "fp16"
|
||||
|
||||
if args.dtype == "fp16":
|
||||
dtype = torch.float16
|
||||
elif args.dtype == "bf16":
|
||||
if DEVICE_MODULE.is_bf16_supported():
|
||||
dtype = torch.bfloat16
|
||||
else:
|
||||
raise RuntimeError("The device does not support `bf16`, " "please set `dtype` to `fp16`.")
|
||||
else:
|
||||
raise RuntimeError("`dtype` only supports `fp16`, `bf16` or `auto`, " f"but found {args.dtype}.")
|
||||
|
||||
with torch.device("meta"):
|
||||
# In order to save CPU memory and GPU memory,
|
||||
# initialize an empty complete model on all ranks first.
|
||||
# At the same time, a non-empty complete model will be loaded
|
||||
# on the CPU of rank0.
|
||||
# After the model is parallelized, the parameters of the complete
|
||||
# model on rank0 will be loaded.
|
||||
actor_model = AutoModelForCausalLM.from_pretrained(args.actor, attn_implementation="flash_attention_2", torch_dtype=dtype)
|
||||
|
||||
for module in actor_model.modules():
|
||||
for p_name, param in module.named_parameters(recurse=False):
|
||||
if param.requires_grad:
|
||||
param_fp32 = torch.nn.Parameter(param.to(dtype=torch.float32))
|
||||
setattr(module, p_name, param_fp32)
|
||||
|
||||
ref_model = AutoModelForCausalLM.from_pretrained(args.reference, attn_implementation="flash_attention_2", torch_dtype=dtype)
|
||||
|
||||
for param in ref_model.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
with profile_time_and_memory("[Parallelize Actor]"):
|
||||
actor_model = AutoPatch.from_causal_lm(
|
||||
actor_model,
|
||||
fsdp_config=FSDPConfig(
|
||||
tp_size=args.tp_size,
|
||||
sp_size=args.sp_size,
|
||||
param_dtype=dtype,
|
||||
reduce_dtype=dtype,
|
||||
cpu_offload=args.cpu_offload,
|
||||
reshard_after_forward=False,
|
||||
mesh_prefix="actor",
|
||||
),
|
||||
)
|
||||
dist.barrier()
|
||||
|
||||
with profile_time_and_memory("[Parallelize Reference]"):
|
||||
ref_model = AutoPatch.from_causal_lm(
|
||||
ref_model,
|
||||
fsdp_config=FSDPConfig(
|
||||
tp_size=args.tp_size,
|
||||
sp_size=args.sp_size,
|
||||
param_dtype=dtype,
|
||||
reduce_dtype=dtype,
|
||||
cpu_offload=args.cpu_offload,
|
||||
reshard_after_forward=True,
|
||||
mesh_prefix="ref",
|
||||
),
|
||||
)
|
||||
dist.barrier()
|
||||
|
||||
# -------------------------- FSDP End ------------------------------ #
|
||||
|
||||
###########################################################################
|
||||
# 2. Dataset & Dataloader #
|
||||
###########################################################################
|
||||
actor_sp_mesh = actor_model.sequence_parallel_mesh
|
||||
actor_dp_mesh = actor_model.data_parallel_mesh
|
||||
actor_data_mesh = actor_model.data_mesh
|
||||
actor_dp_size = actor_dp_mesh.size()
|
||||
|
||||
actor_sp_size = actor_sp_mesh.size()
|
||||
|
||||
prompt_global_batch = args.gen_global_batch // args.prompt_repeat_k
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.actor, trust_remote_code=True, padding_side="right")
|
||||
|
||||
if args.chat_template is not None:
|
||||
if rank == 0:
|
||||
logger.info(f"[CHAT_TEMPLATE] {args.chat_template}")
|
||||
tokenizer.chat_template = CHAT_TEMPLATE_MAP[args.chat_template]["chat_template"]
|
||||
|
||||
stop_token_ids = []
|
||||
if args.stop_word:
|
||||
word_ids = tokenizer.encode(args.stop_word, add_special_tokens=False)
|
||||
else:
|
||||
word_ids = [tokenizer.encode(stop_word, add_special_tokens=False) for stop_word in CHAT_TEMPLATE_MAP[args.chat_template]["stop_words"]]
|
||||
# if len(word_ids) > 1:
|
||||
# raise NotImplementedError("The stop word must be a single token.")
|
||||
stop_token_ids.extend(word_ids)
|
||||
|
||||
with profile_time_and_memory("[Dataset & Dataloader]"):
|
||||
|
||||
prompt_dataset = bootcampPromptDataset(
|
||||
args.datasets,
|
||||
tokenizer,
|
||||
difficulty_balance_cfg=args.data_difficulty_balance_cfg,
|
||||
)
|
||||
if rank == 0:
|
||||
logger.info(f"[Dataset] {len(prompt_dataset)} prompts.")
|
||||
|
||||
assert is_flash_attn_2_available()
|
||||
prompt_collator = PromptCollator(pack_batch=True)
|
||||
prompt_sampler = ParallelSampler(prompt_dataset, actor_dp_mesh, prompt_global_batch, shuffle=True)
|
||||
|
||||
prompt_dataloader = DataLoader(
|
||||
prompt_dataset,
|
||||
batch_size=prompt_global_batch // actor_dp_mesh.size(),
|
||||
num_workers=args.num_workers,
|
||||
# Ensure to round up or drop last based on the `global_batch_size`,
|
||||
# if you want to replace a custom sampler.
|
||||
sampler=prompt_sampler,
|
||||
collate_fn=prompt_collator,
|
||||
persistent_workers=args.num_workers > 0,
|
||||
)
|
||||
|
||||
if rank == 0:
|
||||
logger.info(f"[Dataloader] {len(prompt_dataloader)} batches.")
|
||||
_first_batch = [prompt_dataset[i] for i in range(prompt_global_batch)]
|
||||
logger.debug(f"[Dataloader] Training Batch:\n{_first_batch}")
|
||||
|
||||
dist.barrier()
|
||||
# ------------------- Dataset & Dataloader End --------------------- #
|
||||
|
||||
# --------------------- Router Start ------------------------------- #
|
||||
judger_router = ParallelRouter(
|
||||
judgers_config=args.judgers_config,
|
||||
data_judger_mapping=args.data_judger_mapping,
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
###########################################################################
|
||||
# 4. Optimizer & Scheduler #
|
||||
###########################################################################
|
||||
actor_params = [p for p in actor_model.parameters() if p.requires_grad]
|
||||
actor_optimizer = AdamW(actor_params, lr=args.actor_lr, weight_decay=args.wd)
|
||||
|
||||
|
||||
# if args.total_steps == None:
|
||||
total_steps = len(prompt_dataloader) # automatically settings
|
||||
|
||||
warmup_steps = args.warmup_steps
|
||||
lr_min = args.get("actor_min_lr", args.actor_lr)
|
||||
|
||||
if args.checkpoint_interval == -1:
|
||||
checkpoint_interval = total_steps
|
||||
elif args.checkpoint_interval < 1:
|
||||
checkpoint_interval = int(total_steps * args.checkpoint_interval)
|
||||
else:
|
||||
checkpoint_interval = int(args.checkpoint_interval)
|
||||
|
||||
def warmup_fn(x):
|
||||
return x / warmup_steps if x < warmup_steps else 1
|
||||
|
||||
warmup_scheduler = LambdaLR(actor_optimizer, warmup_fn)
|
||||
cosine_scheduler = CosineAnnealingLR(actor_optimizer, T_max=total_steps - warmup_steps, eta_min=lr_min)
|
||||
|
||||
# ---------------- Optimizer & Scheduler End ----------------------- #
|
||||
|
||||
###########################################################################
|
||||
# 5. Training #
|
||||
###########################################################################
|
||||
|
||||
policy_loss_fn = PGLoss(
|
||||
clip=args.get("pgloss_clip", 0.2),
|
||||
loss_type=args.loss_type,
|
||||
)
|
||||
|
||||
trajectory_dataset = TrajectoryDataset()
|
||||
prompt_iterator = InfiniteDataLoaderIter(prompt_dataloader)
|
||||
|
||||
start_step = 0
|
||||
cur_total_minibatch_steps = 0
|
||||
start_train_t = time.time()
|
||||
DEVICE_MODULE.empty_cache()
|
||||
DEVICE_MODULE.reset_peak_memory_stats()
|
||||
max_memory = DEVICE_MODULE.max_memory_allocated()
|
||||
logger.info("[Train] Begin Train Loop. The current GPU memory is " f"{(max_memory / 1024**3):.1f}GB")
|
||||
|
||||
for step in range(start_step, total_steps):
|
||||
|
||||
if step <= warmup_steps:
|
||||
warmup_scheduler.step()
|
||||
cur_lr = warmup_scheduler.get_last_lr()[0]
|
||||
else:
|
||||
cosine_scheduler.step()
|
||||
cur_lr = cosine_scheduler.get_last_lr()[0]
|
||||
|
||||
DEVICE_MODULE.reset_peak_memory_stats()
|
||||
|
||||
step_kl_penalty_loss = 0
|
||||
step_rl_loss = 0
|
||||
step_start_t = time.time()
|
||||
|
||||
DEVICE_MODULE.reset_peak_memory_stats()
|
||||
|
||||
data = next(prompt_iterator)
|
||||
prompt_input_ids = unpack_sequence(data["input_ids"].to(DEVICE), data["num_tokens"])
|
||||
infer_num_tokens = data["num_tokens"].to(DEVICE)
|
||||
# repeat prompt for k times
|
||||
prompt_input_ids = [p for p in prompt_input_ids for _ in range(args.prompt_repeat_k)] # AAAABBBBCCCC
|
||||
infer_num_tokens = torch.Tensor([n for n in infer_num_tokens for _ in range(args.prompt_repeat_k)])
|
||||
message_data = [m for m in data["message_data"] for _ in range(args.prompt_repeat_k)]
|
||||
metadata = [m for m in data["metadata"] for _ in range(args.prompt_repeat_k)]
|
||||
|
||||
# Stage 1, Actor Model Generation
|
||||
step_avg_new_tokens = 0
|
||||
step_gen_start_t = time.time()
|
||||
|
||||
actor_model.eval()
|
||||
# During the generation stage, sequence parallelism was not used,
|
||||
# even when the sp size is greater than 1.
|
||||
# Per sp rank processes different prompts in parallel.
|
||||
responses = actor_model.generate(
|
||||
prompt_input_ids,
|
||||
stop_token_ids,
|
||||
max_length=args.gen_max_length,
|
||||
max_batch_size=len(prompt_input_ids),
|
||||
max_prefill_batch=args.max_prefill_batch,
|
||||
max_new_tokens=args.gen_max_new,
|
||||
do_sample=args.gen_do_sample,
|
||||
top_k=args.gen_top_k,
|
||||
top_p=args.gen_top_p,
|
||||
temperature=args.temperature,
|
||||
cuda_graph=args.cuda_graph,
|
||||
)
|
||||
|
||||
# decode responses
|
||||
response_texts = [tokenizer.decode(res, skip_special_tokens=False) for res in responses]
|
||||
|
||||
actor_model.train()
|
||||
dist.barrier()
|
||||
|
||||
step_avg_new_tokens = sum([len(res) for res in responses]) / len(responses)
|
||||
step_gen_time = time.time() - step_gen_start_t
|
||||
|
||||
prompt_input_ids = [p[0].tolist() for p in prompt_input_ids]
|
||||
|
||||
# Stage 2, Infer
|
||||
step_infer_start_t = time.time()
|
||||
step_infer_consumed_tokens = 0
|
||||
|
||||
# submit to judger
|
||||
if actor_data_mesh.get_local_rank() == 0:
|
||||
submit_batch = []
|
||||
for i in range(len(message_data)):
|
||||
submit_batch.append(
|
||||
{
|
||||
"prompt_messages": message_data[i],
|
||||
"completion_messages": [{"role": "assistant", "content": response_texts[i]}],
|
||||
"metadata": metadata[i],
|
||||
}
|
||||
)
|
||||
token, indexes_for_local = judger_router.submit(submit_batch)
|
||||
|
||||
# `infer_dataset` varies at each dp rank, there is no need to
|
||||
# use the parallel sampler.
|
||||
infer_dataset = InferDataset(prompt_input_ids, responses, message_data, metadata)
|
||||
infer_dataloader = DataLoader(
|
||||
infer_dataset,
|
||||
batch_size=args.rl_micro_batch,
|
||||
num_workers=0,
|
||||
collate_fn=SftCollator(pack_batch=True),
|
||||
shuffle=False,
|
||||
persistent_workers=False,
|
||||
)
|
||||
|
||||
policies = []
|
||||
for infer_packed_seq in infer_dataloader:
|
||||
# labels are already shifted in InferDataset
|
||||
infer_labels = infer_packed_seq["labels"].to(DEVICE)
|
||||
infer_input_ids = infer_packed_seq["input_ids"].to(DEVICE)
|
||||
infer_num_tokens = infer_packed_seq["num_tokens"].to(DEVICE)
|
||||
infer_batch_size = infer_num_tokens.numel()
|
||||
|
||||
step_infer_consumed_tokens += infer_num_tokens.sum() / actor_data_mesh.size()
|
||||
|
||||
unpacked_input_ids = unpack_sequence(infer_input_ids, infer_num_tokens, dim=1)
|
||||
unpacked_labels = unpack_sequence(infer_labels, infer_num_tokens, dim=1)
|
||||
|
||||
for i in range(infer_batch_size):
|
||||
assert unpacked_input_ids[i].numel() == infer_num_tokens[i]
|
||||
assert unpacked_labels[i].numel() == infer_num_tokens[i]
|
||||
|
||||
_policy = {
|
||||
"input_ids": unpacked_input_ids[i].flatten().tolist(),
|
||||
"labels": unpacked_labels[i].flatten().tolist(),
|
||||
"num_tokens": infer_num_tokens[i].item(),
|
||||
}
|
||||
_policy["sequence_text"] = tokenizer.decode(_policy["input_ids"], skip_special_tokens=False)
|
||||
policies.append(_policy)
|
||||
|
||||
step_infer_time = time.time() - step_infer_start_t
|
||||
|
||||
# --------------------------Get Judger Reward------------------ #
|
||||
# query results from judger
|
||||
if actor_data_mesh.get_local_rank() == 0:
|
||||
while True:
|
||||
try:
|
||||
judger_results = judger_router.query(token, timeout=3)
|
||||
logger.info(f"Query judger results: {judger_results}")
|
||||
break
|
||||
except TimeoutError as e:
|
||||
logger.info(f"Judger query timeout: {e}. Will retry")
|
||||
judger_rewards = [list(r.values())[0] for r in judger_results]
|
||||
judger_rewards = [r if r is not None else -1.0 for r in judger_rewards]
|
||||
judger_rewards = torch.tensor(judger_rewards, dtype=torch.float32).to(DEVICE)
|
||||
else:
|
||||
judger_rewards = torch.tensor([0] * len(policies), dtype=torch.float32).to(DEVICE)
|
||||
|
||||
dist.barrier()
|
||||
# broadcast judger rewards to same data mesh
|
||||
dist.all_reduce(judger_rewards, op=dist.ReduceOp.SUM, group=actor_data_mesh.get_group())
|
||||
|
||||
# reward shaping, use GRPO or RLOO to normalize rewards
|
||||
_rewards = judger_rewards.reshape(-1, args.prompt_repeat_k).T
|
||||
if args.reward_shaping_type == "rloo":
|
||||
baseline = (_rewards.sum(0) - _rewards) / (args.prompt_repeat_k - 1)
|
||||
judger_advantages = _rewards - baseline
|
||||
elif args.reward_shaping_type == "grpo":
|
||||
judger_advantages = (_rewards - _rewards.mean(0)) / (_rewards.std(0) + 1e-8)
|
||||
else:
|
||||
raise NotImplementedError(f"Reward shaping type {args.reward_shaping_type} is not implemented.")
|
||||
judger_advantages = judger_advantages.T.flatten()
|
||||
# update policies
|
||||
assert len(judger_rewards) == len(policies)
|
||||
for i in range(len(policies)):
|
||||
policies[i]["judger_reward"] = judger_rewards[i].item()
|
||||
policies[i]["judger_advantage"] = judger_advantages[i].item()
|
||||
|
||||
step_rl_start_t = time.time()
|
||||
|
||||
_global_policies = [None] * actor_dp_size
|
||||
dist.all_gather_object(_global_policies, policies, actor_dp_mesh.get_group())
|
||||
|
||||
global_policies = []
|
||||
for _rank_policies in _global_policies:
|
||||
global_policies.extend(_rank_policies)
|
||||
|
||||
trajectory_dataset.update(global_policies)
|
||||
# ------------------------------------------------------------- #
|
||||
# --------------------------Stage 4, RL------------------------ #
|
||||
# ------------------------------------------------------------- #
|
||||
if rank == 0:
|
||||
# dump trajectory
|
||||
_buffer_dir = os.path.join(args.work_dir, "trajectories")
|
||||
mkdir_or_exist(_buffer_dir)
|
||||
_buffer_file = os.path.join(_buffer_dir, f"step.{step}.jsonl")
|
||||
trajectory_dataset.dump_jsonl(_buffer_file, tokenizer, args.debug)
|
||||
_buffer_log_file = os.path.join(_buffer_dir, f"step.{step}.log")
|
||||
trajectory_dataset.dump_log(_buffer_log_file, tokenizer, args.debug)
|
||||
|
||||
rl_global_batch = args.rl_global_batch
|
||||
|
||||
rl_loader = DataLoader(
|
||||
trajectory_dataset,
|
||||
batch_size=args.rl_micro_batch,
|
||||
num_workers=0,
|
||||
collate_fn=TrajectoryCollator(pack_batch=True),
|
||||
shuffle=False,
|
||||
sampler=RLParallelSampler(trajectory_dataset, actor_dp_mesh, rl_global_batch, shuffle=False),
|
||||
persistent_workers=False,
|
||||
)
|
||||
|
||||
# Count the total number of tokens used for training RL on all ranks
|
||||
# It is necessary for `per-token` loss, otherwise the number of tokens
|
||||
# for each backward is unbalanced.
|
||||
step_avg_judger_reward = sum([t["judger_reward"] for t in global_policies]) / len(global_policies)
|
||||
|
||||
# --------------------------Infer Old Policy---------------------- #
|
||||
_all_old_logprobs = []
|
||||
_all_action_tokens = []
|
||||
for packed_policy in rl_loader:
|
||||
rl_input_ids = packed_policy["input_ids"].to(DEVICE)
|
||||
rl_num_tokens = packed_policy["num_tokens"].to(DEVICE)
|
||||
assert rl_input_ids.numel() == rl_num_tokens.sum()
|
||||
# labels are already shifted in InferDataset
|
||||
rl_labels = packed_policy["labels"].to(DEVICE)
|
||||
|
||||
judger_rewards = torch.Tensor(packed_policy["judger_rewards"]).to(DEVICE) # shape: (rl_micro_batch, )
|
||||
judger_advantages = torch.Tensor(packed_policy["judger_advantages"]).to(DEVICE) # shape: (rl_micro_batch, )
|
||||
|
||||
actor_input_ids = rl_input_ids.clone()
|
||||
actor_labels = rl_labels.clone()
|
||||
actor_num_tokens = rl_num_tokens.clone().tolist()
|
||||
|
||||
actor_cu_seq_lens = torch.cumsum(torch.IntTensor([0] + actor_num_tokens), dim=0).to(DEVICE).int()
|
||||
actor_position_ids = [torch.arange(num) for num in actor_num_tokens]
|
||||
actor_position_ids = torch.cat(actor_position_ids, dim=0).to(DEVICE).unsqueeze_(0)
|
||||
|
||||
with torch.no_grad():
|
||||
packed_actor_logits = actor_model(
|
||||
input_ids=actor_input_ids,
|
||||
position_ids=actor_position_ids,
|
||||
use_cache=False,
|
||||
cu_seq_lens_q=actor_cu_seq_lens,
|
||||
cu_seq_lens_k=actor_cu_seq_lens,
|
||||
max_length_q=max(actor_num_tokens),
|
||||
max_length_k=max(actor_num_tokens),
|
||||
sequence_parallel_mesh=actor_sp_mesh,
|
||||
).logits
|
||||
|
||||
# The labels of prefill tokens and last token are -100.
|
||||
# HACK: (for sp) The -100 part takes the value of 0,
|
||||
# this part will be masked later.
|
||||
packed_logprobs = actor_model.gather_logprobs(packed_actor_logits, actor_labels.clip(0), actor_sp_mesh)
|
||||
logprobs = unpack_sequence(packed_logprobs, actor_num_tokens, dim=1)
|
||||
logprobs = [l.detach().cpu() for l in logprobs]
|
||||
unpacked_labels = unpack_sequence(rl_labels, rl_num_tokens, dim=1)
|
||||
_num_action_tokens = [(unpacked_labels[i] >= 0).sum() for i in range(len(unpacked_labels))]
|
||||
_all_action_tokens.extend(_num_action_tokens)
|
||||
_all_old_logprobs.extend(logprobs)
|
||||
|
||||
|
||||
# --------------------------Mini-batch Train Policy---------------------- #
|
||||
rl_loader_iter = iter(rl_loader)
|
||||
_sample_idx = 0
|
||||
num_mini_batch_samples = len(rl_loader) // args.rl_mini_batch_steps
|
||||
all_mini_batch_action_tokens = [sum(_all_action_tokens[i*num_mini_batch_samples:(i+1)*num_mini_batch_samples]) for i in range(args.rl_mini_batch_steps)]
|
||||
for mini_batch_step in range(args.rl_mini_batch_steps):
|
||||
step_sum_gen_entropy = 0
|
||||
step_sum_ref_kl = 0
|
||||
step_action_tokens = 0
|
||||
step_rl_consumed_tokens = 0
|
||||
step_sum_adv = 0
|
||||
|
||||
for _train_iter in range(len(rl_loader) // args.rl_mini_batch_steps // args.rl_micro_batch):
|
||||
packed_policy = next(rl_loader_iter)
|
||||
rl_input_ids = packed_policy["input_ids"].to(DEVICE)
|
||||
rl_num_tokens = packed_policy["num_tokens"].to(DEVICE)
|
||||
assert rl_input_ids.numel() == rl_num_tokens.sum()
|
||||
rl_batch_size = rl_num_tokens.numel()
|
||||
# labels are already shifted in InferDataset
|
||||
rl_labels = packed_policy["labels"].to(DEVICE)
|
||||
|
||||
judger_rewards = torch.Tensor(packed_policy["judger_rewards"]).to(DEVICE) # shape: (rl_micro_batch, )
|
||||
judger_advantages = torch.Tensor(packed_policy["judger_advantages"]).to(DEVICE) # shape: (rl_micro_batch, )
|
||||
|
||||
actor_input_ids = rl_input_ids.clone()
|
||||
actor_labels = rl_labels.clone()
|
||||
actor_num_tokens = rl_num_tokens.clone().tolist()
|
||||
|
||||
actor_cu_seq_lens = torch.cumsum(torch.IntTensor([0] + actor_num_tokens), dim=0).to(DEVICE).int()
|
||||
actor_position_ids = [torch.arange(num) for num in actor_num_tokens]
|
||||
actor_position_ids = torch.cat(actor_position_ids, dim=0).to(DEVICE).unsqueeze_(0)
|
||||
|
||||
packed_actor_logits = actor_model(
|
||||
input_ids=actor_input_ids,
|
||||
position_ids=actor_position_ids,
|
||||
use_cache=False,
|
||||
cu_seq_lens_q=actor_cu_seq_lens,
|
||||
cu_seq_lens_k=actor_cu_seq_lens,
|
||||
max_length_q=max(actor_num_tokens),
|
||||
max_length_k=max(actor_num_tokens),
|
||||
sequence_parallel_mesh=actor_sp_mesh,
|
||||
).logits
|
||||
|
||||
with torch.no_grad():
|
||||
packed_ref_logits = ref_model(
|
||||
input_ids=actor_input_ids,
|
||||
position_ids=actor_position_ids,
|
||||
use_cache=False,
|
||||
cu_seq_lens_q=actor_cu_seq_lens,
|
||||
cu_seq_lens_k=actor_cu_seq_lens,
|
||||
max_length_q=max(actor_num_tokens),
|
||||
max_length_k=max(actor_num_tokens),
|
||||
sequence_parallel_mesh=actor_sp_mesh,
|
||||
).logits
|
||||
|
||||
# The labels of prefill tokens and last token are -100.
|
||||
# HACK: (for sp) The -100 part takes the value of 0,
|
||||
# this part will be masked later.
|
||||
packed_logprobs = actor_model.gather_logprobs(packed_actor_logits, actor_labels.clip(0), actor_sp_mesh)
|
||||
logprobs = unpack_sequence(packed_logprobs, actor_num_tokens, dim=1)
|
||||
packed_ref_logprobs = ref_model.gather_logprobs(packed_ref_logits, actor_labels.clip(0), actor_sp_mesh)
|
||||
ref_logprobs = unpack_sequence(packed_ref_logprobs, actor_num_tokens, dim=1)
|
||||
# The labels of prefill tokens and last token are -100.
|
||||
# HACK: (for sp) The -100 part takes the value of 0,
|
||||
# this part will be masked later.
|
||||
unpacked_labels = unpack_sequence(rl_labels, rl_num_tokens, dim=1)
|
||||
|
||||
_losses = []
|
||||
for i in range(rl_batch_size):
|
||||
assert unpacked_labels[i].numel() == rl_num_tokens[i]
|
||||
# from the last prefill token, to the second-to-last token (excluding the eos token)
|
||||
_num_action_tokens = (unpacked_labels[i] >= 0).sum()
|
||||
|
||||
_logprobs = logprobs[i][0, -_num_action_tokens - 1 : -1]
|
||||
_ref_logprobs = ref_logprobs[i][0, -_num_action_tokens - 1 : -1]
|
||||
_old_logprobs = _all_old_logprobs[_sample_idx][0, -_num_action_tokens - 1 : -1].to(DEVICE)
|
||||
_judger_advantages = judger_advantages[i]
|
||||
|
||||
_advantages = _judger_advantages
|
||||
|
||||
_loss_factor = 1/float(all_mini_batch_action_tokens[mini_batch_step])
|
||||
_loss = policy_loss_fn(_logprobs, _old_logprobs, _advantages, loss_factor=_loss_factor)
|
||||
kl_type = args.get("kl_type", "unbias") # kl, unbias, mse
|
||||
if kl_type == "kl":
|
||||
kl = _ref_logprobs - _logprobs
|
||||
_kl_penalty_loss = (args.kl_coef * kl).sum() * _loss_factor
|
||||
elif kl_type == "unbias":
|
||||
kl = _ref_logprobs - _logprobs
|
||||
nonneg_nobias_kl = torch.exp(kl) - kl - 1
|
||||
_kl_penalty_loss = (args.kl_coef * nonneg_nobias_kl).sum() * _loss_factor
|
||||
elif kl_type == "mse":
|
||||
_kl_penalty_loss = (
|
||||
(args.kl_coef * (_ref_logprobs - _logprobs).square() / 2).sum() * _loss_factor
|
||||
)
|
||||
|
||||
_loss = _loss + _kl_penalty_loss
|
||||
_losses.append(_loss)
|
||||
|
||||
step_sum_gen_entropy += -_old_logprobs.sum().item()
|
||||
step_sum_ref_kl += (_old_logprobs - _ref_logprobs).sum().item()
|
||||
step_sum_adv += _judger_advantages.sum().item()
|
||||
step_action_tokens += _num_action_tokens.item()
|
||||
_sample_idx += 1
|
||||
|
||||
loss = sum(_losses)
|
||||
loss.backward()
|
||||
|
||||
# for logging
|
||||
step_rl_loss += loss.item()
|
||||
step_rl_consumed_tokens += rl_num_tokens.sum() / actor_data_mesh.size()
|
||||
|
||||
step_rl_time = time.time() - step_rl_start_t
|
||||
step_avg_ref_kl = step_sum_ref_kl / step_action_tokens
|
||||
step_avg_gen_entropy = step_sum_gen_entropy / step_action_tokens
|
||||
step_avg_adv = step_sum_adv / step_action_tokens
|
||||
|
||||
actor_data_group = actor_data_mesh.get_group()
|
||||
step_avg_ref_kl = reduce_mean(step_avg_ref_kl, actor_data_group)
|
||||
step_avg_gen_entropy = reduce_mean(step_avg_gen_entropy, actor_data_group)
|
||||
step_avg_adv = reduce_mean(step_avg_adv, actor_data_group)
|
||||
step_avg_new_tokens = reduce_mean(step_avg_new_tokens, actor_data_group)
|
||||
|
||||
|
||||
actor_grad_norm = actor_model.clip_grad_norm(args.max_grad_norm)
|
||||
actor_grad_norm = actor_grad_norm.item()
|
||||
actor_optimizer.step()
|
||||
actor_optimizer.zero_grad()
|
||||
|
||||
|
||||
step_time = time.time() - step_start_t
|
||||
eta = step_time * (total_steps * args.rl_mini_batch_steps - cur_total_minibatch_steps)
|
||||
eta = timedelta(seconds=int(eta))
|
||||
|
||||
infer_tgs = int(step_infer_consumed_tokens / step_infer_time)
|
||||
rl_tgs = int(step_rl_consumed_tokens / step_rl_time)
|
||||
|
||||
actor_lr = cur_lr
|
||||
max_memory = DEVICE_MODULE.max_memory_allocated()
|
||||
log_dict = {
|
||||
"step": cur_total_minibatch_steps + 1,
|
||||
"minibatch": mini_batch_step + 1,
|
||||
"global_step": step + 1,
|
||||
"actor_lr": actor_lr,
|
||||
"actor_grad_norm": actor_grad_norm,
|
||||
"avg_judger_reward": step_avg_judger_reward,
|
||||
"avg_adv": step_avg_adv,
|
||||
"avg_gen_entropy": step_avg_gen_entropy,
|
||||
"avg_ref_kl": step_avg_ref_kl,
|
||||
"rl_loss": step_rl_loss,
|
||||
"max_memory": max_memory / 1024**3,
|
||||
"avg_new_tokens": step_avg_new_tokens,
|
||||
"num_rl_tokens": step_rl_consumed_tokens,
|
||||
"infer_tgs": infer_tgs,
|
||||
"rl_tgs": rl_tgs,
|
||||
"gen_time": step_gen_time,
|
||||
"infer_time": step_infer_time,
|
||||
"rl_time": step_rl_time,
|
||||
"total_time": step_time,
|
||||
"eta": eta.seconds,
|
||||
}
|
||||
for key, value in log_dict.items():
|
||||
if isinstance(value, torch.Tensor):
|
||||
log_dict[key] = value.item()
|
||||
with open(os.path.join(args.work_dir, f"rank{rank}.log.jsonl"), "a") as f:
|
||||
f.write(json.dumps(log_dict, ensure_ascii=False) + "\n")
|
||||
|
||||
if is_interval(cur_total_minibatch_steps, total_steps * args.rl_mini_batch_steps, args.log_interval):
|
||||
logger.info(
|
||||
f"[Train] Step {cur_total_minibatch_steps + 1} / Mini-batch {mini_batch_step + 1} "
|
||||
f"global_step: {step + 1}/{total_steps} "
|
||||
f"actor_lr: {cur_lr:.3e} "
|
||||
f"actor_grad_norm: {actor_grad_norm:.3f} "
|
||||
f"avg_judger_reward: {step_avg_judger_reward:.8f} "
|
||||
f"avg_adv: {step_avg_adv:.8f} "
|
||||
f"avg_gen_entropy: {step_avg_gen_entropy:.3f} "
|
||||
f"avg_ref_kl: {step_avg_ref_kl:.8f} "
|
||||
f"rl_loss: {step_rl_loss:.3f} "
|
||||
f"max_memory: {(max_memory / 1024**3):.1f}GB "
|
||||
f"avg_new_tokens: {int(step_avg_new_tokens)} "
|
||||
f"num_rl_tokens: {int(step_rl_consumed_tokens)} "
|
||||
f"infer_tgs: {int(infer_tgs)} "
|
||||
f"rl_tgs: {int(rl_tgs)} "
|
||||
f"gen_time: {step_gen_time:.2f}s "
|
||||
f"infer_time: {step_infer_time:.2f}s "
|
||||
f"rl_time: {step_rl_time:.2f}s "
|
||||
f"total_time: {step_time:.2f}s "
|
||||
f"eta: {eta}"
|
||||
)
|
||||
|
||||
if is_interval(cur_total_minibatch_steps, total_steps * args.rl_mini_batch_steps, checkpoint_interval):
|
||||
DEVICE_MODULE.empty_cache()
|
||||
|
||||
num_digits = len(str(abs(total_steps)))
|
||||
work_dir = args.work_dir
|
||||
ckpt_dir = os.path.join(work_dir, f"ckpt-{cur_total_minibatch_steps+1:0{num_digits}}")
|
||||
hf_dir = os.path.join(work_dir, f"hf-{cur_total_minibatch_steps+1:0{num_digits}}")
|
||||
|
||||
with profile_time_and_memory("[Checkpoint]"):
|
||||
actor_model.save_pretrained(hf_dir)
|
||||
if rank == 0:
|
||||
tokenizer.save_pretrained(hf_dir)
|
||||
|
||||
dist.barrier()
|
||||
cur_total_minibatch_steps += 1
|
||||
|
||||
train_cost_time = time.time() - start_train_t
|
||||
logger.success(f"[Train] Cost {timedelta(seconds=int(train_cost_time))}")
|
||||
# ------------------------ Training End ---------------------------- #
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(train_grpo)
|
||||
76
examples/xpuyu_usage/xpuyu_data_preprocess.py
Executable file
76
examples/xpuyu_usage/xpuyu_data_preprocess.py
Executable file
|
|
@ -0,0 +1,76 @@
|
|||
import os
|
||||
import json
|
||||
import fire
|
||||
import subprocess
|
||||
|
||||
import sys
|
||||
sys.set_int_max_str_digits(128*1024)
|
||||
|
||||
def convert_jsonl(src_jsonl, tgt_jsonl):
|
||||
"""将单个 .jsonl 文件转换为目标格式"""
|
||||
with open(tgt_jsonl, "w", encoding="utf-8") as writer:
|
||||
with open(src_jsonl, "r", encoding="utf-8") as reader:
|
||||
for line in reader:
|
||||
item = json.loads(line)
|
||||
new_item = {
|
||||
"message_data": [{"role": "user", "content": item["prompt"]}],
|
||||
"metadata": {
|
||||
"data_source": item["data_source"], # 必要字段,用于配置文件中将数据源和 judger 对应
|
||||
"ground_truth": item["ground_truth"],
|
||||
}
|
||||
}
|
||||
writer.write(json.dumps(new_item, ensure_ascii=False) + '\n')
|
||||
|
||||
|
||||
def _main(src, tgt):
|
||||
"""递归处理目录或文件"""
|
||||
if os.path.isdir(src):
|
||||
# 如果是目录,创建对应的目标目录
|
||||
os.makedirs(tgt, exist_ok=True)
|
||||
for sub in os.listdir(src):
|
||||
src_path = os.path.join(src, sub)
|
||||
tgt_path = os.path.join(tgt, sub)
|
||||
_main(src_path, tgt_path)
|
||||
elif src.endswith(".jsonl"):
|
||||
# 如果是 .jsonl 文件,添加 xpuyu 前缀并进行转换
|
||||
base_name = os.path.basename(src)
|
||||
tgt_file_name = f"xpuyu_{base_name}" # 添加 xpuyu 前缀
|
||||
tgt_path = os.path.join(os.path.dirname(tgt), tgt_file_name)
|
||||
tmp_tgt = tgt_path + ".tmp"
|
||||
try:
|
||||
convert_jsonl(src, tmp_tgt)
|
||||
subprocess.run(f"mv {tmp_tgt} {tgt_path}", shell=True, check=True)
|
||||
except Exception as e:
|
||||
print(f"Error processing {src}: {e}")
|
||||
subprocess.run(f"rm -f {tmp_tgt}", shell=True, check=True)
|
||||
|
||||
|
||||
def main(src, tgt=None):
|
||||
"""
|
||||
主函数,支持目录或文件作为输入
|
||||
:param src: 源文件或目录路径
|
||||
:param tgt: 目标文件或目录路径
|
||||
"""
|
||||
if not tgt and os.path.isdir(src):
|
||||
tgt = src + '_for_xpuyu'
|
||||
|
||||
if not os.path.exists(src):
|
||||
raise ValueError(f"Source path does not exist: {src}")
|
||||
|
||||
if os.path.isfile(src) and not src.endswith(".jsonl"):
|
||||
raise ValueError(f"Source file is not a .jsonl file: {src}")
|
||||
|
||||
_main(src, tgt)
|
||||
subprocess.run(f"cat {tgt}/train/*.jsonl > {tgt}/merge_train.jsonl", shell=True, check=True)
|
||||
subprocess.run(f"shuf {tgt}/merge_train.jsonl -o {tgt}/merge_train.jsonl", shell=True, check=True)
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
"""
|
||||
示例用法:
|
||||
python examples/xpuyu_usage/xpuyu_preprocess.py --src examples/bootcamp_generator_outputs/2025-03-07-16:48:28
|
||||
将 `v2_bootcamp_data` 目录下的所有 .jsonl 文件转换为 xpuyu 格式 .jsonl,并保留目录结构输出到默认输出目录
|
||||
输出的 .jsonl 文件会带有 xpuyu 前缀。
|
||||
"""
|
||||
fire.Fire(main)
|
||||
Loading…
Add table
Add a link
Reference in a new issue