reasoning-gym/training/rewards/reward.py
Oliver Stanley 224532f12a
first inter-domain generalisation experiments (#412)
* tweak len reward

* first inter-generalisation experiment config

* update inter algorithmic config

* default to empty config

* fix typo

* change config to match experiment script

* long prompt fixes

* algorithmic training config tweaks

* imports

* update algorithmic training cfgs

* first logic composite config

* fix dset name

* tweaks

* fix syllogisms dataset

* rm temp print

* initial algebra config

* algebra cfg tweaks

* add gc

* add initial games cfg

* rename games cfg

* fix dset name

* fix sokoban metadata

* remove boxnet

* games cfg tweak
2025-04-14 21:06:40 +01:00

92 lines
3 KiB
Python

import math
import re
class RewardRegistry:
"""Simple registry for secondary reward functions."""
def __init__(self):
self.reward_functions = {}
def register(self, name: str):
"""Register a reward function."""
def decorator(func):
self.reward_functions[name] = func
return func
return decorator
def get(self, name: str):
"""Get a reward function by name."""
return self.reward_functions.get(name)
def list_functions(self):
"""List available reward function names."""
return list(self.reward_functions.keys())
reward_registry = RewardRegistry()
@reward_registry.register("cosine")
def cosine_scaled_reward(solution_str, scaling_factor, **kwargs):
"""Reward function that scales based on completion length using a cosine schedule."""
min_value_wrong = 0
max_value_wrong = 0.7
min_value_correct = 0.95
max_value_correct = 1.0
max_len = 1000
is_correct = kwargs.get("is_correct", False)
gen_len = len(solution_str)
# Apply cosine scaling based on length
progress = gen_len / max_len
cosine = math.cos(progress * math.pi)
if is_correct:
min_value = min_value_correct
max_value = max_value_correct
else:
min_value = max_value_wrong
max_value = min_value_wrong
cosine_scaled_reward = min_value + 0.5 * (max_value - min_value) * (1.0 + cosine)
return cosine_scaled_reward * scaling_factor
@reward_registry.register("format")
def compute_format_reward(solution_str: str, scaling_factor: float = 0.2, **kwargs) -> float:
"""Reward use of exactly one correctly structured <think> and <answer> block."""
preappend_thinking_token = kwargs.get("preappend_thinking_token", False)
if preappend_thinking_token:
solution_str = "<think>" + solution_str
pattern = r"\s*<think>.*?</think>\s*<answer>.*?</answer>"
if not re.match(pattern, solution_str, re.DOTALL):
return 0.0
think_matches = list(re.finditer(r"<think>(.*?)</think>", solution_str, re.DOTALL))
answer_matches = list(re.finditer(r"<answer>(.*?)</answer>", solution_str, re.DOTALL))
if len(think_matches) != 1 or len(answer_matches) != 1:
return 0.0
think_content = think_matches[0].group(1)
if "<think>" in think_content or "<answer>" in think_content:
return 0.0
answer_content = answer_matches[0].group(1)
if "<answer>" in answer_content or "<think>" in answer_content:
return 0.0
return 1.0 * scaling_factor
@reward_registry.register("length")
def length_reward(solution_str, scaling_factor, **kwargs):
"""Reward length appropriately based on correctness."""
correctness_score = kwargs.get("correctness_score", 0.0)
max_score = kwargs.get("max_score", 1.0)
max_output_length = kwargs.get("max_output_length", 1024)
progress = min(len(solution_str) / max_output_length, 1.0)
# for imperfect answers, incentivise longer ones
length_reward = (max_score - correctness_score) * progress
return length_reward * scaling_factor