# Copyright (c) InternLM. All rights reserved. import json import os import sys import time from collections import OrderedDict from datetime import datetime, timedelta import fire import torch import torch.distributed as dist from mmengine import mkdir_or_exist from mmengine.runner import set_random_seed from mmengine.utils import get_git_hash from mmengine.utils.dl_utils import collect_env from torch.nn import functional as F from torch.optim import AdamW from torch.optim.lr_scheduler import CosineAnnealingLR, LambdaLR from torch.utils.data import DataLoader from transformers import AutoModelForCausalLM, AutoTokenizer from transformers.utils.import_utils import is_flash_attn_2_available from xtuner._lite import get_device, get_logger, get_torch_device_module from xtuner._lite.accelerate import profile_time_and_memory, unpack_sequence from xtuner._lite.algorithms.sft import SftCollator from xtuner._lite.modelings import register_remote_code from xtuner._lite.parallel import ( ParallelSampler, setup_parallel, ) from xtuner._lite.patches import AutoPatch, FSDPConfig from bootcamp_rl.datasets import ( InferDataset, bootcampPromptDataset, PromptCollator, TrajectoryCollator, TrajectoryDataset, InfiniteDataLoaderIter, ) from bootcamp_rl.judgers import ParallelRouter from bootcamp_rl.utils import Config logger = get_logger() DEVICE = get_device() DEVICE_MODULE = get_torch_device_module() torch._dynamo.config.cache_size_limit = 16384 CHAT_TEMPLATE_MAP = { "qwen": { "chat_template":"{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0]['role'] == 'system' %}\n {{- messages[0]['content'] }}\n {%- else %}\n {{- 'You are a helpful assistant.' }}\n {%- endif %}\n {{- \"\\n\\n# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within XML tags:\\n\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n\\n\\nFor each function call, return a json object with function name and arguments within XML tags:\\n\\n{\\\"name\\\": , \\\"arguments\\\": }\\n<|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0]['role'] == 'system' %}\n {{- '<|im_start|>system\\n' + messages[0]['content'] + '<|im_end|>\\n' }}\n {%- else %}\n {{- '<|im_start|>system\\nYou are a helpful assistant.<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- for message in messages %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) or (message.role == \"assistant\" and not message.tool_calls) %}\n {{- '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {{- '<|im_start|>' + message.role }}\n {%- if message.content %}\n {{- '\\n' + message.content }}\n {%- endif %}\n {%- for tool_call in message.tool_calls %}\n {%- if tool_call.function is defined %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '\\n\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {{- tool_call.arguments | tojson }}\n {{- '}\\n' }}\n {%- endfor %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n\\n' }}\n {{- message.content }}\n {{- '\\n' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n{%- endif %}\n", "stop_words":["<|im_end|>", "<|endoftext|>"], }, "internthinker": { "chat_template":"{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0]['role'] == 'system' %}\n {{- messages[0]['content'] }}\n {%- else %}\n {{- 'You are an expert reasoner with extensive experience in mathematical and code competitions. You approach problems through systematic thinking and rigorous reasoning. Your response should reflect deep understanding and precise logical thinking, making your solution path and reasoning clear to others. Please put your thinking process within ... tags.' }}\n {%- endif %}\n {{- \"\\n\\n# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within XML tags:\\n\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n\\n\\nFor each function call, return a json object with function name and arguments within XML tags:\\n\\n{\\\"name\\\": , \\\"arguments\\\": }\\n<|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0]['role'] == 'system' %}\n {{- '<|im_start|>system\\n' + messages[0]['content'] + '<|im_end|>\\n' }}\n {%- else %}\n {{- '<|im_start|>system\\nYou are an expert reasoner with extensive experience in mathematical and code competitions. You approach problems through systematic thinking and rigorous reasoning. Your response should reflect deep understanding and precise logical thinking, making your solution path and reasoning clear to others. Please put your thinking process within ... tags.<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- for message in messages %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) or (message.role == \"assistant\" and not message.tool_calls) %}\n {{- '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {{- '<|im_start|>' + message.role }}\n {%- if message.content %}\n {{- '\\n' + message.content }}\n {%- endif %}\n {%- for tool_call in message.tool_calls %}\n {%- if tool_call.function is defined %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '\\n\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {{- tool_call.arguments | tojson }}\n {{- '}\\n' }}\n {%- endfor %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n\\n' }}\n {{- message.content }}\n {{- '\\n' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n{%- endif %}\n", "stop_words":["<|im_end|>", "<|endoftext|>"], }, "r1": { "chat_template":"{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set ns = namespace(is_first=false, is_tool=false, is_output_first=true, system_prompt='', is_first_sp=true) %}{%- for message in messages %}{%- if message['role'] == 'system' %}{%- if ns.is_first_sp %}{% set ns.system_prompt = ns.system_prompt + message['content'] %}{% set ns.is_first_sp = false %}{%- else %}{% set ns.system_prompt = ns.system_prompt + '\\n\\n' + message['content'] %}{%- endif %}{%- endif %}{%- endfor %}{{ bos_token }}{{ ns.system_prompt }}{%- for message in messages %}{%- if message['role'] == 'user' %}{%- set ns.is_tool = false -%}{{'<|User|>' + message['content']}}{%- endif %}{%- if message['role'] == 'assistant' and 'tool_calls' in message %}{%- set ns.is_tool = false -%}{%- for tool in message['tool_calls'] %}{%- if not ns.is_first %}{%- if message['content'] is none %}{{'<|Assistant|><|tool▁calls▁begin|><|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\\n' + '```json' + '\\n' + tool['function']['arguments'] + '\\n' + '```' + '<|tool▁call▁end|>'}}{%- else %}{{'<|Assistant|>' + message['content'] + '<|tool▁calls▁begin|><|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\\n' + '```json' + '\\n' + tool['function']['arguments'] + '\\n' + '```' + '<|tool▁call▁end|>'}}{%- endif %}{%- set ns.is_first = true -%}{%- else %}{{'\\n' + '<|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\\n' + '```json' + '\\n' + tool['function']['arguments'] + '\\n' + '```' + '<|tool▁call▁end|>'}}{%- endif %}{%- endfor %}{{'<|tool▁calls▁end|><|end▁of▁sentence|>'}}{%- endif %}{%- if message['role'] == 'assistant' and 'tool_calls' not in message %}{%- if ns.is_tool %}{{'<|tool▁outputs▁end|>' + message['content'] + '<|end▁of▁sentence|>'}}{%- set ns.is_tool = false -%}{%- else %}{% set content = message['content'] %}{% if '' in content %}{% set content = content.split('')[-1] %}{% endif %}{{'<|Assistant|>' + content + '<|end▁of▁sentence|>'}}{%- endif %}{%- endif %}{%- if message['role'] == 'tool' %}{%- set ns.is_tool = true -%}{%- if ns.is_output_first %}{{'<|tool▁outputs▁begin|><|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}}{%- set ns.is_output_first = false %}{%- else %}{{'<|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}}{%- endif %}{%- endif %}{%- endfor -%}{% if ns.is_tool %}{{'<|tool▁outputs▁end|>'}}{% endif %}{% if add_generation_prompt and not ns.is_tool %}{{'<|Assistant|>\\n'}}{% endif %}", "stop_words":["<|end▁of▁sentence|>"], }, } class RLParallelSampler(ParallelSampler): def __iter__(self): """Iterate the indices.""" # deterministically shuffle based on epoch and seed if self.shuffle: g = torch.Generator() g.manual_seed(self.seed + self.epoch) indices = torch.randperm(len(self.dataset), generator=g).tolist() else: indices = torch.arange(len(self.dataset)).tolist() # add extra samples to make it evenly divisible if self.round_up: indices = (indices * int(self.total_size / len(indices) + 1))[ : self.total_size ] # subsample chunk_size = len(indices) // self.world_size start = self.rank * chunk_size end = start + chunk_size indices = indices[start:end] return iter(indices[self.step :]) class PGLoss(torch.nn.Module): """Policy Gradient Loss for policy model.""" def __init__(self, clip: float = 0.2, loss_type: str = "per_seq"): super().__init__() self.clip = clip self.loss_type = loss_type assert self.loss_type in ["per_token", "per_seq"] def forward(self, logprobs, old_logprobs, advantages, loss_factor=None): if self.loss_type == "per_seq": return self.forward_per_seq(logprobs, old_logprobs, advantages, loss_factor) elif self.loss_type == "per_token": return self.forward_per_token(logprobs, old_logprobs, advantages, loss_factor) def forward_per_seq(self, logprobs, old_logprobs, advantages, loss_factor=None): logprobs = logprobs.sum(1) old_logprobs = old_logprobs.sum(1) logprobs_diff = logprobs - old_logprobs ratio = torch.exp(logprobs_diff) pg_losses = -advantages * ratio pg_losses2 = -advantages * torch.clamp(ratio, 1.0 - self.clip, 1.0 + self.clip) pg_loss_max = torch.max(pg_losses, pg_losses2) pg_loss = pg_loss_max.mean() return pg_loss def forward_per_token(self, logprobs, old_logprobs, advantages, loss_factor=None): ratio = (logprobs - old_logprobs).exp() pg_loss1 = -ratio * advantages pg_loss2 = -ratio.clamp(1 - self.clip, 1 + self.clip) * advantages pg_loss_max = torch.max(pg_loss1, pg_loss2) assert loss_factor is not None pg_loss = torch.sum(pg_loss_max) * loss_factor return pg_loss def log_format(rank, debug=False): formatter = f"[XTuner][RANK {rank}]" formatter += "[{time:YYYY-MM-DD HH:mm:ss}][{level}]" if debug: formatter += "[{name}:" formatter += "{function}:" formatter += "{line}]" formatter += " {message}" return formatter def is_interval(step, total_steps, interval): return (step + 1) % interval == 0 or (step + 1) == total_steps def reduce_mean(data, group): data_tensor = torch.tensor(data, device=DEVICE) dist.all_reduce(data_tensor, op=dist.ReduceOp.AVG, group=group) return data_tensor.item() def train_grpo(cfg_path, **kwargs): args = Config.fromfile(cfg_path) args.update(kwargs) ########################################################################### # 1. Environment # ########################################################################### register_remote_code() setup_parallel() set_random_seed(args.seed) rank = dist.get_rank() timestamp = datetime.now().strftime("%Y%m%d%H%M%S") objects = [timestamp] dist.broadcast_object_list(objects, src=0) timestamp = objects[0] args.work_dir = os.path.join(args.work_dir, timestamp) mkdir_or_exist(args.work_dir) log_file = os.path.join(args.work_dir, f"rank{rank}.log") # Change the log format printed in the terminal lvl = "DEBUG" if args.debug else "INFO" logger.remove() logger.add(sys.stderr, level=lvl, format=log_format(rank, args.debug)) # Change the format saved in the log file logger.add(log_file, format=log_format(rank), backtrace=True, catch=True) logger.info(args) if rank == 0: env = collect_env() import transformers import xtuner env["Transformers"] = transformers.__version__ env["XTuner"] = f"{xtuner.__version__}+{get_git_hash(digits=6)}" runtime_env = OrderedDict() runtime_env.update(env) runtime_env["Seed"] = args.seed runtime_env["World Size"] = dist.get_world_size() runtime_env_info = "\n " + "\n ".join(f"{k}: {v}" for k, v in runtime_env.items()) dash_line = "-" * 60 logger.info("\n" + dash_line + "\nRuntime environment:" + runtime_env_info + "\n" + dash_line + "\n") # ------------------- Environment End ------------------------------ # ########################################################################### # 3. FSDP # ########################################################################### if args.dtype == "auto": args.dtype = "bf16" if DEVICE_MODULE.is_bf16_supported() else "fp16" if args.dtype == "fp16": dtype = torch.float16 elif args.dtype == "bf16": if DEVICE_MODULE.is_bf16_supported(): dtype = torch.bfloat16 else: raise RuntimeError("The device does not support `bf16`, " "please set `dtype` to `fp16`.") else: raise RuntimeError("`dtype` only supports `fp16`, `bf16` or `auto`, " f"but found {args.dtype}.") with torch.device("meta"): # In order to save CPU memory and GPU memory, # initialize an empty complete model on all ranks first. # At the same time, a non-empty complete model will be loaded # on the CPU of rank0. # After the model is parallelized, the parameters of the complete # model on rank0 will be loaded. actor_model = AutoModelForCausalLM.from_pretrained(args.actor, attn_implementation="flash_attention_2", torch_dtype=dtype) for module in actor_model.modules(): for p_name, param in module.named_parameters(recurse=False): if param.requires_grad: param_fp32 = torch.nn.Parameter(param.to(dtype=torch.float32)) setattr(module, p_name, param_fp32) ref_model = AutoModelForCausalLM.from_pretrained(args.reference, attn_implementation="flash_attention_2", torch_dtype=dtype) for param in ref_model.parameters(): param.requires_grad = False with profile_time_and_memory("[Parallelize Actor]"): actor_model = AutoPatch.from_causal_lm( actor_model, fsdp_config=FSDPConfig( tp_size=args.tp_size, sp_size=args.sp_size, param_dtype=dtype, reduce_dtype=dtype, cpu_offload=args.cpu_offload, reshard_after_forward=False, mesh_prefix="actor", ), ) dist.barrier() with profile_time_and_memory("[Parallelize Reference]"): ref_model = AutoPatch.from_causal_lm( ref_model, fsdp_config=FSDPConfig( tp_size=args.tp_size, sp_size=args.sp_size, param_dtype=dtype, reduce_dtype=dtype, cpu_offload=args.cpu_offload, reshard_after_forward=True, mesh_prefix="ref", ), ) dist.barrier() # -------------------------- FSDP End ------------------------------ # ########################################################################### # 2. Dataset & Dataloader # ########################################################################### actor_sp_mesh = actor_model.sequence_parallel_mesh actor_dp_mesh = actor_model.data_parallel_mesh actor_data_mesh = actor_model.data_mesh actor_dp_size = actor_dp_mesh.size() actor_sp_size = actor_sp_mesh.size() prompt_global_batch = args.gen_global_batch // args.prompt_repeat_k tokenizer = AutoTokenizer.from_pretrained(args.actor, trust_remote_code=True, padding_side="right") if args.chat_template is not None: if rank == 0: logger.info(f"[CHAT_TEMPLATE] {args.chat_template}") tokenizer.chat_template = CHAT_TEMPLATE_MAP[args.chat_template]["chat_template"] stop_token_ids = [] if args.stop_word: word_ids = tokenizer.encode(args.stop_word, add_special_tokens=False) else: word_ids = [tokenizer.encode(stop_word, add_special_tokens=False) for stop_word in CHAT_TEMPLATE_MAP[args.chat_template]["stop_words"]] # if len(word_ids) > 1: # raise NotImplementedError("The stop word must be a single token.") stop_token_ids.extend(word_ids) with profile_time_and_memory("[Dataset & Dataloader]"): prompt_dataset = bootcampPromptDataset( args.datasets, tokenizer, difficulty_balance_cfg=args.data_difficulty_balance_cfg, ) if rank == 0: logger.info(f"[Dataset] {len(prompt_dataset)} prompts.") assert is_flash_attn_2_available() prompt_collator = PromptCollator(pack_batch=True) prompt_sampler = ParallelSampler(prompt_dataset, actor_dp_mesh, prompt_global_batch, shuffle=True) prompt_dataloader = DataLoader( prompt_dataset, batch_size=prompt_global_batch // actor_dp_mesh.size(), num_workers=args.num_workers, # Ensure to round up or drop last based on the `global_batch_size`, # if you want to replace a custom sampler. sampler=prompt_sampler, collate_fn=prompt_collator, persistent_workers=args.num_workers > 0, ) if rank == 0: logger.info(f"[Dataloader] {len(prompt_dataloader)} batches.") _first_batch = [prompt_dataset[i] for i in range(prompt_global_batch)] logger.debug(f"[Dataloader] Training Batch:\n{_first_batch}") dist.barrier() # ------------------- Dataset & Dataloader End --------------------- # # --------------------- Router Start ------------------------------- # judger_router = ParallelRouter( judgers_config=args.judgers_config, data_judger_mapping=args.data_judger_mapping, logger=logger, ) ########################################################################### # 4. Optimizer & Scheduler # ########################################################################### actor_params = [p for p in actor_model.parameters() if p.requires_grad] actor_optimizer = AdamW(actor_params, lr=args.actor_lr, weight_decay=args.wd) # if args.total_steps == None: total_steps = len(prompt_dataloader) # automatically settings warmup_steps = args.warmup_steps lr_min = args.get("actor_min_lr", args.actor_lr) if args.checkpoint_interval == -1: checkpoint_interval = total_steps elif args.checkpoint_interval < 1: checkpoint_interval = int(total_steps * args.checkpoint_interval) else: checkpoint_interval = int(args.checkpoint_interval) def warmup_fn(x): return x / warmup_steps if x < warmup_steps else 1 warmup_scheduler = LambdaLR(actor_optimizer, warmup_fn) cosine_scheduler = CosineAnnealingLR(actor_optimizer, T_max=total_steps - warmup_steps, eta_min=lr_min) # ---------------- Optimizer & Scheduler End ----------------------- # ########################################################################### # 5. Training # ########################################################################### policy_loss_fn = PGLoss( clip=args.get("pgloss_clip", 0.2), loss_type=args.loss_type, ) trajectory_dataset = TrajectoryDataset() prompt_iterator = InfiniteDataLoaderIter(prompt_dataloader) start_step = 0 cur_total_minibatch_steps = 0 start_train_t = time.time() DEVICE_MODULE.empty_cache() DEVICE_MODULE.reset_peak_memory_stats() max_memory = DEVICE_MODULE.max_memory_allocated() logger.info("[Train] Begin Train Loop. The current GPU memory is " f"{(max_memory / 1024**3):.1f}GB") for step in range(start_step, total_steps): if step <= warmup_steps: warmup_scheduler.step() cur_lr = warmup_scheduler.get_last_lr()[0] else: cosine_scheduler.step() cur_lr = cosine_scheduler.get_last_lr()[0] DEVICE_MODULE.reset_peak_memory_stats() step_kl_penalty_loss = 0 step_rl_loss = 0 step_start_t = time.time() DEVICE_MODULE.reset_peak_memory_stats() data = next(prompt_iterator) prompt_input_ids = unpack_sequence(data["input_ids"].to(DEVICE), data["num_tokens"]) infer_num_tokens = data["num_tokens"].to(DEVICE) # repeat prompt for k times prompt_input_ids = [p for p in prompt_input_ids for _ in range(args.prompt_repeat_k)] # AAAABBBBCCCC infer_num_tokens = torch.Tensor([n for n in infer_num_tokens for _ in range(args.prompt_repeat_k)]) message_data = [m for m in data["message_data"] for _ in range(args.prompt_repeat_k)] metadata = [m for m in data["metadata"] for _ in range(args.prompt_repeat_k)] # Stage 1, Actor Model Generation step_avg_new_tokens = 0 step_gen_start_t = time.time() actor_model.eval() # During the generation stage, sequence parallelism was not used, # even when the sp size is greater than 1. # Per sp rank processes different prompts in parallel. responses = actor_model.generate( prompt_input_ids, stop_token_ids, max_length=args.gen_max_length, max_batch_size=len(prompt_input_ids), max_prefill_batch=args.max_prefill_batch, max_new_tokens=args.gen_max_new, do_sample=args.gen_do_sample, top_k=args.gen_top_k, top_p=args.gen_top_p, temperature=args.temperature, cuda_graph=args.cuda_graph, ) # decode responses response_texts = [tokenizer.decode(res, skip_special_tokens=False) for res in responses] actor_model.train() dist.barrier() step_avg_new_tokens = sum([len(res) for res in responses]) / len(responses) step_gen_time = time.time() - step_gen_start_t prompt_input_ids = [p[0].tolist() for p in prompt_input_ids] # Stage 2, Infer step_infer_start_t = time.time() step_infer_consumed_tokens = 0 # submit to judger if actor_data_mesh.get_local_rank() == 0: submit_batch = [] for i in range(len(message_data)): submit_batch.append( { "prompt_messages": message_data[i], "completion_messages": [{"role": "assistant", "content": response_texts[i]}], "metadata": metadata[i], } ) token, indexes_for_local = judger_router.submit(submit_batch) # `infer_dataset` varies at each dp rank, there is no need to # use the parallel sampler. infer_dataset = InferDataset(prompt_input_ids, responses, message_data, metadata) infer_dataloader = DataLoader( infer_dataset, batch_size=args.rl_micro_batch, num_workers=0, collate_fn=SftCollator(pack_batch=True), shuffle=False, persistent_workers=False, ) policies = [] for infer_packed_seq in infer_dataloader: # labels are already shifted in InferDataset infer_labels = infer_packed_seq["labels"].to(DEVICE) infer_input_ids = infer_packed_seq["input_ids"].to(DEVICE) infer_num_tokens = infer_packed_seq["num_tokens"].to(DEVICE) infer_batch_size = infer_num_tokens.numel() step_infer_consumed_tokens += infer_num_tokens.sum() / actor_data_mesh.size() unpacked_input_ids = unpack_sequence(infer_input_ids, infer_num_tokens, dim=1) unpacked_labels = unpack_sequence(infer_labels, infer_num_tokens, dim=1) for i in range(infer_batch_size): assert unpacked_input_ids[i].numel() == infer_num_tokens[i] assert unpacked_labels[i].numel() == infer_num_tokens[i] _policy = { "input_ids": unpacked_input_ids[i].flatten().tolist(), "labels": unpacked_labels[i].flatten().tolist(), "num_tokens": infer_num_tokens[i].item(), } _policy["sequence_text"] = tokenizer.decode(_policy["input_ids"], skip_special_tokens=False) policies.append(_policy) step_infer_time = time.time() - step_infer_start_t # --------------------------Get Judger Reward------------------ # # query results from judger if actor_data_mesh.get_local_rank() == 0: while True: try: judger_results = judger_router.query(token, timeout=3) logger.info(f"Query judger results: {judger_results}") break except TimeoutError as e: logger.info(f"Judger query timeout: {e}. Will retry") judger_rewards = [list(r.values())[0] for r in judger_results] judger_rewards = [r if r is not None else -1.0 for r in judger_rewards] judger_rewards = torch.tensor(judger_rewards, dtype=torch.float32).to(DEVICE) else: judger_rewards = torch.tensor([0] * len(policies), dtype=torch.float32).to(DEVICE) dist.barrier() # broadcast judger rewards to same data mesh dist.all_reduce(judger_rewards, op=dist.ReduceOp.SUM, group=actor_data_mesh.get_group()) # reward shaping, use GRPO or RLOO to normalize rewards _rewards = judger_rewards.reshape(-1, args.prompt_repeat_k).T if args.reward_shaping_type == "rloo": baseline = (_rewards.sum(0) - _rewards) / (args.prompt_repeat_k - 1) judger_advantages = _rewards - baseline elif args.reward_shaping_type == "grpo": judger_advantages = (_rewards - _rewards.mean(0)) / (_rewards.std(0) + 1e-8) else: raise NotImplementedError(f"Reward shaping type {args.reward_shaping_type} is not implemented.") judger_advantages = judger_advantages.T.flatten() # update policies assert len(judger_rewards) == len(policies) for i in range(len(policies)): policies[i]["judger_reward"] = judger_rewards[i].item() policies[i]["judger_advantage"] = judger_advantages[i].item() step_rl_start_t = time.time() _global_policies = [None] * actor_dp_size dist.all_gather_object(_global_policies, policies, actor_dp_mesh.get_group()) global_policies = [] for _rank_policies in _global_policies: global_policies.extend(_rank_policies) trajectory_dataset.update(global_policies) # ------------------------------------------------------------- # # --------------------------Stage 4, RL------------------------ # # ------------------------------------------------------------- # if rank == 0: # dump trajectory _buffer_dir = os.path.join(args.work_dir, "trajectories") mkdir_or_exist(_buffer_dir) _buffer_file = os.path.join(_buffer_dir, f"step.{step}.jsonl") trajectory_dataset.dump_jsonl(_buffer_file, tokenizer, args.debug) _buffer_log_file = os.path.join(_buffer_dir, f"step.{step}.log") trajectory_dataset.dump_log(_buffer_log_file, tokenizer, args.debug) rl_global_batch = args.rl_global_batch rl_loader = DataLoader( trajectory_dataset, batch_size=args.rl_micro_batch, num_workers=0, collate_fn=TrajectoryCollator(pack_batch=True), shuffle=False, sampler=RLParallelSampler(trajectory_dataset, actor_dp_mesh, rl_global_batch, shuffle=False), persistent_workers=False, ) # Count the total number of tokens used for training RL on all ranks # It is necessary for `per-token` loss, otherwise the number of tokens # for each backward is unbalanced. step_avg_judger_reward = sum([t["judger_reward"] for t in global_policies]) / len(global_policies) # --------------------------Infer Old Policy---------------------- # _all_old_logprobs = [] _all_action_tokens = [] for packed_policy in rl_loader: rl_input_ids = packed_policy["input_ids"].to(DEVICE) rl_num_tokens = packed_policy["num_tokens"].to(DEVICE) assert rl_input_ids.numel() == rl_num_tokens.sum() # labels are already shifted in InferDataset rl_labels = packed_policy["labels"].to(DEVICE) judger_rewards = torch.Tensor(packed_policy["judger_rewards"]).to(DEVICE) # shape: (rl_micro_batch, ) judger_advantages = torch.Tensor(packed_policy["judger_advantages"]).to(DEVICE) # shape: (rl_micro_batch, ) actor_input_ids = rl_input_ids.clone() actor_labels = rl_labels.clone() actor_num_tokens = rl_num_tokens.clone().tolist() actor_cu_seq_lens = torch.cumsum(torch.IntTensor([0] + actor_num_tokens), dim=0).to(DEVICE).int() actor_position_ids = [torch.arange(num) for num in actor_num_tokens] actor_position_ids = torch.cat(actor_position_ids, dim=0).to(DEVICE).unsqueeze_(0) with torch.no_grad(): packed_actor_logits = actor_model( input_ids=actor_input_ids, position_ids=actor_position_ids, use_cache=False, cu_seq_lens_q=actor_cu_seq_lens, cu_seq_lens_k=actor_cu_seq_lens, max_length_q=max(actor_num_tokens), max_length_k=max(actor_num_tokens), sequence_parallel_mesh=actor_sp_mesh, ).logits # The labels of prefill tokens and last token are -100. # HACK: (for sp) The -100 part takes the value of 0, # this part will be masked later. packed_logprobs = actor_model.gather_logprobs(packed_actor_logits, actor_labels.clip(0), actor_sp_mesh) logprobs = unpack_sequence(packed_logprobs, actor_num_tokens, dim=1) logprobs = [l.detach().cpu() for l in logprobs] unpacked_labels = unpack_sequence(rl_labels, rl_num_tokens, dim=1) _num_action_tokens = [(unpacked_labels[i] >= 0).sum() for i in range(len(unpacked_labels))] _all_action_tokens.extend(_num_action_tokens) _all_old_logprobs.extend(logprobs) # --------------------------Mini-batch Train Policy---------------------- # rl_loader_iter = iter(rl_loader) _sample_idx = 0 num_mini_batch_samples = len(rl_loader) // args.rl_mini_batch_steps all_mini_batch_action_tokens = [sum(_all_action_tokens[i*num_mini_batch_samples:(i+1)*num_mini_batch_samples]) for i in range(args.rl_mini_batch_steps)] for mini_batch_step in range(args.rl_mini_batch_steps): step_sum_gen_entropy = 0 step_sum_ref_kl = 0 step_action_tokens = 0 step_rl_consumed_tokens = 0 step_sum_adv = 0 for _train_iter in range(len(rl_loader) // args.rl_mini_batch_steps // args.rl_micro_batch): packed_policy = next(rl_loader_iter) rl_input_ids = packed_policy["input_ids"].to(DEVICE) rl_num_tokens = packed_policy["num_tokens"].to(DEVICE) assert rl_input_ids.numel() == rl_num_tokens.sum() rl_batch_size = rl_num_tokens.numel() # labels are already shifted in InferDataset rl_labels = packed_policy["labels"].to(DEVICE) judger_rewards = torch.Tensor(packed_policy["judger_rewards"]).to(DEVICE) # shape: (rl_micro_batch, ) judger_advantages = torch.Tensor(packed_policy["judger_advantages"]).to(DEVICE) # shape: (rl_micro_batch, ) actor_input_ids = rl_input_ids.clone() actor_labels = rl_labels.clone() actor_num_tokens = rl_num_tokens.clone().tolist() actor_cu_seq_lens = torch.cumsum(torch.IntTensor([0] + actor_num_tokens), dim=0).to(DEVICE).int() actor_position_ids = [torch.arange(num) for num in actor_num_tokens] actor_position_ids = torch.cat(actor_position_ids, dim=0).to(DEVICE).unsqueeze_(0) packed_actor_logits = actor_model( input_ids=actor_input_ids, position_ids=actor_position_ids, use_cache=False, cu_seq_lens_q=actor_cu_seq_lens, cu_seq_lens_k=actor_cu_seq_lens, max_length_q=max(actor_num_tokens), max_length_k=max(actor_num_tokens), sequence_parallel_mesh=actor_sp_mesh, ).logits with torch.no_grad(): packed_ref_logits = ref_model( input_ids=actor_input_ids, position_ids=actor_position_ids, use_cache=False, cu_seq_lens_q=actor_cu_seq_lens, cu_seq_lens_k=actor_cu_seq_lens, max_length_q=max(actor_num_tokens), max_length_k=max(actor_num_tokens), sequence_parallel_mesh=actor_sp_mesh, ).logits # The labels of prefill tokens and last token are -100. # HACK: (for sp) The -100 part takes the value of 0, # this part will be masked later. packed_logprobs = actor_model.gather_logprobs(packed_actor_logits, actor_labels.clip(0), actor_sp_mesh) logprobs = unpack_sequence(packed_logprobs, actor_num_tokens, dim=1) packed_ref_logprobs = ref_model.gather_logprobs(packed_ref_logits, actor_labels.clip(0), actor_sp_mesh) ref_logprobs = unpack_sequence(packed_ref_logprobs, actor_num_tokens, dim=1) # The labels of prefill tokens and last token are -100. # HACK: (for sp) The -100 part takes the value of 0, # this part will be masked later. unpacked_labels = unpack_sequence(rl_labels, rl_num_tokens, dim=1) _losses = [] for i in range(rl_batch_size): assert unpacked_labels[i].numel() == rl_num_tokens[i] # from the last prefill token, to the second-to-last token (excluding the eos token) _num_action_tokens = (unpacked_labels[i] >= 0).sum() _logprobs = logprobs[i][0, -_num_action_tokens - 1 : -1] _ref_logprobs = ref_logprobs[i][0, -_num_action_tokens - 1 : -1] _old_logprobs = _all_old_logprobs[_sample_idx][0, -_num_action_tokens - 1 : -1].to(DEVICE) _judger_advantages = judger_advantages[i] _advantages = _judger_advantages _loss_factor = 1/float(all_mini_batch_action_tokens[mini_batch_step]) _loss = policy_loss_fn(_logprobs, _old_logprobs, _advantages, loss_factor=_loss_factor) kl_type = args.get("kl_type", "unbias") # kl, unbias, mse if kl_type == "kl": kl = _ref_logprobs - _logprobs _kl_penalty_loss = (args.kl_coef * kl).sum() * _loss_factor elif kl_type == "unbias": kl = _ref_logprobs - _logprobs nonneg_nobias_kl = torch.exp(kl) - kl - 1 _kl_penalty_loss = (args.kl_coef * nonneg_nobias_kl).sum() * _loss_factor elif kl_type == "mse": _kl_penalty_loss = ( (args.kl_coef * (_ref_logprobs - _logprobs).square() / 2).sum() * _loss_factor ) _loss = _loss + _kl_penalty_loss _losses.append(_loss) step_sum_gen_entropy += -_old_logprobs.sum().item() step_sum_ref_kl += (_old_logprobs - _ref_logprobs).sum().item() step_sum_adv += _judger_advantages.sum().item() step_action_tokens += _num_action_tokens.item() _sample_idx += 1 loss = sum(_losses) loss.backward() # for logging step_rl_loss += loss.item() step_rl_consumed_tokens += rl_num_tokens.sum() / actor_data_mesh.size() step_rl_time = time.time() - step_rl_start_t step_avg_ref_kl = step_sum_ref_kl / step_action_tokens step_avg_gen_entropy = step_sum_gen_entropy / step_action_tokens step_avg_adv = step_sum_adv / step_action_tokens actor_data_group = actor_data_mesh.get_group() step_avg_ref_kl = reduce_mean(step_avg_ref_kl, actor_data_group) step_avg_gen_entropy = reduce_mean(step_avg_gen_entropy, actor_data_group) step_avg_adv = reduce_mean(step_avg_adv, actor_data_group) step_avg_new_tokens = reduce_mean(step_avg_new_tokens, actor_data_group) actor_grad_norm = actor_model.clip_grad_norm(args.max_grad_norm) actor_grad_norm = actor_grad_norm.item() actor_optimizer.step() actor_optimizer.zero_grad() step_time = time.time() - step_start_t eta = step_time * (total_steps * args.rl_mini_batch_steps - cur_total_minibatch_steps) eta = timedelta(seconds=int(eta)) infer_tgs = int(step_infer_consumed_tokens / step_infer_time) rl_tgs = int(step_rl_consumed_tokens / step_rl_time) actor_lr = cur_lr max_memory = DEVICE_MODULE.max_memory_allocated() log_dict = { "step": cur_total_minibatch_steps + 1, "minibatch": mini_batch_step + 1, "global_step": step + 1, "actor_lr": actor_lr, "actor_grad_norm": actor_grad_norm, "avg_judger_reward": step_avg_judger_reward, "avg_adv": step_avg_adv, "avg_gen_entropy": step_avg_gen_entropy, "avg_ref_kl": step_avg_ref_kl, "rl_loss": step_rl_loss, "max_memory": max_memory / 1024**3, "avg_new_tokens": step_avg_new_tokens, "num_rl_tokens": step_rl_consumed_tokens, "infer_tgs": infer_tgs, "rl_tgs": rl_tgs, "gen_time": step_gen_time, "infer_time": step_infer_time, "rl_time": step_rl_time, "total_time": step_time, "eta": eta.seconds, } for key, value in log_dict.items(): if isinstance(value, torch.Tensor): log_dict[key] = value.item() with open(os.path.join(args.work_dir, f"rank{rank}.log.jsonl"), "a") as f: f.write(json.dumps(log_dict, ensure_ascii=False) + "\n") if is_interval(cur_total_minibatch_steps, total_steps * args.rl_mini_batch_steps, args.log_interval): logger.info( f"[Train] Step {cur_total_minibatch_steps + 1} / Mini-batch {mini_batch_step + 1} " f"global_step: {step + 1}/{total_steps} " f"actor_lr: {cur_lr:.3e} " f"actor_grad_norm: {actor_grad_norm:.3f} " f"avg_judger_reward: {step_avg_judger_reward:.8f} " f"avg_adv: {step_avg_adv:.8f} " f"avg_gen_entropy: {step_avg_gen_entropy:.3f} " f"avg_ref_kl: {step_avg_ref_kl:.8f} " f"rl_loss: {step_rl_loss:.3f} " f"max_memory: {(max_memory / 1024**3):.1f}GB " f"avg_new_tokens: {int(step_avg_new_tokens)} " f"num_rl_tokens: {int(step_rl_consumed_tokens)} " f"infer_tgs: {int(infer_tgs)} " f"rl_tgs: {int(rl_tgs)} " f"gen_time: {step_gen_time:.2f}s " f"infer_time: {step_infer_time:.2f}s " f"rl_time: {step_rl_time:.2f}s " f"total_time: {step_time:.2f}s " f"eta: {eta}" ) if is_interval(cur_total_minibatch_steps, total_steps * args.rl_mini_batch_steps, checkpoint_interval): DEVICE_MODULE.empty_cache() num_digits = len(str(abs(total_steps))) work_dir = args.work_dir ckpt_dir = os.path.join(work_dir, f"ckpt-{cur_total_minibatch_steps+1:0{num_digits}}") hf_dir = os.path.join(work_dir, f"hf-{cur_total_minibatch_steps+1:0{num_digits}}") with profile_time_and_memory("[Checkpoint]"): actor_model.save_pretrained(hf_dir) if rank == 0: tokenizer.save_pretrained(hf_dir) dist.barrier() cur_total_minibatch_steps += 1 train_cost_time = time.time() - start_train_t logger.success(f"[Train] Cost {timedelta(seconds=int(train_cost_time))}") # ------------------------ Training End ---------------------------- # if __name__ == "__main__": fire.Fire(train_grpo)