reasoning-gym/reasoning_gym/languages/regex_puzzles.py
Ritvik19 6eb252ae32 Add 13 new procedural datasets across 7 categories
New dataset categories: combinatorics, statistics, optimization, and
formal languages. Extended existing algebra, arithmetic, probability,
logic, and graphs packages with complex_advanced, linear_algebra, limits,
number_theory, conditional_probability, set_operations, and job_scheduling.

Each dataset includes config validation, deterministic seeding, custom
scoring, curriculum support, and comprehensive unit tests (92 new tests).
2026-04-18 16:42:54 +05:30

257 lines
10 KiB
Python

import random
import re
from dataclasses import dataclass, field
from typing import Any, Optional
from ..coaching import BaseCurriculum, ScalarAttributeDefinition
from ..factory import ProceduralDataset, register_dataset
DATASET_NAME = "regex_puzzles"
TASK_TYPES = ("string_generation", "extraction", "dfa_state", "dfa_prefix")
REGEX_PATTERNS = [
(r"[a-c]{2}[0-9]{3}", "two lowercase letters (a-c) followed by three digits"),
(r"[A-Z]{3}[0-9]{2}", "three uppercase letters followed by two digits"),
(r"[0-9]{2}-[0-9]{2}-[0-9]{4}", "a date in DD-MM-YYYY format (digits only)"),
(r"[a-z]+@[a-z]+\.[a-z]{2,3}", "a simple email like name@domain.com"),
(r"[01]{4}", "a 4-digit binary string"),
(r"[A-Z][a-z]{2,5}", "a capitalized word (3-6 letters)"),
(r"[0-9]{1,3}\.[0-9]{1,3}\.[0-9]{1,3}\.[0-9]{1,3}", "an IP-address-like pattern"),
(r"#[0-9a-f]{6}", "a hex color code like #a1b2c3"),
]
@dataclass
class RegexPuzzlesConfig:
min_dfa_states: int = 3
max_dfa_states: int = 5
task_types: tuple[str, ...] = TASK_TYPES
task_weights: list[float] = field(default_factory=lambda: [0.3, 0.25, 0.25, 0.2])
seed: Optional[int] = None
size: int = 500
def validate(self) -> None:
assert self.size > 0, "size must be positive"
assert self.min_dfa_states >= 2, "min_dfa_states must be >= 2"
assert self.max_dfa_states >= self.min_dfa_states, "max_dfa_states must be >= min_dfa_states"
assert len(self.task_types) > 0, "must have at least one task type"
assert all(t in TASK_TYPES for t in self.task_types), f"invalid task type"
assert len(self.task_weights) == len(self.task_types), "weights must match types"
def _gen_matching_string(pattern: str, rng: random.Random) -> str:
"""Generate a string matching a simple regex pattern via character-level generation."""
import string
result = []
i = 0
while i < len(pattern):
if pattern[i] == "[":
end = pattern.index("]", i)
char_class = pattern[i + 1 : end]
i = end + 1
reps = 1
if i < len(pattern) and pattern[i] == "{":
end_brace = pattern.index("}", i)
quant = pattern[i + 1 : end_brace]
if "," in quant:
lo, hi = quant.split(",")
reps = rng.randint(int(lo), int(hi))
else:
reps = int(quant)
i = end_brace + 1
elif i < len(pattern) and pattern[i] == "+":
reps = rng.randint(1, 5)
i += 1
chars = []
j = 0
while j < len(char_class):
if j + 2 < len(char_class) and char_class[j + 1] == "-":
chars.extend(chr(c) for c in range(ord(char_class[j]), ord(char_class[j + 2]) + 1))
j += 3
else:
chars.append(char_class[j])
j += 1
for _ in range(reps):
result.append(rng.choice(chars))
elif pattern[i] == "\\":
i += 1
if pattern[i] == "d":
result.append(str(rng.randint(0, 9)))
elif pattern[i] == ".":
result.append(".")
elif pattern[i] == "$":
result.append("$")
i += 1
else:
result.append(pattern[i])
i += 1
return "".join(result)
class RegexPuzzlesDataset(ProceduralDataset):
def __init__(self, config: RegexPuzzlesConfig):
super().__init__(config=config, seed=config.seed, size=config.size)
def _make_string_generation(self, rng: random.Random) -> dict:
pattern, desc = rng.choice(REGEX_PATTERNS)
answer = _gen_matching_string(pattern, rng)
question = (
f"Generate a string that matches the regex pattern `{pattern}` "
f"(i.e., {desc}). Give only the string, nothing else."
)
return {"question": question, "answer": answer, "task_type": "string_generation", "pattern": pattern}
def _make_extraction(self, rng: random.Random) -> dict:
pattern_str = r"\$\d+\.\d{2}"
n = rng.randint(2, 4)
prices = [f"${rng.randint(1, 999)}.{rng.randint(10, 99):02d}" for _ in range(n)]
words = ["The price is", "costs", "for", "and", "total", "you pay", "item at"]
text_parts = []
for p in prices:
text_parts.append(rng.choice(words))
text_parts.append(p)
text_parts.append(rng.choice(["today", "now", "in total"]))
text = " ".join(text_parts)
matches = re.findall(pattern_str, text)
answer = ", ".join(matches)
question = (
f"Extract all dollar amounts (matching the pattern $X.XX) from the following text:\n"
f"'{text}'\n"
f"List them separated by commas in the order they appear."
)
return {"question": question, "answer": answer, "task_type": "extraction"}
def _make_dfa(self, rng: random.Random) -> tuple[dict, list, str, list, str]:
n = rng.randint(self.config.min_dfa_states, self.config.max_dfa_states)
states = [f"q{i}" for i in range(n)]
alphabet = ["a", "b"]
transitions = {}
for s in states:
for c in alphabet:
transitions[(s, c)] = rng.choice(states)
accept = rng.sample(states, rng.randint(1, max(1, n // 2)))
return transitions, states, states[0], accept, alphabet
def _run_dfa(self, transitions: dict, start: str, input_str: str) -> str:
state = start
for c in input_str:
state = transitions.get((state, c), state)
return state
def _make_dfa_state(self, rng: random.Random) -> dict:
transitions, states, start, accept, alphabet = self._make_dfa(rng)
input_len = rng.randint(3, 6)
input_str = "".join(rng.choice(alphabet) for _ in range(input_len))
final_state = self._run_dfa(transitions, start, input_str)
trans_str = ", ".join(f"δ({s},{c})={transitions[(s,c)]}" for s in states for c in alphabet)
question = (
f"A DFA has states {{{', '.join(states)}}}, alphabet {{a, b}}, start state {start}.\n"
f"Transitions: {trans_str}\n"
f"After processing the input '{input_str}', what state is the DFA in? "
f"Give only the state name."
)
return {"question": question, "answer": final_state, "task_type": "dfa_state"}
def _make_dfa_prefix(self, rng: random.Random) -> dict:
transitions, states, start, accept, alphabet = self._make_dfa(rng)
input_len = rng.randint(4, 8)
input_str = "".join(rng.choice(alphabet) for _ in range(input_len))
longest_prefix = ""
state = start
for i, c in enumerate(input_str):
state = transitions.get((state, c), state)
if state in accept:
longest_prefix = input_str[: i + 1]
if not longest_prefix:
if start in accept:
longest_prefix = ""
else:
longest_prefix = "NONE"
trans_str = ", ".join(f"δ({s},{c})={transitions[(s,c)]}" for s in states for c in alphabet)
accept_str = ", ".join(accept)
question = (
f"A DFA has states {{{', '.join(states)}}}, alphabet {{a, b}}, "
f"start state {start}, accept states {{{accept_str}}}.\n"
f"Transitions: {trans_str}\n"
f"What is the longest prefix of '{input_str}' that is accepted by this DFA? "
f"If no prefix is accepted, answer 'NONE'."
)
return {"question": question, "answer": longest_prefix, "task_type": "dfa_prefix"}
def __getitem__(self, idx: int) -> dict:
rng = random.Random(self.seed + idx)
task_type = rng.choices(self.config.task_types, weights=self.config.task_weights, k=1)[0]
generators = {
"string_generation": self._make_string_generation,
"extraction": self._make_extraction,
"dfa_state": self._make_dfa_state,
"dfa_prefix": self._make_dfa_prefix,
}
result = generators[task_type](rng)
return {
"question": result["question"],
"answer": result["answer"],
"metadata": {
"source_dataset": DATASET_NAME,
"source_index": idx,
"task_type": result["task_type"],
"difficulty": {
"min_dfa_states": self.config.min_dfa_states,
"max_dfa_states": self.config.max_dfa_states,
},
**({"pattern": result["pattern"]} if "pattern" in result else {}),
},
}
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
if answer is None:
return 0.0
oracle = entry["answer"]
if answer.strip() == oracle.strip():
return 1.0
task_type = entry["metadata"]["task_type"]
if task_type == "string_generation":
pattern = entry["metadata"]["pattern"]
try:
if re.fullmatch(pattern, answer.strip()):
return 1.0
except re.error:
pass
return 0.0
if task_type == "extraction":
try:
a_parts = [x.strip() for x in answer.split(",")]
o_parts = [x.strip() for x in oracle.split(",")]
if a_parts == o_parts:
return 1.0
return 0.0
except (ValueError, TypeError):
return 0.0
return 0.0
class RegexPuzzlesCurriculum(BaseCurriculum):
def __init__(self):
super().__init__(RegexPuzzlesCurriculum.__name__, RegexPuzzlesConfig)
self._define_attributes(
ScalarAttributeDefinition(
name="max_dfa_states",
field_name="max_dfa_states",
levels=[3, 5, 7, 10],
description="Maximum DFA states",
),
)
register_dataset(DATASET_NAME, RegexPuzzlesDataset, RegexPuzzlesConfig, RegexPuzzlesCurriculum)