extract answer from last answer tag

This commit is contained in:
Andreas Koepf 2025-01-28 16:37:19 +00:00
parent cc0312e446
commit c196d622e0
5 changed files with 31 additions and 22 deletions

View file

@ -1 +1,2 @@
checkpoint/
wandb/

View file

@ -16,6 +16,8 @@ from openrlhf.models.utils import compute_approx_kl, masked_mean
from openrlhf.trainer import PPOTrainer
from openrlhf.trainer.ppo_utils.experience_maker import Experience, NaiveExperienceMaker, Samples
from openrlhf.utils import blending_datasets, get_strategy, get_tokenizer
from openrlhf.utils.logging_utils import init_logger
from torch.utils.data import Dataset
from transformers.trainer import get_scheduler
@ -23,6 +25,8 @@ import reasoning_gym
from reasoning_gym.dataset import ProceduralDataset
from reasoning_gym.utils import extract_answer
logger = init_logger(__name__)
DEBUG = False
@ -176,12 +180,13 @@ class AlgorithmicRewardExperienceMaker(NaiveExperienceMaker):
value = None
# determine outcome reward
completions = self.tokenizer.batch_decode(sequences.cpu(), skip_special_tokens=True)
completions = sequences[:, -action_mask.size(1):].cpu()
completions = self.tokenizer.batch_decode(completions, skip_special_tokens=True)
returns = [
self.dataset.score_answer(extract_answer(c, tag_name="answer"), entry=m)
for c, m in zip(completions, metadata)
]
r = torch.tensor(returns, device=sequences.device)
r = torch.tensor(returns, dtype=torch.float32, device=sequences.device)
kl = compute_approx_kl(
action_log_probs,
@ -197,6 +202,9 @@ class AlgorithmicRewardExperienceMaker(NaiveExperienceMaker):
"total_length": samples.total_length,
"num_actions": num_actions,
}
logger.info(f"info={info}")
# reset model state
self.actor.train()
if self.critic is not None:
@ -667,7 +675,7 @@ if __name__ == "__main__":
parser.add_argument("--value_head_prefix", type=str, default="score")
# Custom dataset
parser.add_argument("--prompt_data", type=str, default=None, help="HF dataset name or path")
parser.add_argument("--prompt_data", type=str, default="chain_sum", help="HF dataset name or path")
parser.add_argument(
"--prompt_data_probs",
type=str,
@ -708,15 +716,10 @@ if __name__ == "__main__":
args = parser.parse_args()
# if args.advantage_estimator not in ["gae"]:
# args.critic_pretrain = None
# elif args.critic_pretrain is None:
# if not args.remote_rm_url:
# args.critic_pretrain = args.reward_pretrain
# else:
# args.critic_pretrain = args.pretrain
args.critic_pretrain = args.pretrain ## temp
if args.advantage_estimator not in ["gae"]:
args.critic_pretrain = None
elif args.critic_pretrain is None:
args.critic_pretrain = args.pretrain ## temp
if args.advantage_estimator == "rloo":
assert args.n_samples_per_prompt > 1, "RLOO requires n_samples_per_prompt > 1"

View file

@ -17,15 +17,18 @@ args=(
--bf16
--actor_learning_rate 5e-7
--init_kl_coef 0.01
--prompt_data leg_counting
--prompt_data chain_sum # leg_counting
--input_key question
--apply_chat_template
--normalize_reward
--adam_offload
--flash_attn
--gradient_checkpointing
--max_samples 1024 # 100000
--max_samples 100000
--critic_learning_rate 9e-6
# --use_wandb {wandb_token}
)
# Add wandb argument only if wandb_token is set
if [[ -n "${wandb_token}" ]]; then
args+=(--use_wandb "${wandb_token}")
fi
deepspeed ${args[@]}