Merge branch 'main' into env/string-insertion

This commit is contained in:
Andreas Köpf 2025-02-13 13:07:29 +01:00 committed by GitHub
commit 971782308f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 956 additions and 118 deletions

View file

@ -2,6 +2,7 @@ import argparse
import asyncio import asyncio
import json import json
import os import os
import re
import time import time
from datetime import datetime from datetime import datetime
from typing import Any, Dict, List from typing import Any, Dict, List
@ -10,6 +11,7 @@ from openai import AsyncOpenAI
from tqdm.asyncio import tqdm_asyncio from tqdm.asyncio import tqdm_asyncio
from reasoning_gym.factory import create_dataset from reasoning_gym.factory import create_dataset
from reasoning_gym.utils import SYSTEM_PROMPTS
class AsyncOpenRouterEvaluator: class AsyncOpenRouterEvaluator:
@ -25,22 +27,33 @@ class AsyncOpenRouterEvaluator:
async with self.semaphore: async with self.semaphore:
try: try:
completion = await self.client.chat.completions.create( completion = await self.client.chat.completions.create(
extra_headers=self.extra_headers, model=self.model, messages=[{"role": "user", "content": prompt}] extra_headers=self.extra_headers,
model=self.model,
messages=[
{"role": "system", "content": SYSTEM_PROMPTS["default"]},
{"role": "user", "content": prompt},
],
) )
return completion.choices[0].message.content return completion.choices[0].message.content
except Exception as e: except Exception as e:
print(f"Error calling OpenRouter API: {str(e)}") print(f"Error calling OpenRouter API: {str(e)}")
raise raise
def parse_model_response(self, response: str) -> str:
"""Gather the final answer between the <answer> and </answer> tags."""
match = re.search(r"<answer>(.*?)</answer>", response, re.DOTALL)
return match.group(1).strip() if match else response
async def process_single_question(self, entry: Dict, dataset) -> Dict: async def process_single_question(self, entry: Dict, dataset) -> Dict:
"""Process a single question and return the result.""" """Process a single question and return the result."""
response = await self.get_model_response(entry["question"]) response = await self.get_model_response(entry["question"])
score = dataset.score_answer(answer=response, entry=entry) answer = self.parse_model_response(response)
score = dataset.score_answer(answer=answer, entry=entry)
return { return {
"question": entry["question"], "question": entry["question"],
"expected_answer": entry["answer"], "expected_answer": entry["answer"],
"model_answer": response, "model_answer": answer,
"score": score, "score": score,
"metadata": entry["metadata"], "metadata": entry["metadata"],
} }

View file

@ -1,4 +1,5 @@
import argparse import argparse
import asyncio
import json import json
import logging import logging
import os import os
@ -6,10 +7,9 @@ from dataclasses import asdict
from datetime import datetime from datetime import datetime
from typing import Any, Dict, List from typing import Any, Dict, List
import requests import aiohttp
from eval_config import EvalConfig from eval_config import EvalConfig
from requests.exceptions import RequestException from tenacity import AsyncRetrying, retry_if_exception_type, stop_after_attempt, wait_exponential
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_exponential
import reasoning_gym import reasoning_gym
from reasoning_gym.utils import extract_answer from reasoning_gym.utils import extract_answer
@ -30,9 +30,9 @@ class OpenRouterEvaluator:
"X-Title": os.getenv("OR_APP_NAME", "Model Evaluation"), "X-Title": os.getenv("OR_APP_NAME", "Model Evaluation"),
"Content-Type": "application/json", "Content-Type": "application/json",
} }
self.semaphore = asyncio.Semaphore(10) # Control concurrency
def save_results(self, results: List[Dict[str, Any]], dataset, dataset_name) -> Dict[str, Any]: def save_results(self, results: List[Dict[str, Any]], dataset, dataset_name) -> Dict[str, Any]:
file_name = f"{self.output_dir}/{dataset_name}.json" file_name = f"{self.output_dir}/{dataset_name}.json"
total_score = sum(r["score"] for r in results) total_score = sum(r["score"] for r in results)
@ -45,7 +45,7 @@ class OpenRouterEvaluator:
"total_examples": len(results), "total_examples": len(results),
"timestamp": datetime.now().isoformat(), "timestamp": datetime.now().isoformat(),
"config": asdict(dataset.config), "config": asdict(dataset.config),
"results": results, # save results to allow for performance recalculation "results": results,
} }
with open(file_name, "w") as f: with open(file_name, "w") as f:
@ -53,87 +53,93 @@ class OpenRouterEvaluator:
return metrics return metrics
def prepare_messages(self, prompt: str) -> List[Dict[str, str]]: def prepare_messages(self, prompt: str) -> List[Dict[str, str]]:
messages = [ return {
{"role": self.config.developer_role, "content": self.config.developer_prompt}, "model": self.model,
{"role": "user", "content": prompt}, "messages": [
] {"role": self.config.developer_role, "content": self.config.developer_prompt},
{"role": "user", "content": prompt},
],
"provider": {"order": ["Nebius"], "allow_fallbacks": False},
}
async def get_model_response(self, session: aiohttp.ClientSession, prompt: str) -> str:
payload = { payload = {
"model": self.model, "model": self.model,
"messages": messages, "messages": [
"provider": {"order": ["Nebius"], "allow_fallbacks": False}, {"role": self.config.developer_role, "content": self.config.developer_prompt},
} # make sure only one provider is used {"role": "user", "content": prompt},
],
}
return payload async for attempt in AsyncRetrying(
stop=stop_after_attempt(20),
wait=wait_exponential(multiplier=1, min=1, max=60),
retry=retry_if_exception_type(
(aiohttp.ClientError, asyncio.TimeoutError, json.JSONDecodeError, ValueError)
),
):
with attempt:
async with session.post(self.base_url, json=payload) as response:
data = await response.json()
@retry( if not data:
retry=retry_if_exception_type(RequestException), raise ValueError("Empty response")
stop=stop_after_attempt(5),
wait=wait_exponential(multiplier=1, min=4, max=60),
)
def get_model_response(self, prompt: str) -> str:
"""Get response from the model via OpenRouter API."""
payload = self.prepare_messages(prompt) if not data.get("choices"):
try: raise ValueError("Missing choices in response")
response = requests.post(self.base_url, headers=self.headers, json=payload, timeout=30)
response.raise_for_status()
except requests.exceptions.RequestException as e:
raise RequestException(
f"API request failed: {str(e)}", {"endpoint": self.base_url, "model": self.model}
) from e
return response.json()["choices"][0]["message"]["content"]
def evaluate_datasets(self) -> List[Dict[str, Any]]: return data["choices"][0]["message"]["content"]
"""Evaluate model on multiple datasets with their respective configurations."""
raise Exception("Failed to get valid response after retries")
async def process_entry(self, session: aiohttp.ClientSession, dataset: Any, entry: Any) -> Dict[str, Any]:
"""Process a single entry with concurrency control."""
async with self.semaphore:
response = await self.get_model_response(session, entry["question"])
model_answer = extract_answer(response)
score = dataset.score_answer(answer=model_answer, entry=entry)
print(f"Question: {entry['question']}")
return {
"question": entry["question"],
"expected_answer": str(entry["answer"]),
"model_answer": model_answer,
"score": score,
"metadata": str(entry["metadata"]),
}
async def evaluate_dataset(self, session: aiohttp.ClientSession, dataset_name: str) -> Dict[str, Any]:
"""Evaluate a single dataset asynchronously."""
self.logger.info(f"\nEvaluating dataset: {dataset_name}")
dataset = reasoning_gym.create_dataset(
dataset_name, size=self.config.dataset_size, seed=self.config.dataset_seed
)
tasks = [self.process_entry(session, dataset, entry) for entry in dataset]
results = await asyncio.gather(*tasks)
return self.save_results(results, dataset, dataset_name)
async def evaluate_datasets(self) -> List[Dict[str, Any]]:
"""Main async evaluation entry point."""
all_results = [] all_results = []
async with aiohttp.ClientSession(headers=self.headers) as session:
for dataset_name in self.config.datasets: return await asyncio.gather(*(self.evaluate_dataset(session, name) for name in self.config.datasets))
self.logger.info(f"\nEvaluating dataset: {dataset_name}")
# Create dataset with its specific configuration
dataset = reasoning_gym.create_dataset(
dataset_name, size=self.config.dataset_size, seed=self.config.dataset_seed
)
results = []
for i, entry in enumerate(dataset):
print(f"On example {i+1} of {len(dataset)}")
response = self.get_model_response(entry["question"])
model_answer = extract_answer(response)
score = dataset.score_answer(answer=model_answer, entry=entry)
result = {
"question": entry["question"],
"expected_answer": str(entry["answer"]),
"model_answer": model_answer,
"score": score,
"metadata": str(entry["metadata"]),
}
results.append(result)
metrics = self.save_results(results, dataset, dataset_name)
all_results.append({"metrics": metrics, "results": results})
return all_results
def main(): async def async_main():
parser = argparse.ArgumentParser(description="Evaluate models on reasoning datasets") parser = argparse.ArgumentParser(description="Evaluate models on reasoning datasets")
parser.add_argument("--yaml", required=True, help="Path to YAML configuration file") parser.add_argument("--yaml", required=True, help="Path to YAML configuration file")
args = parser.parse_args() args = parser.parse_args()
config = EvalConfig.from_yaml(args.yaml) config = EvalConfig.from_yaml(args.yaml)
evaluator = OpenRouterEvaluator(model=config.model, config=config)
results = await evaluator.evaluate_datasets()
output_dir = f"{config.eval_dir}/{config.category}" output_dir = f"{config.eval_dir}/{config.category}"
os.makedirs(output_dir, exist_ok=True) os.makedirs(output_dir, exist_ok=True)
evaluator = OpenRouterEvaluator(model=config.model, config=config)
all_results = evaluator.evaluate_datasets()
with open(f"{output_dir}/summary.json", "w") as f: with open(f"{output_dir}/summary.json", "w") as f:
json.dump(all_results, f, indent=2) json.dump(results, f, indent=2)
if __name__ == "__main__": if __name__ == "__main__":
main() asyncio.run(async_main())

View file

@ -1,9 +1,8 @@
model: deepseek/deepseek-r1 model: deepseek/deepseek-r1
category: algorithmic category: algorithmic
datasets: datasets:
- base_conversion
- binary_matrix - binary_matrix
- caesar _cipher - caesar_cipher
- group_anagrams - group_anagrams
- isomorphic_strings - isomorphic_strings
- letter_counting - letter_counting

View file

@ -0,0 +1,61 @@
[
{
"dataset_name": "letter_counting",
"model": "openai/o1",
"average_score": 0.99,
"total_examples": 50,
"timestamp": "2025-02-12T10:26:39.897674",
"config": {
"min_words": 5,
"max_words": 15,
"size": 50,
"seed": 42
}
},
{
"dataset_name": "propositional_logic",
"model": "openai/o1",
"average_score": 0.010000000000000004,
"total_examples": 50,
"timestamp": "2025-02-12T10:27:45.054740",
"config": {
"size": 50,
"seed": 42
}
},
{
"dataset_name": "leg_counting",
"model": "openai/o1",
"average_score": 0.802,
"total_examples": 50,
"timestamp": "2025-02-12T10:28:06.199253",
"config": {
"min_animals": 3,
"max_animals": 8,
"size": 50,
"seed": 42
}
},
{
"dataset_name": "group_anagrams",
"model": "openai/o1",
"average_score": 0.94,
"total_examples": 50,
"timestamp": "2025-02-12T10:30:02.084562",
"config": {
"size": 50,
"seed": 42
}
},
{
"dataset_name": "spell_backward",
"model": "openai/o1",
"average_score": 0.9802000000000001,
"total_examples": 50,
"timestamp": "2025-02-12T10:30:17.839014",
"config": {
"size": 50,
"seed": 42
}
}
]

View file

@ -70,28 +70,32 @@ class PolynomialEquationsDataset(ProceduralDataset):
Returns: Returns:
A dict with: A dict with:
- question: str (e.g. "Solve the polynomial equation: 2*x^2 - 3*x + 1 = 0") - question: str (e.g. "Solve the polynomial equation: 2*x**2 - 3*x + 1 = 0")
- answer: str (the sorted list of real solutions, e.g. "[0.5, 1.0]") - answer: str (the sorted list of real solutions, e.g. "0.5, 1.0")
- metadata: dict with details (polynomial_expr, degree, etc.) - metadata: dict with details (polynomial_expr, degree, etc.)
""" """
rng = random.Random(self.seed + idx) rng = random.Random(self.seed + idx)
for _ in range(8):
# Get variable and generate polynomial equation in standard form
variable = self._get_variable(rng)
degree = rng.randint(self.config.min_degree, self.config.max_degree)
polynomial_expr = self._generate_polynomial_expr(rng, variable, degree)
polynomial_expanded = expand(polynomial_expr)
# Get variable and generate polynomial equation in standard form # Solve the polynomial = 0
variable = self._get_variable(rng) # We filter real solutions only
degree = rng.randint(self.config.min_degree, self.config.max_degree) solutions = solve(Eq(polynomial_expanded, 0), variable, dict=False)
polynomial_expr = self._generate_polynomial_expr(rng, variable, degree) real_solutions = []
polynomial_expanded = expand(polynomial_expr) for sol in solutions:
if sol.is_real:
# Evaluate symbolic solution to a floating approximation
real_solutions.append(float(sol.evalf()))
# Solve the polynomial = 0 if len(real_solutions) > 0:
# We filter real solutions only real_solutions.sort()
solutions = solve(Eq(polynomial_expanded, 0), variable, dict=False) break
real_solutions = []
for sol in solutions: answer_str = ", ".join(str(x) for x in real_solutions)
if sol.is_real:
# Evaluate symbolic solution to a floating approximation
real_solutions.append(float(sol.evalf()))
real_solutions.sort()
answer_str = str(real_solutions)
return { return {
"question": rng.choice(self._prompt_templates).format( "question": rng.choice(self._prompt_templates).format(
@ -109,7 +113,7 @@ class PolynomialEquationsDataset(ProceduralDataset):
def _get_variable(self, rng: random.Random) -> str: def _get_variable(self, rng: random.Random) -> str:
"""Get a random lowercase variable name""" """Get a random lowercase variable name"""
return rng.choice(string.ascii_lowercase) return rng.choice("abcdefghklmnopqrstuvwxyz") # remove ij to avoid confusion with complex numbers
def _generate_polynomial_expr(self, rng: random.Random, variable: Symbol, degree: int): def _generate_polynomial_expr(self, rng: random.Random, variable: Symbol, degree: int):
""" """
@ -202,6 +206,9 @@ class PolynomialEquationsDataset(ProceduralDataset):
oracle_solutions = self._parse_score_to_list(entry["answer"]) # Parse oracle solutions oracle_solutions = self._parse_score_to_list(entry["answer"]) # Parse oracle solutions
predicted_solutions = self._parse_score_to_list(answer) # Parse predicted solutions predicted_solutions = self._parse_score_to_list(answer) # Parse predicted solutions
if len(oracle_solutions) == 0 and len(predicted_solutions) == 0:
return 1.0
total_reward = 0.0 total_reward = 0.0
matched_solutions = 0 matched_solutions = 0
extra_solutions = 0 extra_solutions = 0

View file

@ -19,12 +19,14 @@ from .manipulate_matrix import ManipulateMatrixConfig, ManipulateMatrixDataset
from .number_filtering import NumberFilteringConfig, NumberFilteringDataset from .number_filtering import NumberFilteringConfig, NumberFilteringDataset
from .number_sorting import NumberSortingConfig, NumberSortingDataset from .number_sorting import NumberSortingConfig, NumberSortingDataset
from .palindrome_generation import PalindromeConfig, PalindromeDataset from .palindrome_generation import PalindromeConfig, PalindromeDataset
from .pool_matrix import PoolMatrixConfig, PoolMatrixDataset
from .ransom_note import RansomNoteConfig, RansomNoteDataset from .ransom_note import RansomNoteConfig, RansomNoteDataset
from .rotate_matrix import RotateMatrixConfig, RotateMatrixDataset from .rotate_matrix import RotateMatrixConfig, RotateMatrixDataset
from .sentence_reordering import SentenceReorderingConfig, SentenceReorderingDataset from .sentence_reordering import SentenceReorderingConfig, SentenceReorderingDataset
from .spell_backward import SpellBackwardConfig, SpellBackwardDataset from .spell_backward import SpellBackwardConfig, SpellBackwardDataset
from .spiral_matrix import SpiralMatrixConfig, SpiralMatrixDataset from .spiral_matrix import SpiralMatrixConfig, SpiralMatrixDataset
from .string_insertion import StringInsertionConfig, StringInsertionDataset from .string_insertion import StringInsertionConfig, StringInsertionDataset
from .string_manipulation import StringManipulationConfig, StringManipulationDataset
from .word_ladder import WordLadderConfig, WordLadderDataset from .word_ladder import WordLadderConfig, WordLadderDataset
from .word_sequence_reversal import WordSequenceReversalConfig, WordSequenceReversalDataset from .word_sequence_reversal import WordSequenceReversalConfig, WordSequenceReversalDataset
from .word_sorting import TextTransformation, WordSortingConfig, WordSortingDataset from .word_sorting import TextTransformation, WordSortingConfig, WordSortingDataset
@ -69,10 +71,14 @@ __all__ = [
"ManipulateMatrixDataset", "ManipulateMatrixDataset",
"BinaryMatrixConfig", "BinaryMatrixConfig",
"BinaryMatrixDataset", "BinaryMatrixDataset",
"PoolMatrixConfig",
"PoolMatrixDataset",
"ABConfig", "ABConfig",
"ABDataset", "ABDataset",
"CountPrimesConfig", "CountPrimesConfig",
"CountPrimesDataset", "CountPrimesDataset",
"StringInsertionConfig", "StringInsertionConfig",
"StringInsertionDataset", "StringInsertionDataset",
"StringManipulationConfig",
"StringManipulationDataset",
] ]

View file

@ -0,0 +1,142 @@
"""Perform average / max pooling on a matrix"""
from copy import deepcopy
from dataclasses import dataclass
from random import Random
from typing import Dict, Optional
import numpy as np
from ..factory import ProceduralDataset, register_dataset
QUESTION_TEMPLATE = """Your job is to perform max/average pooling on the given matrix.
The stride is equal to the kernel size, meaning there is no overlap between the pooling regions.
Example 1:
- Input: Perform max pooling on the following matrix with a kernel size of 2:
1 2 3 4
5 6 7 8
9 10 11 12
13 14 15 16
- Output:
6 8
14 16
Example 2:
- Input: Perform average pooling on the following matrix with a kernel size of 2:
1 2 3 4
5 6 7 8
9 10 11 12
13 14 15 16
- Output:
3.5 5.5
11.5 13.5
Perform {pool_type} pooling on the following matrix with a kernel size of {pool_size}:
{matrix}
"""
@dataclass
class PoolMatrixConfig:
"""Configuration for Pool Matrix dataset generation"""
min_rows: int = 2 # Minimum rows of the matrix
min_cols: int = 2 # Minimum columns of the matrix
max_rows: int = 10 # Maximum rows of the matrix
max_cols: int = 10 # Maximum columns of the matrix
max_pool_size: int = 3 # Maximum pooling size
size: int = 500 # Virtual dataset size
seed: Optional[int] = None
def validate(self):
"""Validate configuration parameters"""
assert 2 <= self.min_rows, "min_rows must be at least 2"
assert 2 <= self.min_cols, "min_cols must be at least 2"
assert self.min_rows <= self.max_rows, "max_rows must be at least min_rows"
assert self.min_cols <= self.max_cols, "max_cols must be at least min_cols"
assert 1 <= self.max_pool_size, "max_pool_size must be at least 1"
class PoolMatrixDataset(ProceduralDataset):
"""Generates Pool Matrix exercises with configurable difficulty"""
def __init__(self, config: PoolMatrixConfig):
super().__init__(config=config, seed=config.seed, size=config.size)
def _get_matrix(self, rng: Random) -> np.ndarray:
"""Generate a random matrix"""
rows = rng.randint(self.config.min_rows, self.config.max_rows)
cols = rng.randint(self.config.min_rows, self.config.max_cols)
return np.random.randint(0, 10, (rows, cols))
def _matrix_to_str(self, matrix: np.ndarray) -> str:
"""Get a string representation of the matrix"""
return "\n".join(" ".join(str(round(x, 2)) for x in row) for row in matrix)
def _max_pool(self, matrix: np.ndarray, pool_size: int) -> np.ndarray:
"""Perform max pooling on the matrix"""
rows, cols = matrix.shape
return np.array(
[
[np.max(matrix[i : i + pool_size, j : j + pool_size]) for j in range(0, cols, pool_size)]
for i in range(0, rows, pool_size)
]
)
def _average_pool(self, matrix: np.ndarray, pool_size: int) -> np.ndarray:
"""Perform average pooling on the matrix"""
rows, cols = matrix.shape
return np.array(
[
[np.mean(matrix[i : i + pool_size, j : j + pool_size]) for j in range(0, cols, pool_size)]
for i in range(0, rows, pool_size)
]
)
def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float:
"""Score the answer based on the metadata"""
reward = 0.0
try:
if answer is not None:
oracle_answer = np.array(entry["answer"])
answer = np.array(answer)
if oracle_answer.shape == answer.shape and np.allclose(oracle_answer, answer):
reward = 1.0
if oracle_answer.shape == answer.shape:
reward = 0.1
else:
reward = 0.01
except:
print("Error in scoring answer for Pool Matrix")
return reward
def __getitem__(self, idx: int) -> dict:
"""Generate a single Pool Matrix question"""
rng = Random(self.seed + idx)
np.random.seed(self.seed + idx)
matrix = self._get_matrix(rng)
matrix_str = self._matrix_to_str(matrix)
pool_size = rng.randint(1, self.config.max_pool_size)
pool_type = rng.choice(["average", "max"])
answer = self._average_pool(matrix, pool_size) if pool_type == "average" else self._max_pool(matrix, pool_size)
answer_str = self._matrix_to_str(answer)
return {
"question": QUESTION_TEMPLATE.format(matrix=matrix_str, pool_type=pool_type, pool_size=pool_size),
"answer": answer_str,
"metadata": {
"matrix": matrix.tolist(),
"pool_type": pool_type,
"pool_size": pool_size,
"solution": answer.tolist(),
},
}
register_dataset("pool_matrix", PoolMatrixDataset, PoolMatrixConfig)

View file

@ -0,0 +1,199 @@
"""Manipulate a string according to a set of rules
https://github.com/yongchao98/CodeSteer-v1.0/blob/main/create_dataset/create_dataset_string_deletion_and_modification.py
"""
from dataclasses import dataclass
from random import Random
from typing import Optional
from ..factory import ProceduralDataset, register_dataset
QUESTION_TEMPLATE = """Your job is to repeatedly transform a string according to a set of rules until no further transformations can be performed, or a state is repeated.
Evaluate the following rules in order, and apply the first applicable rule to the string:
{rules}
Once you have applied a rule, repeat the process with the new string until no further transformations can be performed (i.e. the string doesn't change), or a state is repeated.
If a state is repeated, the process is terminated, and the repeated state is discarded (i.e. is not considered as the final answer) and the state before the repeated state is considered as the final answer.
Example:
- Input:
- String: abbac
- Rules:
1. If the string prefix is 'ab', replace it with 'ca'.
2. If the string prefix is 'ca', replace it with 'bb' and append 'c' to the end.
3. If the string ends with 'aa', replace it with 'cc'.
- Output: bbbacc
- Explanation:
- In the first iteration, rule 1 is applied to the string abbac, resulting in cabac
- In the second interation, rule 1 doesn't apply, but rule 2 is applied to the string cabac, resulting in bbbacc
- In the third iteration, none of the rules (1, 2, 3) apply, so the process is terminated, and the final answer is bbbacc
Transform the following string according to the above list of rules:
{string}
"""
@dataclass
class StringManipulationConfig:
"""Configuration for String Insertion dataset generation"""
min_string_length: int = 5 # Minimum string length
max_string_length: int = 20 # Maximum string length
min_num_rules: int = 3 # Minimum number of rules/transforms
max_num_rules: int = 8 # Maximum number of rules/transforms
size: int = 500 # Virtual dataset size
seed: Optional[int] = None
def validate(self):
"""Validate configuration parameters"""
assert 5 <= self.min_string_length, "Minimum string length should be at least 5"
assert self.min_string_length <= self.max_string_length, "Minimum string length should be less than maximum"
assert 3 <= self.min_num_rules, "Minimum number of rules should be at least 3"
assert self.min_num_rules <= self.max_num_rules, "Minimum number of rules should be less than maximum"
class StringManipulationDataset(ProceduralDataset):
"""Generates String Insertion exercises with configurable difficulty"""
def __init__(self, config: StringManipulationConfig):
super().__init__(config=config, seed=config.seed, size=config.size)
self.vocabulary = ["a", "b", "c"]
self.rules = [
(
"If the string prefix is 'ab', replace it with 'ca'.",
lambda s: ("ca" + s[2:], 1) if s.startswith("ab") else (s, 0),
),
(
"If the string suffix is 'ac', replace it with 'cb'.",
lambda s: (s[:-2] + "cb", 2) if s.endswith("ac") else (s, 0),
),
(
"If the string prefix is 'bc', delete the first two characters and append 'aa' to the end.",
lambda s: (s[2:] + "aa", 3) if s.startswith("bc") else (s, 0),
),
(
"If the string suffix is 'bb', delete the last two characters.",
lambda s: (s[:-2], 4) if s.endswith("bb") else (s, 0),
),
(
"If the string prefix is 'cb', replace it with 'aa' and delete the last character.",
lambda s: ("aa" + s[2:-1], 5) if s.startswith("cb") and len(s) > 1 else (s, 0),
),
(
"If the string prefix is 'ca', replace it with 'bb' and append 'c' to the end.",
lambda s: ("bb" + s[2:] + "c", 6) if s.startswith("ca") else (s, 0),
),
(
"If the string suffix is 'cc', replace it with 'b' and prepend 'a' to the start.",
lambda s: ("a" + s[:-2] + "b", 7) if s.endswith("cc") else (s, 0),
),
(
"If the string prefix is 'aa', remove the first character.",
lambda s: (s[1:], 8) if s.startswith("aa") else (s, 0),
),
(
"If the string contains 'abc', replace the first occurrence with 'cab'.",
lambda s: (s.replace("abc", "cab", 1), 9) if "abc" in s else (s, 0),
),
(
"If the string contains 'bca', delete the first occurrence entirely.",
lambda s: (s.replace("bca", "", 1), 10) if "bca" in s else (s, 0),
),
(
"If the string ends with 'ba', replace it with 'ab'.",
lambda s: (s[:-2] + "ab", 11) if s.endswith("ba") else (s, 0),
),
(
"If the string starts with 'cc', remove the first two characters.",
lambda s: (s[2:], 12) if s.startswith("cc") else (s, 0),
),
(
"If the string contains 'acb', replace the first occurrence with its reverse ('bca').",
lambda s: (s.replace("acb", "bca", 1), 13) if "acb" in s else (s, 0),
),
(
"If the string ends with 'ca', remove the last character.",
lambda s: (s[:-1], 14) if s.endswith("ca") and len(s) > 0 else (s, 0),
),
(
"If the string starts with 'bb', remove the second character.",
lambda s: (s[0] + s[2:], 15) if s.startswith("bb") and len(s) >= 2 else (s, 0),
),
(
"If the string ends with 'aa', replace it with 'cc'.",
lambda s: (s[:-2] + "cc", 16) if s.endswith("aa") else (s, 0),
),
(
"If the string contains 'ca' (not at the start), remove the first occurrence found after the first character.",
lambda s: (s[:idx] + s[idx + 2 :], 17) if (idx := s.find("ca", 1)) != -1 else (s, 0),
),
(
"If the string contains an even number of 'b's (and at least one 'b'), append 'ab' at the end.",
lambda s: (s + "ab", 18) if (s.count("b") > 0 and s.count("b") % 2 == 0) else (s, 0),
),
(
"If the string length is greater than 15, remove the middle character.",
lambda s: (s[: len(s) // 2] + s[len(s) // 2 + 1 :], 19) if len(s) > 15 else (s, 0),
),
(
"If the string starts with 'ac', replace the first two characters with 'zz'.",
lambda s: ("zz" + s[2:], 20) if s.startswith("ac") else (s, 0),
),
]
def _apply_rule(self, string: str, selected_rules: list[tuple[str, callable]]) -> tuple[str, int]:
"""
Apply the first applicable rule from the list of selected rules.
Returns a tuple containing the modified string and the rule index (1-based) that was applied.
If no rule is applicable, returns (s, 0).
"""
for _, rule_fn in selected_rules:
new_string, op_idx = rule_fn(string)
if op_idx != 0:
return new_string, op_idx
return string, 0
def _get_all_transforms(self, string: str, selected_rules: list[tuple[str, callable]]) -> list[str]:
"""
Repeatedly apply transformation rules to a string until no further transformations can be performed,
or a state is repeated. If a state is repeated, the process is terminated, and the state is not added to the list.
Returns a list of string states from the initial string to the final state (i.e. the desired answer).
"""
states = [string]
while True:
new_string, op_idx = self._apply_rule(states[-1], selected_rules)
if op_idx == 0 or new_string in states:
break
states.append(new_string)
return states
def __getitem__(self, idx: int) -> dict:
"""Generate a single String Insertion question"""
rng = Random(self.seed + idx)
string_length = rng.randint(self.config.min_string_length, self.config.max_string_length)
string = "".join(rng.choice(self.vocabulary) for _ in range(string_length))
num_rules = rng.randint(self.config.min_num_rules, self.config.max_num_rules)
selected_rules = rng.sample(self.rules, num_rules)
rules_str = "\n".join(f"{i+1}. {rule}" for i, (rule, _) in enumerate(selected_rules))
states = self._get_all_transforms(string, selected_rules)
answer = states[-1]
return {
"question": QUESTION_TEMPLATE.format(string=string, rules=rules_str),
"answer": str(answer),
"metadata": {
"string": string,
"solution": answer,
"states": states,
"selected_rules": [rule for rule, _ in selected_rules],
},
}
register_dataset("string_manipulation", StringManipulationDataset, StringManipulationConfig)

View file

@ -119,28 +119,26 @@ class ChainSumCurriculum(BaseCurriculum):
# Define attributes # Define attributes
self._define_attributes( self._define_attributes(
( RangeAttributeDefinition(
RangeAttributeDefinition( name="num_terms",
name="num_terms", levels=[2, 3, 4, 5],
levels=[2, 3, 4, 5], default_level=0, # Start with 2 terms
default_level=0, # Start with 2 terms description="Maximum number of terms in the expression",
description="Maximum number of terms in the expression", attr_type=AttributeType.APPEND,
attr_type=AttributeType.APPEND, min_value=2, # Ensure at least 2 terms
min_value=2, # Ensure at least 2 terms lower_field_name="min_terms",
lower_field_name="min_terms", upper_field_name="max_terms",
upper_field_name="max_terms", ),
), RangeAttributeDefinition(
RangeAttributeDefinition( name="num_digits",
name="num_digits", levels=[1, 2, 4, 10],
levels=[1, 2, 4, 10], default_level=0, # Start with 1-digit numbers
default_level=0, # Start with 1-digit numbers description="Number of digits in each operand",
description="Number of digits in each operand", attr_type=AttributeType.APPEND,
attr_type=AttributeType.APPEND, min_value=1, # Ensure numbers are at least 1 digit
min_value=1, # Ensure numbers are at least 1 digit lower_field_name="min_digits",
lower_field_name="min_digits", upper_field_name="max_digits",
upper_field_name="max_digits", ),
),
)
) )

View file

@ -34,7 +34,7 @@ class BaseCurriculum:
raise KeyError(f"Attribute '{self.name}.{attr_name}' does not exist") raise KeyError(f"Attribute '{self.name}.{attr_name}' does not exist")
return self._attributes[attr_name] return self._attributes[attr_name]
def _define_attributes(self, attrs: Iterable[AttributeDefinition]) -> None: def _define_attributes(self, *attrs: tuple[AttributeDefinition, ...]) -> None:
for attr in attrs: for attr in attrs:
if attr.name in self.attributes: if attr.name in self.attributes:
raise RuntimeError(f"Attribute with name {attr.name} is already defined.") raise RuntimeError(f"Attribute with name {attr.name} is already defined.")

View file

@ -4,12 +4,15 @@ from decimal import Decimal, InvalidOperation
from fractions import Fraction from fractions import Fraction
from typing import Any, Optional, Union from typing import Any, Optional, Union
# DeepSeek Zero system prompt
SYSTEM_PROMPTS = { SYSTEM_PROMPTS = {
"DeepSeekZero": """A conversation between User and Assistant. The user asks a question, and the Assistant solves it. "DeepSeekZero": """A conversation between User and Assistant. The user asks a question, and the Assistant solves it.
The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think> reasoning process here </think> The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think> reasoning process here </think>
<answer> answer here </answer> <answer> answer here </answer>
""" """,
"default": """Given a problem, your task is to answer the question by thinking step-by-step in a clear and specific manner.
Once you have thought about the reasoning process, provide the answer in the following format:
<answer> answer here </answer>
""",
} }

View file

@ -138,3 +138,12 @@ def test_polynomial_solutions_score_answer(oracle_answer, predicted_answer, expe
actual_reward = ds.score_answer(predicted_answer, {"answer": oracle_answer}) actual_reward = ds.score_answer(predicted_answer, {"answer": oracle_answer})
assert actual_reward == pytest.approx(expected_reward, rel=1e-3) # Fuzzy comparison for floats assert actual_reward == pytest.approx(expected_reward, rel=1e-3) # Fuzzy comparison for floats
def test_polynomial_perfect_score():
"""Test that scoring an item's own answer gives a perfect score"""
cfg = PolynomialEquationsConfig(seed=42, size=10)
ds = PolynomialEquationsDataset(cfg)
for item in ds:
assert ds.score_answer(item["answer"], item) == 1.0

138
tests/test_pool_matrix.py Normal file
View file

@ -0,0 +1,138 @@
"""Tests for Pool Matrix questions generation"""
import numpy as np
import pytest
from reasoning_gym.algorithmic.pool_matrix import PoolMatrixConfig, PoolMatrixDataset
def test_pool_matrix_config_validation():
"""Test that invalid configs raise appropriate errors"""
for field in ["min_rows", "min_cols", "max_rows", "max_cols"]:
with pytest.raises(AssertionError):
config = PoolMatrixConfig(**{field: -1}) # Negative not allowed
config.validate()
with pytest.raises(AssertionError):
config = PoolMatrixConfig(**{field: 0}) # Zero not allowed
config.validate()
with pytest.raises(AssertionError):
config = PoolMatrixConfig(**{field: 1}) # One not allowed
config.validate()
with pytest.raises(AssertionError):
config = PoolMatrixConfig(max_pool_size=-1) # Negative not allowed
config.validate()
with pytest.raises(AssertionError):
config = PoolMatrixConfig(max_pool_size=0) # Zero not allowed
config.validate()
def test_pool_matrix_dataset_deterministic():
"""Test that dataset generates same items with same seed"""
config = PoolMatrixConfig(seed=42, size=10)
dataset1 = PoolMatrixDataset(config)
dataset2 = PoolMatrixDataset(config)
for i in range(len(dataset1)):
assert dataset1[i] == dataset2[i]
def test_pool_matrix_dataset_items():
"""Test basic properties of generated items"""
config = PoolMatrixConfig(max_rows=10, max_cols=10, max_pool_size=3, size=10, seed=42)
dataset = PoolMatrixDataset(config)
for i in range(len(dataset)):
item = dataset[i]
# Check item structure
assert isinstance(item, dict)
assert "question" in item
assert "answer" in item
assert "metadata" in item
# Check metadata
assert "matrix" in item["metadata"]
assert "pool_type" in item["metadata"]
assert "pool_size" in item["metadata"]
assert "solution" in item["metadata"]
matrix = item["metadata"]["matrix"]
pool_type = item["metadata"]["pool_type"]
pool_size = item["metadata"]["pool_size"]
solution = item["metadata"]["solution"]
# Verify dimensions
assert len(matrix) <= config.max_rows
assert all(len(row) <= config.max_cols for row in matrix)
assert len(solution) <= len(matrix)
assert len(solution[0]) <= len(matrix[0])
assert pool_size <= config.max_pool_size
assert pool_type in ["average", "max"]
def test_pool_matrix_dataset_iteration():
"""Test that iteration respects dataset size"""
config = PoolMatrixConfig(size=5, seed=42)
dataset = PoolMatrixDataset(config)
items = list(dataset)
assert len(items) == config.size
# Test multiple iterations yield same items
assert items == list(dataset)
def test_pool_matrix_answer():
"""Test the pooling methods"""
config = PoolMatrixConfig(seed=42)
dataset = PoolMatrixDataset(config)
# 1. Max pooling
matrix = np.array([[1, 2], [3, 4]])
assert np.allclose(dataset._max_pool(matrix, 2), np.array([[4]]))
matrix = np.array(
[
[1, 2, 3, 4],
[5, 6, 7, 8],
[9, 10, 11, 12],
]
)
assert np.allclose(dataset._max_pool(matrix, 2), np.array([[6, 8], [10, 12]]))
matrix = np.array(
[
[1, 2, 3, 4],
[5, 6, 7, 8],
[9, 10, 11, 12],
[13, 14, 15, 16],
]
)
assert np.allclose(dataset._max_pool(matrix, 2), np.array([[6, 8], [14, 16]]))
# 2. Average pooling
matrix = np.array([[1, 2], [3, 4]])
assert np.allclose(dataset._average_pool(matrix, 2), np.array([[2.5]]))
matrix = np.array(
[
[1, 2, 3, 4],
[5, 6, 7, 8],
[9, 10, 11, 12],
]
)
assert np.allclose(dataset._average_pool(matrix, 2), np.array([[3.5, 5.5], [9.5, 11.5]]))
matrix = np.array(
[
[1, 2, 3, 4],
[5, 6, 7, 8],
[9, 10, 11, 12],
[13, 14, 15, 16],
]
)
assert np.allclose(dataset._average_pool(matrix, 2), np.array([[3.5, 5.5], [11.5, 13.5]]))

View file

@ -0,0 +1,257 @@
"""Tests for String Manipulation questions generation"""
import pytest
from reasoning_gym.algorithmic.string_manipulation import StringManipulationConfig, StringManipulationDataset
def test_string_manipulation_config_validation():
"""Test that invalid configs raise appropriate errors"""
with pytest.raises(AssertionError):
config = StringManipulationConfig(min_string_length=4) # Minimum string length should be at least 5
config.validate()
with pytest.raises(AssertionError):
config = StringManipulationConfig(min_string_length=10, max_string_length=7) # Max must be greater than min
config.validate()
with pytest.raises(AssertionError):
config = StringManipulationConfig(min_num_rules=2) # Min number of rules should be at least 3
config.validate()
with pytest.raises(AssertionError):
config = StringManipulationConfig(min_num_rules=5, max_num_rules=3) # Max must be greater than min
config.validate()
def test_string_manipulation_dataset_deterministic():
"""Test that dataset generates same items with same seed"""
config = StringManipulationConfig(seed=42, size=10)
dataset1 = StringManipulationDataset(config)
dataset2 = StringManipulationDataset(config)
for i in range(len(dataset1)):
assert dataset1[i] == dataset2[i]
def test_string_manipulation_dataset_items():
"""Test basic properties of generated items"""
config = StringManipulationConfig(
min_string_length=7, max_string_length=25, min_num_rules=5, max_num_rules=12, size=10, seed=42
)
dataset = StringManipulationDataset(config)
for i in range(len(dataset)):
item = dataset[i]
# Check item structure
assert isinstance(item, dict)
assert "question" in item
assert "answer" in item
assert "metadata" in item
# Check metadata
assert "string" in item["metadata"]
assert "states" in item["metadata"]
# assert "selected_rules" in item["metadata"]
assert "solution" in item["metadata"]
string = item["metadata"]["string"]
solution = item["metadata"]["solution"]
states = item["metadata"]["states"]
selected_rules = item["metadata"]["selected_rules"]
# Verify dimensions
assert config.min_string_length <= len(string) <= config.max_string_length
assert config.min_num_rules <= len(selected_rules) <= config.max_num_rules
assert len(states) >= 1
assert solution == states[-1]
def test_string_manipulation_dataset_iteration():
"""Test that iteration respects dataset size"""
config = StringManipulationConfig(size=5, seed=42)
dataset = StringManipulationDataset(config)
items = list(dataset)
assert len(items) == config.size
# Test multiple iterations yield same items
assert items == list(dataset)
def test_string_manipulation_answer():
"""Test the method for getting the answer"""
config = StringManipulationConfig(seed=42)
dataset = StringManipulationDataset(config)
rules = [
(
"If the string prefix is 'ab', replace it with 'ca'.",
lambda s: ("ca" + s[2:], 1) if s.startswith("ab") else (s, 0),
)
]
assert dataset._get_all_transforms("abbbab", rules)[-1] == "cabbab"
rules = [
(
"If the string suffix is 'ac', replace it with 'cb'.",
lambda s: (s[:-2] + "cb", 2) if s.endswith("ac") else (s, 0),
),
]
assert dataset._get_all_transforms("abbbac", rules)[-1] == "abbbcb"
rules = [
(
"If the string prefix is 'bc', delete the first two characters and append 'aa' to the end.",
lambda s: (s[2:] + "aa", 3) if s.startswith("bc") else (s, 0),
),
]
assert dataset._get_all_transforms("bcabbb", rules)[-1] == "abbbaa"
rules = [
(
"If the string suffix is 'bb', delete the last two characters.",
lambda s: (s[:-2], 4) if s.endswith("bb") else (s, 0),
),
]
assert dataset._get_all_transforms("abbbabb", rules)[-1] == "abbba"
rules = [
(
"If the string prefix is 'cb', replace it with 'aa' and delete the last character.",
lambda s: ("aa" + s[2:-1], 5) if s.startswith("cb") and len(s) > 1 else (s, 0),
)
]
assert dataset._get_all_transforms("cbabbb", rules)[-1] == "aaabb"
rules = [
(
"If the string prefix is 'ca', replace it with 'bb' and append 'c' to the end.",
lambda s: ("bb" + s[2:] + "c", 6) if s.startswith("ca") else (s, 0),
)
]
assert dataset._get_all_transforms("caabbb", rules)[-1] == "bbabbbc"
rules = [
(
"If the string suffix is 'cc', replace it with 'b' and prepend 'a' to the start.",
lambda s: ("a" + s[:-2] + "b", 7) if s.endswith("cc") else (s, 0),
)
]
assert dataset._get_all_transforms("abbbcc", rules)[-1] == "aabbbb"
rules = [
(
"If the string prefix is 'aa', remove the first character.",
lambda s: (s[1:], 8) if s.startswith("aa") else (s, 0),
)
]
assert dataset._get_all_transforms("aabbb", rules)[-1] == "abbb"
rules = [
(
"If the string contains 'abc', replace the first occurrence with 'cab'.",
lambda s: (s.replace("abc", "cab", 1), 9) if "abc" in s else (s, 0),
)
]
assert dataset._get_all_transforms("ababcb", rules)[-1] == "cababb" # 'ababcb' -> 'abcabb' -> 'cababb'
rules = [
(
"If the string contains 'bca', delete the first occurrence entirely.",
lambda s: (s.replace("bca", "", 1), 10) if "bca" in s else (s, 0),
)
]
assert dataset._get_all_transforms("abbcab", rules)[-1] == "abb"
rules = [
(
"If the string ends with 'ba', replace it with 'ab'.",
lambda s: (s[:-2] + "ab", 11) if s.endswith("ba") else (s, 0),
)
]
assert dataset._get_all_transforms("abbbba", rules)[-1] == "abbbab"
rules = [
(
"If the string starts with 'cc', remove the first two characters.",
lambda s: (s[2:], 12) if s.startswith("cc") else (s, 0),
)
]
assert dataset._get_all_transforms("ccabbb", rules)[-1] == "abbb"
rules = [
(
"If the string contains 'acb', replace the first occurrence with its reverse ('bca').",
lambda s: (s.replace("acb", "bca", 1), 13) if "acb" in s else (s, 0),
)
]
assert dataset._get_all_transforms("abacbb", rules)[-1] == "abbcab"
rules = [
(
"If the string contains 'acb', replace the first occurrence with its reverse ('bca').",
lambda s: (s.replace("acb", "bca", 1), 13) if "acb" in s else (s, 0),
)
]
assert dataset._get_all_transforms("abacbb", rules)[-1] == "abbcab"
rules = [
(
"If the string ends with 'ca', remove the last character.",
lambda s: (s[:-1], 14) if s.endswith("ca") and len(s) > 0 else (s, 0),
)
]
assert dataset._get_all_transforms("abbbca", rules)[-1] == "abbbc"
rules = [
(
"If the string starts with 'bb', remove the second character.",
lambda s: (s[0] + s[2:], 15) if s.startswith("bb") and len(s) >= 2 else (s, 0),
)
]
assert dataset._get_all_transforms("bbabcbb", rules)[-1] == "babcbb"
rules = [
(
"If the string ends with 'aa', replace it with 'cc'.",
lambda s: (s[:-2] + "cc", 16) if s.endswith("aa") else (s, 0),
)
]
assert dataset._get_all_transforms("abccbaa", rules)[-1] == "abccbcc"
rules = [
(
"If the string contains 'ca' (not at the start), remove the first occurrence found after the first character.",
lambda s: (s[:idx] + s[idx + 2 :], 17) if (idx := s.find("ca", 1)) != -1 else (s, 0),
)
]
assert dataset._get_all_transforms("abacab", rules)[-1] == "abab"
assert dataset._get_all_transforms("caabab", rules)[-1] == "caabab"
rules = [
(
"If the string contains an even number of 'b's (and at least one 'b'), append 'ab' at the end.",
lambda s: (s + "ab", 18) if (s.count("b") > 0 and s.count("b") % 2 == 0) else (s, 0),
)
]
assert dataset._get_all_transforms("abab", rules)[-1] == "ababab"
assert dataset._get_all_transforms("abbab", rules)[-1] == "abbab"
rules = [
(
"If the string length is greater than 15, remove the middle character.",
lambda s: (s[: len(s) // 2] + s[len(s) // 2 + 1 :], 19) if len(s) > 15 else (s, 0),
)
]
assert (
dataset._get_all_transforms("bccbcbbbcbbbbcccc", rules)[-1] == "bccbcbbbbbbcccc"
) # bccbcbbbcbbbbcccc -> "bccbcbbbbbbbcccc" -> "bccbcbbbbbbcccc"
rules = [
(
"If the string starts with 'ac', replace the first two characters with 'zz'.",
lambda s: ("zz" + s[2:], 20) if s.startswith("ac") else (s, 0),
)
]
assert dataset._get_all_transforms("acab", rules)[-1] == "zzab"