mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-19 12:58:07 +00:00
Merge branch 'main' into env/string-insertion
This commit is contained in:
commit
971782308f
14 changed files with 956 additions and 118 deletions
19
eval/eval.py
19
eval/eval.py
|
|
@ -2,6 +2,7 @@ import argparse
|
|||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List
|
||||
|
|
@ -10,6 +11,7 @@ from openai import AsyncOpenAI
|
|||
from tqdm.asyncio import tqdm_asyncio
|
||||
|
||||
from reasoning_gym.factory import create_dataset
|
||||
from reasoning_gym.utils import SYSTEM_PROMPTS
|
||||
|
||||
|
||||
class AsyncOpenRouterEvaluator:
|
||||
|
|
@ -25,22 +27,33 @@ class AsyncOpenRouterEvaluator:
|
|||
async with self.semaphore:
|
||||
try:
|
||||
completion = await self.client.chat.completions.create(
|
||||
extra_headers=self.extra_headers, model=self.model, messages=[{"role": "user", "content": prompt}]
|
||||
extra_headers=self.extra_headers,
|
||||
model=self.model,
|
||||
messages=[
|
||||
{"role": "system", "content": SYSTEM_PROMPTS["default"]},
|
||||
{"role": "user", "content": prompt},
|
||||
],
|
||||
)
|
||||
return completion.choices[0].message.content
|
||||
except Exception as e:
|
||||
print(f"Error calling OpenRouter API: {str(e)}")
|
||||
raise
|
||||
|
||||
def parse_model_response(self, response: str) -> str:
|
||||
"""Gather the final answer between the <answer> and </answer> tags."""
|
||||
match = re.search(r"<answer>(.*?)</answer>", response, re.DOTALL)
|
||||
return match.group(1).strip() if match else response
|
||||
|
||||
async def process_single_question(self, entry: Dict, dataset) -> Dict:
|
||||
"""Process a single question and return the result."""
|
||||
response = await self.get_model_response(entry["question"])
|
||||
score = dataset.score_answer(answer=response, entry=entry)
|
||||
answer = self.parse_model_response(response)
|
||||
score = dataset.score_answer(answer=answer, entry=entry)
|
||||
|
||||
return {
|
||||
"question": entry["question"],
|
||||
"expected_answer": entry["answer"],
|
||||
"model_answer": response,
|
||||
"model_answer": answer,
|
||||
"score": score,
|
||||
"metadata": entry["metadata"],
|
||||
}
|
||||
|
|
|
|||
116
eval/r1/eval.py
116
eval/r1/eval.py
|
|
@ -1,4 +1,5 @@
|
|||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
|
|
@ -6,10 +7,9 @@ from dataclasses import asdict
|
|||
from datetime import datetime
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import requests
|
||||
import aiohttp
|
||||
from eval_config import EvalConfig
|
||||
from requests.exceptions import RequestException
|
||||
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_exponential
|
||||
from tenacity import AsyncRetrying, retry_if_exception_type, stop_after_attempt, wait_exponential
|
||||
|
||||
import reasoning_gym
|
||||
from reasoning_gym.utils import extract_answer
|
||||
|
|
@ -30,9 +30,9 @@ class OpenRouterEvaluator:
|
|||
"X-Title": os.getenv("OR_APP_NAME", "Model Evaluation"),
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
self.semaphore = asyncio.Semaphore(10) # Control concurrency
|
||||
|
||||
def save_results(self, results: List[Dict[str, Any]], dataset, dataset_name) -> Dict[str, Any]:
|
||||
|
||||
file_name = f"{self.output_dir}/{dataset_name}.json"
|
||||
total_score = sum(r["score"] for r in results)
|
||||
|
||||
|
|
@ -45,7 +45,7 @@ class OpenRouterEvaluator:
|
|||
"total_examples": len(results),
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"config": asdict(dataset.config),
|
||||
"results": results, # save results to allow for performance recalculation
|
||||
"results": results,
|
||||
}
|
||||
|
||||
with open(file_name, "w") as f:
|
||||
|
|
@ -53,87 +53,93 @@ class OpenRouterEvaluator:
|
|||
return metrics
|
||||
|
||||
def prepare_messages(self, prompt: str) -> List[Dict[str, str]]:
|
||||
messages = [
|
||||
return {
|
||||
"model": self.model,
|
||||
"messages": [
|
||||
{"role": self.config.developer_role, "content": self.config.developer_prompt},
|
||||
{"role": "user", "content": prompt},
|
||||
]
|
||||
],
|
||||
"provider": {"order": ["Nebius"], "allow_fallbacks": False},
|
||||
}
|
||||
|
||||
async def get_model_response(self, session: aiohttp.ClientSession, prompt: str) -> str:
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"messages": messages,
|
||||
"provider": {"order": ["Nebius"], "allow_fallbacks": False},
|
||||
} # make sure only one provider is used
|
||||
"messages": [
|
||||
{"role": self.config.developer_role, "content": self.config.developer_prompt},
|
||||
{"role": "user", "content": prompt},
|
||||
],
|
||||
}
|
||||
|
||||
return payload
|
||||
async for attempt in AsyncRetrying(
|
||||
stop=stop_after_attempt(20),
|
||||
wait=wait_exponential(multiplier=1, min=1, max=60),
|
||||
retry=retry_if_exception_type(
|
||||
(aiohttp.ClientError, asyncio.TimeoutError, json.JSONDecodeError, ValueError)
|
||||
),
|
||||
):
|
||||
with attempt:
|
||||
async with session.post(self.base_url, json=payload) as response:
|
||||
data = await response.json()
|
||||
|
||||
@retry(
|
||||
retry=retry_if_exception_type(RequestException),
|
||||
stop=stop_after_attempt(5),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=60),
|
||||
)
|
||||
def get_model_response(self, prompt: str) -> str:
|
||||
"""Get response from the model via OpenRouter API."""
|
||||
if not data:
|
||||
raise ValueError("Empty response")
|
||||
|
||||
payload = self.prepare_messages(prompt)
|
||||
try:
|
||||
response = requests.post(self.base_url, headers=self.headers, json=payload, timeout=30)
|
||||
response.raise_for_status()
|
||||
except requests.exceptions.RequestException as e:
|
||||
raise RequestException(
|
||||
f"API request failed: {str(e)}", {"endpoint": self.base_url, "model": self.model}
|
||||
) from e
|
||||
return response.json()["choices"][0]["message"]["content"]
|
||||
if not data.get("choices"):
|
||||
raise ValueError("Missing choices in response")
|
||||
|
||||
def evaluate_datasets(self) -> List[Dict[str, Any]]:
|
||||
"""Evaluate model on multiple datasets with their respective configurations."""
|
||||
all_results = []
|
||||
return data["choices"][0]["message"]["content"]
|
||||
|
||||
for dataset_name in self.config.datasets:
|
||||
self.logger.info(f"\nEvaluating dataset: {dataset_name}")
|
||||
raise Exception("Failed to get valid response after retries")
|
||||
|
||||
# Create dataset with its specific configuration
|
||||
dataset = reasoning_gym.create_dataset(
|
||||
dataset_name, size=self.config.dataset_size, seed=self.config.dataset_seed
|
||||
)
|
||||
results = []
|
||||
|
||||
for i, entry in enumerate(dataset):
|
||||
print(f"On example {i+1} of {len(dataset)}")
|
||||
response = self.get_model_response(entry["question"])
|
||||
async def process_entry(self, session: aiohttp.ClientSession, dataset: Any, entry: Any) -> Dict[str, Any]:
|
||||
"""Process a single entry with concurrency control."""
|
||||
async with self.semaphore:
|
||||
response = await self.get_model_response(session, entry["question"])
|
||||
model_answer = extract_answer(response)
|
||||
|
||||
score = dataset.score_answer(answer=model_answer, entry=entry)
|
||||
print(f"Question: {entry['question']}")
|
||||
|
||||
result = {
|
||||
return {
|
||||
"question": entry["question"],
|
||||
"expected_answer": str(entry["answer"]),
|
||||
"model_answer": model_answer,
|
||||
"score": score,
|
||||
"metadata": str(entry["metadata"]),
|
||||
}
|
||||
results.append(result)
|
||||
|
||||
metrics = self.save_results(results, dataset, dataset_name)
|
||||
async def evaluate_dataset(self, session: aiohttp.ClientSession, dataset_name: str) -> Dict[str, Any]:
|
||||
"""Evaluate a single dataset asynchronously."""
|
||||
self.logger.info(f"\nEvaluating dataset: {dataset_name}")
|
||||
dataset = reasoning_gym.create_dataset(
|
||||
dataset_name, size=self.config.dataset_size, seed=self.config.dataset_seed
|
||||
)
|
||||
|
||||
all_results.append({"metrics": metrics, "results": results})
|
||||
tasks = [self.process_entry(session, dataset, entry) for entry in dataset]
|
||||
results = await asyncio.gather(*tasks)
|
||||
return self.save_results(results, dataset, dataset_name)
|
||||
|
||||
return all_results
|
||||
async def evaluate_datasets(self) -> List[Dict[str, Any]]:
|
||||
"""Main async evaluation entry point."""
|
||||
all_results = []
|
||||
async with aiohttp.ClientSession(headers=self.headers) as session:
|
||||
return await asyncio.gather(*(self.evaluate_dataset(session, name) for name in self.config.datasets))
|
||||
|
||||
|
||||
def main():
|
||||
async def async_main():
|
||||
parser = argparse.ArgumentParser(description="Evaluate models on reasoning datasets")
|
||||
parser.add_argument("--yaml", required=True, help="Path to YAML configuration file")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
config = EvalConfig.from_yaml(args.yaml)
|
||||
evaluator = OpenRouterEvaluator(model=config.model, config=config)
|
||||
results = await evaluator.evaluate_datasets()
|
||||
|
||||
output_dir = f"{config.eval_dir}/{config.category}"
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
evaluator = OpenRouterEvaluator(model=config.model, config=config)
|
||||
all_results = evaluator.evaluate_datasets()
|
||||
|
||||
with open(f"{output_dir}/summary.json", "w") as f:
|
||||
json.dump(all_results, f, indent=2)
|
||||
json.dump(results, f, indent=2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
asyncio.run(async_main())
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
model: deepseek/deepseek-r1
|
||||
category: algorithmic
|
||||
datasets:
|
||||
- base_conversion
|
||||
- binary_matrix
|
||||
- caesar_cipher
|
||||
- group_anagrams
|
||||
|
|
|
|||
61
eval/results/summary_openai_o1_20250212_103017.json
Normal file
61
eval/results/summary_openai_o1_20250212_103017.json
Normal file
|
|
@ -0,0 +1,61 @@
|
|||
[
|
||||
{
|
||||
"dataset_name": "letter_counting",
|
||||
"model": "openai/o1",
|
||||
"average_score": 0.99,
|
||||
"total_examples": 50,
|
||||
"timestamp": "2025-02-12T10:26:39.897674",
|
||||
"config": {
|
||||
"min_words": 5,
|
||||
"max_words": 15,
|
||||
"size": 50,
|
||||
"seed": 42
|
||||
}
|
||||
},
|
||||
{
|
||||
"dataset_name": "propositional_logic",
|
||||
"model": "openai/o1",
|
||||
"average_score": 0.010000000000000004,
|
||||
"total_examples": 50,
|
||||
"timestamp": "2025-02-12T10:27:45.054740",
|
||||
"config": {
|
||||
"size": 50,
|
||||
"seed": 42
|
||||
}
|
||||
},
|
||||
{
|
||||
"dataset_name": "leg_counting",
|
||||
"model": "openai/o1",
|
||||
"average_score": 0.802,
|
||||
"total_examples": 50,
|
||||
"timestamp": "2025-02-12T10:28:06.199253",
|
||||
"config": {
|
||||
"min_animals": 3,
|
||||
"max_animals": 8,
|
||||
"size": 50,
|
||||
"seed": 42
|
||||
}
|
||||
},
|
||||
{
|
||||
"dataset_name": "group_anagrams",
|
||||
"model": "openai/o1",
|
||||
"average_score": 0.94,
|
||||
"total_examples": 50,
|
||||
"timestamp": "2025-02-12T10:30:02.084562",
|
||||
"config": {
|
||||
"size": 50,
|
||||
"seed": 42
|
||||
}
|
||||
},
|
||||
{
|
||||
"dataset_name": "spell_backward",
|
||||
"model": "openai/o1",
|
||||
"average_score": 0.9802000000000001,
|
||||
"total_examples": 50,
|
||||
"timestamp": "2025-02-12T10:30:17.839014",
|
||||
"config": {
|
||||
"size": 50,
|
||||
"seed": 42
|
||||
}
|
||||
}
|
||||
]
|
||||
|
|
@ -70,12 +70,12 @@ class PolynomialEquationsDataset(ProceduralDataset):
|
|||
|
||||
Returns:
|
||||
A dict with:
|
||||
- question: str (e.g. "Solve the polynomial equation: 2*x^2 - 3*x + 1 = 0")
|
||||
- answer: str (the sorted list of real solutions, e.g. "[0.5, 1.0]")
|
||||
- question: str (e.g. "Solve the polynomial equation: 2*x**2 - 3*x + 1 = 0")
|
||||
- answer: str (the sorted list of real solutions, e.g. "0.5, 1.0")
|
||||
- metadata: dict with details (polynomial_expr, degree, etc.)
|
||||
"""
|
||||
rng = random.Random(self.seed + idx)
|
||||
|
||||
for _ in range(8):
|
||||
# Get variable and generate polynomial equation in standard form
|
||||
variable = self._get_variable(rng)
|
||||
degree = rng.randint(self.config.min_degree, self.config.max_degree)
|
||||
|
|
@ -90,8 +90,12 @@ class PolynomialEquationsDataset(ProceduralDataset):
|
|||
if sol.is_real:
|
||||
# Evaluate symbolic solution to a floating approximation
|
||||
real_solutions.append(float(sol.evalf()))
|
||||
|
||||
if len(real_solutions) > 0:
|
||||
real_solutions.sort()
|
||||
answer_str = str(real_solutions)
|
||||
break
|
||||
|
||||
answer_str = ", ".join(str(x) for x in real_solutions)
|
||||
|
||||
return {
|
||||
"question": rng.choice(self._prompt_templates).format(
|
||||
|
|
@ -109,7 +113,7 @@ class PolynomialEquationsDataset(ProceduralDataset):
|
|||
|
||||
def _get_variable(self, rng: random.Random) -> str:
|
||||
"""Get a random lowercase variable name"""
|
||||
return rng.choice(string.ascii_lowercase)
|
||||
return rng.choice("abcdefghklmnopqrstuvwxyz") # remove ij to avoid confusion with complex numbers
|
||||
|
||||
def _generate_polynomial_expr(self, rng: random.Random, variable: Symbol, degree: int):
|
||||
"""
|
||||
|
|
@ -202,6 +206,9 @@ class PolynomialEquationsDataset(ProceduralDataset):
|
|||
oracle_solutions = self._parse_score_to_list(entry["answer"]) # Parse oracle solutions
|
||||
predicted_solutions = self._parse_score_to_list(answer) # Parse predicted solutions
|
||||
|
||||
if len(oracle_solutions) == 0 and len(predicted_solutions) == 0:
|
||||
return 1.0
|
||||
|
||||
total_reward = 0.0
|
||||
matched_solutions = 0
|
||||
extra_solutions = 0
|
||||
|
|
|
|||
|
|
@ -19,12 +19,14 @@ from .manipulate_matrix import ManipulateMatrixConfig, ManipulateMatrixDataset
|
|||
from .number_filtering import NumberFilteringConfig, NumberFilteringDataset
|
||||
from .number_sorting import NumberSortingConfig, NumberSortingDataset
|
||||
from .palindrome_generation import PalindromeConfig, PalindromeDataset
|
||||
from .pool_matrix import PoolMatrixConfig, PoolMatrixDataset
|
||||
from .ransom_note import RansomNoteConfig, RansomNoteDataset
|
||||
from .rotate_matrix import RotateMatrixConfig, RotateMatrixDataset
|
||||
from .sentence_reordering import SentenceReorderingConfig, SentenceReorderingDataset
|
||||
from .spell_backward import SpellBackwardConfig, SpellBackwardDataset
|
||||
from .spiral_matrix import SpiralMatrixConfig, SpiralMatrixDataset
|
||||
from .string_insertion import StringInsertionConfig, StringInsertionDataset
|
||||
from .string_manipulation import StringManipulationConfig, StringManipulationDataset
|
||||
from .word_ladder import WordLadderConfig, WordLadderDataset
|
||||
from .word_sequence_reversal import WordSequenceReversalConfig, WordSequenceReversalDataset
|
||||
from .word_sorting import TextTransformation, WordSortingConfig, WordSortingDataset
|
||||
|
|
@ -69,10 +71,14 @@ __all__ = [
|
|||
"ManipulateMatrixDataset",
|
||||
"BinaryMatrixConfig",
|
||||
"BinaryMatrixDataset",
|
||||
"PoolMatrixConfig",
|
||||
"PoolMatrixDataset",
|
||||
"ABConfig",
|
||||
"ABDataset",
|
||||
"CountPrimesConfig",
|
||||
"CountPrimesDataset",
|
||||
"StringInsertionConfig",
|
||||
"StringInsertionDataset",
|
||||
"StringManipulationConfig",
|
||||
"StringManipulationDataset",
|
||||
]
|
||||
|
|
|
|||
142
reasoning_gym/algorithmic/pool_matrix.py
Normal file
142
reasoning_gym/algorithmic/pool_matrix.py
Normal file
|
|
@ -0,0 +1,142 @@
|
|||
"""Perform average / max pooling on a matrix"""
|
||||
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
from random import Random
|
||||
from typing import Dict, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
QUESTION_TEMPLATE = """Your job is to perform max/average pooling on the given matrix.
|
||||
The stride is equal to the kernel size, meaning there is no overlap between the pooling regions.
|
||||
|
||||
Example 1:
|
||||
- Input: Perform max pooling on the following matrix with a kernel size of 2:
|
||||
1 2 3 4
|
||||
5 6 7 8
|
||||
9 10 11 12
|
||||
13 14 15 16
|
||||
- Output:
|
||||
6 8
|
||||
14 16
|
||||
|
||||
Example 2:
|
||||
- Input: Perform average pooling on the following matrix with a kernel size of 2:
|
||||
1 2 3 4
|
||||
5 6 7 8
|
||||
9 10 11 12
|
||||
13 14 15 16
|
||||
- Output:
|
||||
3.5 5.5
|
||||
11.5 13.5
|
||||
|
||||
Perform {pool_type} pooling on the following matrix with a kernel size of {pool_size}:
|
||||
{matrix}
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class PoolMatrixConfig:
|
||||
"""Configuration for Pool Matrix dataset generation"""
|
||||
|
||||
min_rows: int = 2 # Minimum rows of the matrix
|
||||
min_cols: int = 2 # Minimum columns of the matrix
|
||||
max_rows: int = 10 # Maximum rows of the matrix
|
||||
max_cols: int = 10 # Maximum columns of the matrix
|
||||
max_pool_size: int = 3 # Maximum pooling size
|
||||
|
||||
size: int = 500 # Virtual dataset size
|
||||
seed: Optional[int] = None
|
||||
|
||||
def validate(self):
|
||||
"""Validate configuration parameters"""
|
||||
assert 2 <= self.min_rows, "min_rows must be at least 2"
|
||||
assert 2 <= self.min_cols, "min_cols must be at least 2"
|
||||
assert self.min_rows <= self.max_rows, "max_rows must be at least min_rows"
|
||||
assert self.min_cols <= self.max_cols, "max_cols must be at least min_cols"
|
||||
assert 1 <= self.max_pool_size, "max_pool_size must be at least 1"
|
||||
|
||||
|
||||
class PoolMatrixDataset(ProceduralDataset):
|
||||
"""Generates Pool Matrix exercises with configurable difficulty"""
|
||||
|
||||
def __init__(self, config: PoolMatrixConfig):
|
||||
super().__init__(config=config, seed=config.seed, size=config.size)
|
||||
|
||||
def _get_matrix(self, rng: Random) -> np.ndarray:
|
||||
"""Generate a random matrix"""
|
||||
rows = rng.randint(self.config.min_rows, self.config.max_rows)
|
||||
cols = rng.randint(self.config.min_rows, self.config.max_cols)
|
||||
return np.random.randint(0, 10, (rows, cols))
|
||||
|
||||
def _matrix_to_str(self, matrix: np.ndarray) -> str:
|
||||
"""Get a string representation of the matrix"""
|
||||
return "\n".join(" ".join(str(round(x, 2)) for x in row) for row in matrix)
|
||||
|
||||
def _max_pool(self, matrix: np.ndarray, pool_size: int) -> np.ndarray:
|
||||
"""Perform max pooling on the matrix"""
|
||||
rows, cols = matrix.shape
|
||||
return np.array(
|
||||
[
|
||||
[np.max(matrix[i : i + pool_size, j : j + pool_size]) for j in range(0, cols, pool_size)]
|
||||
for i in range(0, rows, pool_size)
|
||||
]
|
||||
)
|
||||
|
||||
def _average_pool(self, matrix: np.ndarray, pool_size: int) -> np.ndarray:
|
||||
"""Perform average pooling on the matrix"""
|
||||
rows, cols = matrix.shape
|
||||
return np.array(
|
||||
[
|
||||
[np.mean(matrix[i : i + pool_size, j : j + pool_size]) for j in range(0, cols, pool_size)]
|
||||
for i in range(0, rows, pool_size)
|
||||
]
|
||||
)
|
||||
|
||||
def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float:
|
||||
"""Score the answer based on the metadata"""
|
||||
|
||||
reward = 0.0
|
||||
try:
|
||||
if answer is not None:
|
||||
oracle_answer = np.array(entry["answer"])
|
||||
answer = np.array(answer)
|
||||
if oracle_answer.shape == answer.shape and np.allclose(oracle_answer, answer):
|
||||
reward = 1.0
|
||||
if oracle_answer.shape == answer.shape:
|
||||
reward = 0.1
|
||||
else:
|
||||
reward = 0.01
|
||||
except:
|
||||
print("Error in scoring answer for Pool Matrix")
|
||||
return reward
|
||||
|
||||
def __getitem__(self, idx: int) -> dict:
|
||||
"""Generate a single Pool Matrix question"""
|
||||
rng = Random(self.seed + idx)
|
||||
np.random.seed(self.seed + idx)
|
||||
|
||||
matrix = self._get_matrix(rng)
|
||||
matrix_str = self._matrix_to_str(matrix)
|
||||
|
||||
pool_size = rng.randint(1, self.config.max_pool_size)
|
||||
pool_type = rng.choice(["average", "max"])
|
||||
|
||||
answer = self._average_pool(matrix, pool_size) if pool_type == "average" else self._max_pool(matrix, pool_size)
|
||||
answer_str = self._matrix_to_str(answer)
|
||||
|
||||
return {
|
||||
"question": QUESTION_TEMPLATE.format(matrix=matrix_str, pool_type=pool_type, pool_size=pool_size),
|
||||
"answer": answer_str,
|
||||
"metadata": {
|
||||
"matrix": matrix.tolist(),
|
||||
"pool_type": pool_type,
|
||||
"pool_size": pool_size,
|
||||
"solution": answer.tolist(),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
register_dataset("pool_matrix", PoolMatrixDataset, PoolMatrixConfig)
|
||||
199
reasoning_gym/algorithmic/string_manipulation.py
Normal file
199
reasoning_gym/algorithmic/string_manipulation.py
Normal file
|
|
@ -0,0 +1,199 @@
|
|||
"""Manipulate a string according to a set of rules
|
||||
|
||||
https://github.com/yongchao98/CodeSteer-v1.0/blob/main/create_dataset/create_dataset_string_deletion_and_modification.py
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from random import Random
|
||||
from typing import Optional
|
||||
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
QUESTION_TEMPLATE = """Your job is to repeatedly transform a string according to a set of rules until no further transformations can be performed, or a state is repeated.
|
||||
|
||||
Evaluate the following rules in order, and apply the first applicable rule to the string:
|
||||
{rules}
|
||||
|
||||
Once you have applied a rule, repeat the process with the new string until no further transformations can be performed (i.e. the string doesn't change), or a state is repeated.
|
||||
If a state is repeated, the process is terminated, and the repeated state is discarded (i.e. is not considered as the final answer) and the state before the repeated state is considered as the final answer.
|
||||
|
||||
Example:
|
||||
- Input:
|
||||
- String: abbac
|
||||
- Rules:
|
||||
1. If the string prefix is 'ab', replace it with 'ca'.
|
||||
2. If the string prefix is 'ca', replace it with 'bb' and append 'c' to the end.
|
||||
3. If the string ends with 'aa', replace it with 'cc'.
|
||||
- Output: bbbacc
|
||||
- Explanation:
|
||||
- In the first iteration, rule 1 is applied to the string abbac, resulting in cabac
|
||||
- In the second interation, rule 1 doesn't apply, but rule 2 is applied to the string cabac, resulting in bbbacc
|
||||
- In the third iteration, none of the rules (1, 2, 3) apply, so the process is terminated, and the final answer is bbbacc
|
||||
|
||||
Transform the following string according to the above list of rules:
|
||||
{string}
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class StringManipulationConfig:
|
||||
"""Configuration for String Insertion dataset generation"""
|
||||
|
||||
min_string_length: int = 5 # Minimum string length
|
||||
max_string_length: int = 20 # Maximum string length
|
||||
min_num_rules: int = 3 # Minimum number of rules/transforms
|
||||
max_num_rules: int = 8 # Maximum number of rules/transforms
|
||||
|
||||
size: int = 500 # Virtual dataset size
|
||||
seed: Optional[int] = None
|
||||
|
||||
def validate(self):
|
||||
"""Validate configuration parameters"""
|
||||
assert 5 <= self.min_string_length, "Minimum string length should be at least 5"
|
||||
assert self.min_string_length <= self.max_string_length, "Minimum string length should be less than maximum"
|
||||
assert 3 <= self.min_num_rules, "Minimum number of rules should be at least 3"
|
||||
assert self.min_num_rules <= self.max_num_rules, "Minimum number of rules should be less than maximum"
|
||||
|
||||
|
||||
class StringManipulationDataset(ProceduralDataset):
|
||||
"""Generates String Insertion exercises with configurable difficulty"""
|
||||
|
||||
def __init__(self, config: StringManipulationConfig):
|
||||
super().__init__(config=config, seed=config.seed, size=config.size)
|
||||
self.vocabulary = ["a", "b", "c"]
|
||||
self.rules = [
|
||||
(
|
||||
"If the string prefix is 'ab', replace it with 'ca'.",
|
||||
lambda s: ("ca" + s[2:], 1) if s.startswith("ab") else (s, 0),
|
||||
),
|
||||
(
|
||||
"If the string suffix is 'ac', replace it with 'cb'.",
|
||||
lambda s: (s[:-2] + "cb", 2) if s.endswith("ac") else (s, 0),
|
||||
),
|
||||
(
|
||||
"If the string prefix is 'bc', delete the first two characters and append 'aa' to the end.",
|
||||
lambda s: (s[2:] + "aa", 3) if s.startswith("bc") else (s, 0),
|
||||
),
|
||||
(
|
||||
"If the string suffix is 'bb', delete the last two characters.",
|
||||
lambda s: (s[:-2], 4) if s.endswith("bb") else (s, 0),
|
||||
),
|
||||
(
|
||||
"If the string prefix is 'cb', replace it with 'aa' and delete the last character.",
|
||||
lambda s: ("aa" + s[2:-1], 5) if s.startswith("cb") and len(s) > 1 else (s, 0),
|
||||
),
|
||||
(
|
||||
"If the string prefix is 'ca', replace it with 'bb' and append 'c' to the end.",
|
||||
lambda s: ("bb" + s[2:] + "c", 6) if s.startswith("ca") else (s, 0),
|
||||
),
|
||||
(
|
||||
"If the string suffix is 'cc', replace it with 'b' and prepend 'a' to the start.",
|
||||
lambda s: ("a" + s[:-2] + "b", 7) if s.endswith("cc") else (s, 0),
|
||||
),
|
||||
(
|
||||
"If the string prefix is 'aa', remove the first character.",
|
||||
lambda s: (s[1:], 8) if s.startswith("aa") else (s, 0),
|
||||
),
|
||||
(
|
||||
"If the string contains 'abc', replace the first occurrence with 'cab'.",
|
||||
lambda s: (s.replace("abc", "cab", 1), 9) if "abc" in s else (s, 0),
|
||||
),
|
||||
(
|
||||
"If the string contains 'bca', delete the first occurrence entirely.",
|
||||
lambda s: (s.replace("bca", "", 1), 10) if "bca" in s else (s, 0),
|
||||
),
|
||||
(
|
||||
"If the string ends with 'ba', replace it with 'ab'.",
|
||||
lambda s: (s[:-2] + "ab", 11) if s.endswith("ba") else (s, 0),
|
||||
),
|
||||
(
|
||||
"If the string starts with 'cc', remove the first two characters.",
|
||||
lambda s: (s[2:], 12) if s.startswith("cc") else (s, 0),
|
||||
),
|
||||
(
|
||||
"If the string contains 'acb', replace the first occurrence with its reverse ('bca').",
|
||||
lambda s: (s.replace("acb", "bca", 1), 13) if "acb" in s else (s, 0),
|
||||
),
|
||||
(
|
||||
"If the string ends with 'ca', remove the last character.",
|
||||
lambda s: (s[:-1], 14) if s.endswith("ca") and len(s) > 0 else (s, 0),
|
||||
),
|
||||
(
|
||||
"If the string starts with 'bb', remove the second character.",
|
||||
lambda s: (s[0] + s[2:], 15) if s.startswith("bb") and len(s) >= 2 else (s, 0),
|
||||
),
|
||||
(
|
||||
"If the string ends with 'aa', replace it with 'cc'.",
|
||||
lambda s: (s[:-2] + "cc", 16) if s.endswith("aa") else (s, 0),
|
||||
),
|
||||
(
|
||||
"If the string contains 'ca' (not at the start), remove the first occurrence found after the first character.",
|
||||
lambda s: (s[:idx] + s[idx + 2 :], 17) if (idx := s.find("ca", 1)) != -1 else (s, 0),
|
||||
),
|
||||
(
|
||||
"If the string contains an even number of 'b's (and at least one 'b'), append 'ab' at the end.",
|
||||
lambda s: (s + "ab", 18) if (s.count("b") > 0 and s.count("b") % 2 == 0) else (s, 0),
|
||||
),
|
||||
(
|
||||
"If the string length is greater than 15, remove the middle character.",
|
||||
lambda s: (s[: len(s) // 2] + s[len(s) // 2 + 1 :], 19) if len(s) > 15 else (s, 0),
|
||||
),
|
||||
(
|
||||
"If the string starts with 'ac', replace the first two characters with 'zz'.",
|
||||
lambda s: ("zz" + s[2:], 20) if s.startswith("ac") else (s, 0),
|
||||
),
|
||||
]
|
||||
|
||||
def _apply_rule(self, string: str, selected_rules: list[tuple[str, callable]]) -> tuple[str, int]:
|
||||
"""
|
||||
Apply the first applicable rule from the list of selected rules.
|
||||
Returns a tuple containing the modified string and the rule index (1-based) that was applied.
|
||||
If no rule is applicable, returns (s, 0).
|
||||
"""
|
||||
for _, rule_fn in selected_rules:
|
||||
new_string, op_idx = rule_fn(string)
|
||||
if op_idx != 0:
|
||||
return new_string, op_idx
|
||||
return string, 0
|
||||
|
||||
def _get_all_transforms(self, string: str, selected_rules: list[tuple[str, callable]]) -> list[str]:
|
||||
"""
|
||||
Repeatedly apply transformation rules to a string until no further transformations can be performed,
|
||||
or a state is repeated. If a state is repeated, the process is terminated, and the state is not added to the list.
|
||||
Returns a list of string states from the initial string to the final state (i.e. the desired answer).
|
||||
"""
|
||||
states = [string]
|
||||
while True:
|
||||
new_string, op_idx = self._apply_rule(states[-1], selected_rules)
|
||||
if op_idx == 0 or new_string in states:
|
||||
break
|
||||
states.append(new_string)
|
||||
return states
|
||||
|
||||
def __getitem__(self, idx: int) -> dict:
|
||||
"""Generate a single String Insertion question"""
|
||||
rng = Random(self.seed + idx)
|
||||
|
||||
string_length = rng.randint(self.config.min_string_length, self.config.max_string_length)
|
||||
string = "".join(rng.choice(self.vocabulary) for _ in range(string_length))
|
||||
|
||||
num_rules = rng.randint(self.config.min_num_rules, self.config.max_num_rules)
|
||||
selected_rules = rng.sample(self.rules, num_rules)
|
||||
rules_str = "\n".join(f"{i+1}. {rule}" for i, (rule, _) in enumerate(selected_rules))
|
||||
|
||||
states = self._get_all_transforms(string, selected_rules)
|
||||
answer = states[-1]
|
||||
|
||||
return {
|
||||
"question": QUESTION_TEMPLATE.format(string=string, rules=rules_str),
|
||||
"answer": str(answer),
|
||||
"metadata": {
|
||||
"string": string,
|
||||
"solution": answer,
|
||||
"states": states,
|
||||
"selected_rules": [rule for rule, _ in selected_rules],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
register_dataset("string_manipulation", StringManipulationDataset, StringManipulationConfig)
|
||||
|
|
@ -119,7 +119,6 @@ class ChainSumCurriculum(BaseCurriculum):
|
|||
|
||||
# Define attributes
|
||||
self._define_attributes(
|
||||
(
|
||||
RangeAttributeDefinition(
|
||||
name="num_terms",
|
||||
levels=[2, 3, 4, 5],
|
||||
|
|
@ -141,7 +140,6 @@ class ChainSumCurriculum(BaseCurriculum):
|
|||
upper_field_name="max_digits",
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
# Register the dataset
|
||||
|
|
|
|||
|
|
@ -34,7 +34,7 @@ class BaseCurriculum:
|
|||
raise KeyError(f"Attribute '{self.name}.{attr_name}' does not exist")
|
||||
return self._attributes[attr_name]
|
||||
|
||||
def _define_attributes(self, attrs: Iterable[AttributeDefinition]) -> None:
|
||||
def _define_attributes(self, *attrs: tuple[AttributeDefinition, ...]) -> None:
|
||||
for attr in attrs:
|
||||
if attr.name in self.attributes:
|
||||
raise RuntimeError(f"Attribute with name {attr.name} is already defined.")
|
||||
|
|
|
|||
|
|
@ -4,12 +4,15 @@ from decimal import Decimal, InvalidOperation
|
|||
from fractions import Fraction
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
# DeepSeek Zero system prompt
|
||||
SYSTEM_PROMPTS = {
|
||||
"DeepSeekZero": """A conversation between User and Assistant. The user asks a question, and the Assistant solves it.
|
||||
The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think> reasoning process here </think>
|
||||
<answer> answer here </answer>
|
||||
"""
|
||||
""",
|
||||
"default": """Given a problem, your task is to answer the question by thinking step-by-step in a clear and specific manner.
|
||||
Once you have thought about the reasoning process, provide the answer in the following format:
|
||||
<answer> answer here </answer>
|
||||
""",
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -138,3 +138,12 @@ def test_polynomial_solutions_score_answer(oracle_answer, predicted_answer, expe
|
|||
|
||||
actual_reward = ds.score_answer(predicted_answer, {"answer": oracle_answer})
|
||||
assert actual_reward == pytest.approx(expected_reward, rel=1e-3) # Fuzzy comparison for floats
|
||||
|
||||
|
||||
def test_polynomial_perfect_score():
|
||||
"""Test that scoring an item's own answer gives a perfect score"""
|
||||
cfg = PolynomialEquationsConfig(seed=42, size=10)
|
||||
ds = PolynomialEquationsDataset(cfg)
|
||||
|
||||
for item in ds:
|
||||
assert ds.score_answer(item["answer"], item) == 1.0
|
||||
|
|
|
|||
138
tests/test_pool_matrix.py
Normal file
138
tests/test_pool_matrix.py
Normal file
|
|
@ -0,0 +1,138 @@
|
|||
"""Tests for Pool Matrix questions generation"""
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from reasoning_gym.algorithmic.pool_matrix import PoolMatrixConfig, PoolMatrixDataset
|
||||
|
||||
|
||||
def test_pool_matrix_config_validation():
|
||||
"""Test that invalid configs raise appropriate errors"""
|
||||
|
||||
for field in ["min_rows", "min_cols", "max_rows", "max_cols"]:
|
||||
with pytest.raises(AssertionError):
|
||||
config = PoolMatrixConfig(**{field: -1}) # Negative not allowed
|
||||
config.validate()
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
config = PoolMatrixConfig(**{field: 0}) # Zero not allowed
|
||||
config.validate()
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
config = PoolMatrixConfig(**{field: 1}) # One not allowed
|
||||
config.validate()
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
config = PoolMatrixConfig(max_pool_size=-1) # Negative not allowed
|
||||
config.validate()
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
config = PoolMatrixConfig(max_pool_size=0) # Zero not allowed
|
||||
config.validate()
|
||||
|
||||
|
||||
def test_pool_matrix_dataset_deterministic():
|
||||
"""Test that dataset generates same items with same seed"""
|
||||
config = PoolMatrixConfig(seed=42, size=10)
|
||||
dataset1 = PoolMatrixDataset(config)
|
||||
dataset2 = PoolMatrixDataset(config)
|
||||
|
||||
for i in range(len(dataset1)):
|
||||
assert dataset1[i] == dataset2[i]
|
||||
|
||||
|
||||
def test_pool_matrix_dataset_items():
|
||||
"""Test basic properties of generated items"""
|
||||
config = PoolMatrixConfig(max_rows=10, max_cols=10, max_pool_size=3, size=10, seed=42)
|
||||
dataset = PoolMatrixDataset(config)
|
||||
|
||||
for i in range(len(dataset)):
|
||||
item = dataset[i]
|
||||
# Check item structure
|
||||
assert isinstance(item, dict)
|
||||
assert "question" in item
|
||||
assert "answer" in item
|
||||
assert "metadata" in item
|
||||
|
||||
# Check metadata
|
||||
assert "matrix" in item["metadata"]
|
||||
assert "pool_type" in item["metadata"]
|
||||
assert "pool_size" in item["metadata"]
|
||||
assert "solution" in item["metadata"]
|
||||
|
||||
matrix = item["metadata"]["matrix"]
|
||||
pool_type = item["metadata"]["pool_type"]
|
||||
pool_size = item["metadata"]["pool_size"]
|
||||
solution = item["metadata"]["solution"]
|
||||
|
||||
# Verify dimensions
|
||||
assert len(matrix) <= config.max_rows
|
||||
assert all(len(row) <= config.max_cols for row in matrix)
|
||||
assert len(solution) <= len(matrix)
|
||||
assert len(solution[0]) <= len(matrix[0])
|
||||
assert pool_size <= config.max_pool_size
|
||||
assert pool_type in ["average", "max"]
|
||||
|
||||
|
||||
def test_pool_matrix_dataset_iteration():
|
||||
"""Test that iteration respects dataset size"""
|
||||
config = PoolMatrixConfig(size=5, seed=42)
|
||||
dataset = PoolMatrixDataset(config)
|
||||
|
||||
items = list(dataset)
|
||||
assert len(items) == config.size
|
||||
|
||||
# Test multiple iterations yield same items
|
||||
assert items == list(dataset)
|
||||
|
||||
|
||||
def test_pool_matrix_answer():
|
||||
"""Test the pooling methods"""
|
||||
config = PoolMatrixConfig(seed=42)
|
||||
dataset = PoolMatrixDataset(config)
|
||||
|
||||
# 1. Max pooling
|
||||
matrix = np.array([[1, 2], [3, 4]])
|
||||
assert np.allclose(dataset._max_pool(matrix, 2), np.array([[4]]))
|
||||
|
||||
matrix = np.array(
|
||||
[
|
||||
[1, 2, 3, 4],
|
||||
[5, 6, 7, 8],
|
||||
[9, 10, 11, 12],
|
||||
]
|
||||
)
|
||||
assert np.allclose(dataset._max_pool(matrix, 2), np.array([[6, 8], [10, 12]]))
|
||||
|
||||
matrix = np.array(
|
||||
[
|
||||
[1, 2, 3, 4],
|
||||
[5, 6, 7, 8],
|
||||
[9, 10, 11, 12],
|
||||
[13, 14, 15, 16],
|
||||
]
|
||||
)
|
||||
assert np.allclose(dataset._max_pool(matrix, 2), np.array([[6, 8], [14, 16]]))
|
||||
|
||||
# 2. Average pooling
|
||||
matrix = np.array([[1, 2], [3, 4]])
|
||||
assert np.allclose(dataset._average_pool(matrix, 2), np.array([[2.5]]))
|
||||
|
||||
matrix = np.array(
|
||||
[
|
||||
[1, 2, 3, 4],
|
||||
[5, 6, 7, 8],
|
||||
[9, 10, 11, 12],
|
||||
]
|
||||
)
|
||||
assert np.allclose(dataset._average_pool(matrix, 2), np.array([[3.5, 5.5], [9.5, 11.5]]))
|
||||
|
||||
matrix = np.array(
|
||||
[
|
||||
[1, 2, 3, 4],
|
||||
[5, 6, 7, 8],
|
||||
[9, 10, 11, 12],
|
||||
[13, 14, 15, 16],
|
||||
]
|
||||
)
|
||||
assert np.allclose(dataset._average_pool(matrix, 2), np.array([[3.5, 5.5], [11.5, 13.5]]))
|
||||
257
tests/test_string_manipulation.py
Normal file
257
tests/test_string_manipulation.py
Normal file
|
|
@ -0,0 +1,257 @@
|
|||
"""Tests for String Manipulation questions generation"""
|
||||
|
||||
import pytest
|
||||
|
||||
from reasoning_gym.algorithmic.string_manipulation import StringManipulationConfig, StringManipulationDataset
|
||||
|
||||
|
||||
def test_string_manipulation_config_validation():
|
||||
"""Test that invalid configs raise appropriate errors"""
|
||||
with pytest.raises(AssertionError):
|
||||
config = StringManipulationConfig(min_string_length=4) # Minimum string length should be at least 5
|
||||
config.validate()
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
config = StringManipulationConfig(min_string_length=10, max_string_length=7) # Max must be greater than min
|
||||
config.validate()
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
config = StringManipulationConfig(min_num_rules=2) # Min number of rules should be at least 3
|
||||
config.validate()
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
config = StringManipulationConfig(min_num_rules=5, max_num_rules=3) # Max must be greater than min
|
||||
config.validate()
|
||||
|
||||
|
||||
def test_string_manipulation_dataset_deterministic():
|
||||
"""Test that dataset generates same items with same seed"""
|
||||
config = StringManipulationConfig(seed=42, size=10)
|
||||
dataset1 = StringManipulationDataset(config)
|
||||
dataset2 = StringManipulationDataset(config)
|
||||
|
||||
for i in range(len(dataset1)):
|
||||
assert dataset1[i] == dataset2[i]
|
||||
|
||||
|
||||
def test_string_manipulation_dataset_items():
|
||||
"""Test basic properties of generated items"""
|
||||
config = StringManipulationConfig(
|
||||
min_string_length=7, max_string_length=25, min_num_rules=5, max_num_rules=12, size=10, seed=42
|
||||
)
|
||||
dataset = StringManipulationDataset(config)
|
||||
|
||||
for i in range(len(dataset)):
|
||||
item = dataset[i]
|
||||
# Check item structure
|
||||
assert isinstance(item, dict)
|
||||
assert "question" in item
|
||||
assert "answer" in item
|
||||
assert "metadata" in item
|
||||
|
||||
# Check metadata
|
||||
assert "string" in item["metadata"]
|
||||
assert "states" in item["metadata"]
|
||||
# assert "selected_rules" in item["metadata"]
|
||||
assert "solution" in item["metadata"]
|
||||
|
||||
string = item["metadata"]["string"]
|
||||
solution = item["metadata"]["solution"]
|
||||
states = item["metadata"]["states"]
|
||||
selected_rules = item["metadata"]["selected_rules"]
|
||||
|
||||
# Verify dimensions
|
||||
assert config.min_string_length <= len(string) <= config.max_string_length
|
||||
assert config.min_num_rules <= len(selected_rules) <= config.max_num_rules
|
||||
assert len(states) >= 1
|
||||
assert solution == states[-1]
|
||||
|
||||
|
||||
def test_string_manipulation_dataset_iteration():
|
||||
"""Test that iteration respects dataset size"""
|
||||
config = StringManipulationConfig(size=5, seed=42)
|
||||
dataset = StringManipulationDataset(config)
|
||||
|
||||
items = list(dataset)
|
||||
assert len(items) == config.size
|
||||
|
||||
# Test multiple iterations yield same items
|
||||
assert items == list(dataset)
|
||||
|
||||
|
||||
def test_string_manipulation_answer():
|
||||
"""Test the method for getting the answer"""
|
||||
config = StringManipulationConfig(seed=42)
|
||||
dataset = StringManipulationDataset(config)
|
||||
|
||||
rules = [
|
||||
(
|
||||
"If the string prefix is 'ab', replace it with 'ca'.",
|
||||
lambda s: ("ca" + s[2:], 1) if s.startswith("ab") else (s, 0),
|
||||
)
|
||||
]
|
||||
assert dataset._get_all_transforms("abbbab", rules)[-1] == "cabbab"
|
||||
|
||||
rules = [
|
||||
(
|
||||
"If the string suffix is 'ac', replace it with 'cb'.",
|
||||
lambda s: (s[:-2] + "cb", 2) if s.endswith("ac") else (s, 0),
|
||||
),
|
||||
]
|
||||
assert dataset._get_all_transforms("abbbac", rules)[-1] == "abbbcb"
|
||||
|
||||
rules = [
|
||||
(
|
||||
"If the string prefix is 'bc', delete the first two characters and append 'aa' to the end.",
|
||||
lambda s: (s[2:] + "aa", 3) if s.startswith("bc") else (s, 0),
|
||||
),
|
||||
]
|
||||
assert dataset._get_all_transforms("bcabbb", rules)[-1] == "abbbaa"
|
||||
|
||||
rules = [
|
||||
(
|
||||
"If the string suffix is 'bb', delete the last two characters.",
|
||||
lambda s: (s[:-2], 4) if s.endswith("bb") else (s, 0),
|
||||
),
|
||||
]
|
||||
assert dataset._get_all_transforms("abbbabb", rules)[-1] == "abbba"
|
||||
|
||||
rules = [
|
||||
(
|
||||
"If the string prefix is 'cb', replace it with 'aa' and delete the last character.",
|
||||
lambda s: ("aa" + s[2:-1], 5) if s.startswith("cb") and len(s) > 1 else (s, 0),
|
||||
)
|
||||
]
|
||||
assert dataset._get_all_transforms("cbabbb", rules)[-1] == "aaabb"
|
||||
|
||||
rules = [
|
||||
(
|
||||
"If the string prefix is 'ca', replace it with 'bb' and append 'c' to the end.",
|
||||
lambda s: ("bb" + s[2:] + "c", 6) if s.startswith("ca") else (s, 0),
|
||||
)
|
||||
]
|
||||
assert dataset._get_all_transforms("caabbb", rules)[-1] == "bbabbbc"
|
||||
|
||||
rules = [
|
||||
(
|
||||
"If the string suffix is 'cc', replace it with 'b' and prepend 'a' to the start.",
|
||||
lambda s: ("a" + s[:-2] + "b", 7) if s.endswith("cc") else (s, 0),
|
||||
)
|
||||
]
|
||||
assert dataset._get_all_transforms("abbbcc", rules)[-1] == "aabbbb"
|
||||
|
||||
rules = [
|
||||
(
|
||||
"If the string prefix is 'aa', remove the first character.",
|
||||
lambda s: (s[1:], 8) if s.startswith("aa") else (s, 0),
|
||||
)
|
||||
]
|
||||
assert dataset._get_all_transforms("aabbb", rules)[-1] == "abbb"
|
||||
|
||||
rules = [
|
||||
(
|
||||
"If the string contains 'abc', replace the first occurrence with 'cab'.",
|
||||
lambda s: (s.replace("abc", "cab", 1), 9) if "abc" in s else (s, 0),
|
||||
)
|
||||
]
|
||||
assert dataset._get_all_transforms("ababcb", rules)[-1] == "cababb" # 'ababcb' -> 'abcabb' -> 'cababb'
|
||||
|
||||
rules = [
|
||||
(
|
||||
"If the string contains 'bca', delete the first occurrence entirely.",
|
||||
lambda s: (s.replace("bca", "", 1), 10) if "bca" in s else (s, 0),
|
||||
)
|
||||
]
|
||||
assert dataset._get_all_transforms("abbcab", rules)[-1] == "abb"
|
||||
|
||||
rules = [
|
||||
(
|
||||
"If the string ends with 'ba', replace it with 'ab'.",
|
||||
lambda s: (s[:-2] + "ab", 11) if s.endswith("ba") else (s, 0),
|
||||
)
|
||||
]
|
||||
assert dataset._get_all_transforms("abbbba", rules)[-1] == "abbbab"
|
||||
|
||||
rules = [
|
||||
(
|
||||
"If the string starts with 'cc', remove the first two characters.",
|
||||
lambda s: (s[2:], 12) if s.startswith("cc") else (s, 0),
|
||||
)
|
||||
]
|
||||
assert dataset._get_all_transforms("ccabbb", rules)[-1] == "abbb"
|
||||
|
||||
rules = [
|
||||
(
|
||||
"If the string contains 'acb', replace the first occurrence with its reverse ('bca').",
|
||||
lambda s: (s.replace("acb", "bca", 1), 13) if "acb" in s else (s, 0),
|
||||
)
|
||||
]
|
||||
assert dataset._get_all_transforms("abacbb", rules)[-1] == "abbcab"
|
||||
|
||||
rules = [
|
||||
(
|
||||
"If the string contains 'acb', replace the first occurrence with its reverse ('bca').",
|
||||
lambda s: (s.replace("acb", "bca", 1), 13) if "acb" in s else (s, 0),
|
||||
)
|
||||
]
|
||||
assert dataset._get_all_transforms("abacbb", rules)[-1] == "abbcab"
|
||||
|
||||
rules = [
|
||||
(
|
||||
"If the string ends with 'ca', remove the last character.",
|
||||
lambda s: (s[:-1], 14) if s.endswith("ca") and len(s) > 0 else (s, 0),
|
||||
)
|
||||
]
|
||||
assert dataset._get_all_transforms("abbbca", rules)[-1] == "abbbc"
|
||||
|
||||
rules = [
|
||||
(
|
||||
"If the string starts with 'bb', remove the second character.",
|
||||
lambda s: (s[0] + s[2:], 15) if s.startswith("bb") and len(s) >= 2 else (s, 0),
|
||||
)
|
||||
]
|
||||
assert dataset._get_all_transforms("bbabcbb", rules)[-1] == "babcbb"
|
||||
|
||||
rules = [
|
||||
(
|
||||
"If the string ends with 'aa', replace it with 'cc'.",
|
||||
lambda s: (s[:-2] + "cc", 16) if s.endswith("aa") else (s, 0),
|
||||
)
|
||||
]
|
||||
assert dataset._get_all_transforms("abccbaa", rules)[-1] == "abccbcc"
|
||||
|
||||
rules = [
|
||||
(
|
||||
"If the string contains 'ca' (not at the start), remove the first occurrence found after the first character.",
|
||||
lambda s: (s[:idx] + s[idx + 2 :], 17) if (idx := s.find("ca", 1)) != -1 else (s, 0),
|
||||
)
|
||||
]
|
||||
assert dataset._get_all_transforms("abacab", rules)[-1] == "abab"
|
||||
assert dataset._get_all_transforms("caabab", rules)[-1] == "caabab"
|
||||
|
||||
rules = [
|
||||
(
|
||||
"If the string contains an even number of 'b's (and at least one 'b'), append 'ab' at the end.",
|
||||
lambda s: (s + "ab", 18) if (s.count("b") > 0 and s.count("b") % 2 == 0) else (s, 0),
|
||||
)
|
||||
]
|
||||
assert dataset._get_all_transforms("abab", rules)[-1] == "ababab"
|
||||
assert dataset._get_all_transforms("abbab", rules)[-1] == "abbab"
|
||||
|
||||
rules = [
|
||||
(
|
||||
"If the string length is greater than 15, remove the middle character.",
|
||||
lambda s: (s[: len(s) // 2] + s[len(s) // 2 + 1 :], 19) if len(s) > 15 else (s, 0),
|
||||
)
|
||||
]
|
||||
assert (
|
||||
dataset._get_all_transforms("bccbcbbbcbbbbcccc", rules)[-1] == "bccbcbbbbbbcccc"
|
||||
) # bccbcbbbcbbbbcccc -> "bccbcbbbbbbbcccc" -> "bccbcbbbbbbcccc"
|
||||
|
||||
rules = [
|
||||
(
|
||||
"If the string starts with 'ac', replace the first two characters with 'zz'.",
|
||||
lambda s: ("zz" + s[2:], 20) if s.startswith("ac") else (s, 0),
|
||||
)
|
||||
]
|
||||
assert dataset._get_all_transforms("acab", rules)[-1] == "zzab"
|
||||
Loading…
Add table
Add a link
Reference in a new issue