mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-19 12:58:07 +00:00
* tweak len reward * first inter-generalisation experiment config * update inter algorithmic config * default to empty config * fix typo * change config to match experiment script * long prompt fixes * algorithmic training config tweaks * imports * update algorithmic training cfgs * first logic composite config * fix dset name * tweaks * fix syllogisms dataset * rm temp print * initial algebra config * algebra cfg tweaks * add gc * add initial games cfg * rename games cfg * fix dset name * fix sokoban metadata * remove boxnet * games cfg tweak
92 lines
3 KiB
Python
92 lines
3 KiB
Python
import math
|
|
import re
|
|
|
|
|
|
class RewardRegistry:
|
|
"""Simple registry for secondary reward functions."""
|
|
|
|
def __init__(self):
|
|
self.reward_functions = {}
|
|
|
|
def register(self, name: str):
|
|
"""Register a reward function."""
|
|
|
|
def decorator(func):
|
|
self.reward_functions[name] = func
|
|
return func
|
|
|
|
return decorator
|
|
|
|
def get(self, name: str):
|
|
"""Get a reward function by name."""
|
|
return self.reward_functions.get(name)
|
|
|
|
def list_functions(self):
|
|
"""List available reward function names."""
|
|
return list(self.reward_functions.keys())
|
|
|
|
|
|
reward_registry = RewardRegistry()
|
|
|
|
|
|
@reward_registry.register("cosine")
|
|
def cosine_scaled_reward(solution_str, scaling_factor, **kwargs):
|
|
"""Reward function that scales based on completion length using a cosine schedule."""
|
|
min_value_wrong = 0
|
|
max_value_wrong = 0.7
|
|
min_value_correct = 0.95
|
|
max_value_correct = 1.0
|
|
max_len = 1000
|
|
|
|
is_correct = kwargs.get("is_correct", False)
|
|
gen_len = len(solution_str)
|
|
|
|
# Apply cosine scaling based on length
|
|
progress = gen_len / max_len
|
|
cosine = math.cos(progress * math.pi)
|
|
|
|
if is_correct:
|
|
min_value = min_value_correct
|
|
max_value = max_value_correct
|
|
else:
|
|
min_value = max_value_wrong
|
|
max_value = min_value_wrong
|
|
|
|
cosine_scaled_reward = min_value + 0.5 * (max_value - min_value) * (1.0 + cosine)
|
|
return cosine_scaled_reward * scaling_factor
|
|
|
|
|
|
@reward_registry.register("format")
|
|
def compute_format_reward(solution_str: str, scaling_factor: float = 0.2, **kwargs) -> float:
|
|
"""Reward use of exactly one correctly structured <think> and <answer> block."""
|
|
preappend_thinking_token = kwargs.get("preappend_thinking_token", False)
|
|
if preappend_thinking_token:
|
|
solution_str = "<think>" + solution_str
|
|
|
|
pattern = r"\s*<think>.*?</think>\s*<answer>.*?</answer>"
|
|
if not re.match(pattern, solution_str, re.DOTALL):
|
|
return 0.0
|
|
think_matches = list(re.finditer(r"<think>(.*?)</think>", solution_str, re.DOTALL))
|
|
answer_matches = list(re.finditer(r"<answer>(.*?)</answer>", solution_str, re.DOTALL))
|
|
if len(think_matches) != 1 or len(answer_matches) != 1:
|
|
return 0.0
|
|
think_content = think_matches[0].group(1)
|
|
if "<think>" in think_content or "<answer>" in think_content:
|
|
return 0.0
|
|
answer_content = answer_matches[0].group(1)
|
|
if "<answer>" in answer_content or "<think>" in answer_content:
|
|
return 0.0
|
|
return 1.0 * scaling_factor
|
|
|
|
|
|
@reward_registry.register("length")
|
|
def length_reward(solution_str, scaling_factor, **kwargs):
|
|
"""Reward length appropriately based on correctness."""
|
|
correctness_score = kwargs.get("correctness_score", 0.0)
|
|
max_score = kwargs.get("max_score", 1.0)
|
|
max_output_length = kwargs.get("max_output_length", 1024)
|
|
|
|
progress = min(len(solution_str) / max_output_length, 1.0)
|
|
# for imperfect answers, incentivise longer ones
|
|
length_reward = (max_score - correctness_score) * progress
|
|
return length_reward * scaling_factor
|