mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-28 17:29:39 +00:00
extract answer from last answer tag
This commit is contained in:
parent
cc0312e446
commit
c196d622e0
5 changed files with 31 additions and 22 deletions
1
examples/OpenRLHF/.gitignore
vendored
1
examples/OpenRLHF/.gitignore
vendored
|
|
@ -1 +1,2 @@
|
|||
checkpoint/
|
||||
wandb/
|
||||
|
|
|
|||
|
|
@ -16,6 +16,8 @@ from openrlhf.models.utils import compute_approx_kl, masked_mean
|
|||
from openrlhf.trainer import PPOTrainer
|
||||
from openrlhf.trainer.ppo_utils.experience_maker import Experience, NaiveExperienceMaker, Samples
|
||||
from openrlhf.utils import blending_datasets, get_strategy, get_tokenizer
|
||||
from openrlhf.utils.logging_utils import init_logger
|
||||
|
||||
from torch.utils.data import Dataset
|
||||
from transformers.trainer import get_scheduler
|
||||
|
||||
|
|
@ -23,6 +25,8 @@ import reasoning_gym
|
|||
from reasoning_gym.dataset import ProceduralDataset
|
||||
from reasoning_gym.utils import extract_answer
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
DEBUG = False
|
||||
|
||||
|
||||
|
|
@ -176,12 +180,13 @@ class AlgorithmicRewardExperienceMaker(NaiveExperienceMaker):
|
|||
value = None
|
||||
|
||||
# determine outcome reward
|
||||
completions = self.tokenizer.batch_decode(sequences.cpu(), skip_special_tokens=True)
|
||||
completions = sequences[:, -action_mask.size(1):].cpu()
|
||||
completions = self.tokenizer.batch_decode(completions, skip_special_tokens=True)
|
||||
returns = [
|
||||
self.dataset.score_answer(extract_answer(c, tag_name="answer"), entry=m)
|
||||
for c, m in zip(completions, metadata)
|
||||
]
|
||||
r = torch.tensor(returns, device=sequences.device)
|
||||
r = torch.tensor(returns, dtype=torch.float32, device=sequences.device)
|
||||
|
||||
kl = compute_approx_kl(
|
||||
action_log_probs,
|
||||
|
|
@ -197,6 +202,9 @@ class AlgorithmicRewardExperienceMaker(NaiveExperienceMaker):
|
|||
"total_length": samples.total_length,
|
||||
"num_actions": num_actions,
|
||||
}
|
||||
|
||||
logger.info(f"info={info}")
|
||||
|
||||
# reset model state
|
||||
self.actor.train()
|
||||
if self.critic is not None:
|
||||
|
|
@ -667,7 +675,7 @@ if __name__ == "__main__":
|
|||
parser.add_argument("--value_head_prefix", type=str, default="score")
|
||||
|
||||
# Custom dataset
|
||||
parser.add_argument("--prompt_data", type=str, default=None, help="HF dataset name or path")
|
||||
parser.add_argument("--prompt_data", type=str, default="chain_sum", help="HF dataset name or path")
|
||||
parser.add_argument(
|
||||
"--prompt_data_probs",
|
||||
type=str,
|
||||
|
|
@ -708,15 +716,10 @@ if __name__ == "__main__":
|
|||
|
||||
args = parser.parse_args()
|
||||
|
||||
# if args.advantage_estimator not in ["gae"]:
|
||||
# args.critic_pretrain = None
|
||||
# elif args.critic_pretrain is None:
|
||||
# if not args.remote_rm_url:
|
||||
# args.critic_pretrain = args.reward_pretrain
|
||||
# else:
|
||||
# args.critic_pretrain = args.pretrain
|
||||
|
||||
args.critic_pretrain = args.pretrain ## temp
|
||||
if args.advantage_estimator not in ["gae"]:
|
||||
args.critic_pretrain = None
|
||||
elif args.critic_pretrain is None:
|
||||
args.critic_pretrain = args.pretrain ## temp
|
||||
|
||||
if args.advantage_estimator == "rloo":
|
||||
assert args.n_samples_per_prompt > 1, "RLOO requires n_samples_per_prompt > 1"
|
||||
|
|
|
|||
|
|
@ -17,15 +17,18 @@ args=(
|
|||
--bf16
|
||||
--actor_learning_rate 5e-7
|
||||
--init_kl_coef 0.01
|
||||
--prompt_data leg_counting
|
||||
--prompt_data chain_sum # leg_counting
|
||||
--input_key question
|
||||
--apply_chat_template
|
||||
--normalize_reward
|
||||
--adam_offload
|
||||
--flash_attn
|
||||
--gradient_checkpointing
|
||||
--max_samples 1024 # 100000
|
||||
--max_samples 100000
|
||||
--critic_learning_rate 9e-6
|
||||
# --use_wandb {wandb_token}
|
||||
)
|
||||
# Add wandb argument only if wandb_token is set
|
||||
if [[ -n "${wandb_token}" ]]; then
|
||||
args+=(--use_wandb "${wandb_token}")
|
||||
fi
|
||||
deepspeed ${args[@]}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue