Merge branch 'rich/decimalmath' of github.com:open-thought/reasoning-gym into rich/decimalmath

This commit is contained in:
Rich Jones 2025-02-19 03:34:57 +01:00
commit 0cd2eb50d7
62 changed files with 4012 additions and 478 deletions

View file

@ -1,45 +0,0 @@
name: Update GALLERY.md
on:
pull_request:
branches:
- main
types: [closed] # Trigger only when the PR is closed
permissions:
contents: write
jobs:
update-gallery:
runs-on: ubuntu-latest
if: github.event.pull_request.merged == true # Ensure it was merged, not just closed
steps:
- name: Check out repository code
uses: actions/checkout@v3
with:
fetch-depth: 0
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: "3.11"
- name: Install dependencies
run: |
pip install -e .
- name: Run gallery script
run: |
python scripts/generate_gallery.py
- name: Commit and push changes
run: |
git config user.name 'github-actions[bot]'
git config user.email 'github-actions[bot]@users.noreply.github.com'
git add GALLERY.md
if [ -n "$(git status --porcelain)" ]; then
git commit -m "Update GALLERY.md [skip ci]"
git push
else
echo "No changes to commit."
fi

1057
GALLERY.md

File diff suppressed because it is too large Load diff

View file

@ -1,19 +1,32 @@
### env setup ### env setup
``` ```
conda create --name verl python=3.12 -y conda create --name verl python=3.11 -y
conda activate verl conda activate verl
pip install flash-attn --no-build-isolation pip install flash-attn --no-build-isolation
pip install vllm==0.7.0 ray wandb pip install ray wandb
# pip3 install vllm==0.7.0
pip3 install vllm --pre --extra-index-url https://wheels.vllm.ai/nightly
``` ```
Regarding vllm>0.7 see: [docs](https://verl.readthedocs.io/en/latest/README_vllm0.7.html)
### clone and install veRL ### clone and install veRL
tested with verl HEAD a65c9157bc0b85b64cd753de19f94e80a11bd871 tested with verl HEAD 0dfcb7f99e299940e1792a386df13c7591df351a
``` ```
git clone https://github.com/volcengine/verl.git git clone https://github.com/volcengine/verl.git
cd verl cd verl
pip install -e . pip install -e .
``` ```
Optionally log in to huggingface hub and wandb with your keys:
```
huggingface-cli login
wandb login
```

View file

@ -0,0 +1,171 @@
data:
tokenizer: null
train_files: ~/data/rlhf/gsm8k/train.parquet
val_files: ~/data/rlhf/gsm8k/test.parquet
prompt_key: prompt
max_prompt_length: 512
max_response_length: 512
train_batch_size: 1024
val_batch_size: 1312
return_raw_input_ids: False # This should be set to true when the tokenizer between policy and rm differs
return_raw_chat: False
actor_rollout_ref:
hybrid_engine: True
model:
path: ~/models/deepseek-llm-7b-chat
external_lib: null
override_config: { }
enable_gradient_checkpointing: True
use_remove_padding: False
actor:
strategy: fsdp # This is for backward-compatibility
ppo_mini_batch_size: 256
ppo_micro_batch_size: null # will be deprecated, use ppo_micro_batch_size_per_gpu
ppo_micro_batch_size_per_gpu: null
use_dynamic_bsz: False
ppo_max_token_len_per_gpu: 16384 # n * ${data.max_prompt_length} + ${data.max_response_length}
grad_clip: 1.0
clip_ratio: 0.2
entropy_coeff: 0.001
use_kl_loss: True # True for GRPO
kl_loss_coef: 0.001 # for grpo
kl_loss_type: low_var_kl # for grpo
ppo_epochs: 1
shuffle: False
ulysses_sequence_parallel_size: 1 # sp size
optim:
lr: 1e-6
lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime
min_lr_ratio: null # only useful for warmup with cosine
warmup_style: constant # select from constant/cosine
total_training_steps: -1 # must be override by program
fsdp_config:
wrap_policy:
# transformer_layer_cls_to_wrap: None
min_num_params: 0
param_offload: False
optimizer_offload: False
fsdp_size: -1
ref:
fsdp_config:
param_offload: False
wrap_policy:
# transformer_layer_cls_to_wrap: None
min_num_params: 0
log_prob_micro_batch_size: null # will be deprecated, use log_prob_micro_batch_size_per_gpu
log_prob_micro_batch_size_per_gpu: null
log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz}
log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}
ulysses_sequence_parallel_size: ${actor_rollout_ref.actor.ulysses_sequence_parallel_size} # sp size
rollout:
name: vllm
temperature: 1.0
top_k: -1 # 0 for hf rollout, -1 for vllm rollout
top_p: 1
prompt_length: ${data.max_prompt_length} # not use for opensource
response_length: ${data.max_response_length}
# for vllm rollout
dtype: bfloat16 # should align with FSDP
gpu_memory_utilization: 0.5
ignore_eos: False
enforce_eager: True
free_cache_engine: True
load_format: dummy_dtensor
tensor_model_parallel_size: 2
max_num_batched_tokens: 8192
max_num_seqs: 1024
log_prob_micro_batch_size: null # will be deprecated, use log_prob_micro_batch_size_per_gpu
log_prob_micro_batch_size_per_gpu: null
log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz}
log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}
disable_log_stats: True
enable_chunked_prefill: True # could get higher throughput
# for hf rollout
do_sample: True
# number of responses (i.e. num sample times)
n: 16 # > 1 for grpo
critic:
strategy: fsdp
optim:
lr: 1e-5
lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime
min_lr_ratio: null # only useful for warmup with cosine
warmup_style: constant # select from constant/cosine
total_training_steps: -1 # must be override by program
model:
path: ~/models/deepseek-llm-7b-chat
tokenizer_path: ${actor_rollout_ref.model.path}
override_config: { }
external_lib: ${actor_rollout_ref.model.external_lib}
enable_gradient_checkpointing: True
use_remove_padding: False
fsdp_config:
param_offload: False
optimizer_offload: False
wrap_policy:
# transformer_layer_cls_to_wrap: None
min_num_params: 0
fsdp_size: -1
ppo_mini_batch_size: ${actor_rollout_ref.actor.ppo_mini_batch_size}
ppo_micro_batch_size: null # will be deprecated, use ppo_micro_batch_size_per_gpu
ppo_micro_batch_size_per_gpu: null
forward_micro_batch_size: ${critic.ppo_micro_batch_size}
forward_micro_batch_size_per_gpu: ${critic.ppo_micro_batch_size_per_gpu}
use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz}
ppo_max_token_len_per_gpu: 32768 # (${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}) * 2
forward_max_token_len_per_gpu: ${critic.ppo_max_token_len_per_gpu}
ulysses_sequence_parallel_size: 1 # sp size
ppo_epochs: ${actor_rollout_ref.actor.ppo_epochs}
shuffle: ${actor_rollout_ref.actor.shuffle}
grad_clip: 1.0
cliprange_value: 0.5
reward_model:
enable: False
strategy: fsdp
model:
input_tokenizer: ${actor_rollout_ref.model.path} # set this to null if the chat template is identical
path: ~/models/FsfairX-LLaMA3-RM-v0.1
external_lib: ${actor_rollout_ref.model.external_lib}
use_remove_padding: False
fsdp_config:
min_num_params: 0
param_offload: False
fsdp_size: -1
micro_batch_size: null # will be deprecated, use micro_batch_size_per_gpu
micro_batch_size_per_gpu: null # set a number
max_length: null
ulysses_sequence_parallel_size: 1 # sp size
use_dynamic_bsz: ${critic.use_dynamic_bsz}
forward_max_token_len_per_gpu: ${critic.forward_max_token_len_per_gpu}
algorithm:
gamma: 1.0
lam: 1.0
adv_estimator: gae
kl_penalty: kl # how to estimate kl divergence
kl_ctrl:
type: fixed
kl_coef: 0.001
trainer:
total_epochs: 30
total_training_steps: null
project_name: verl_examples
experiment_name: gsm8k
logger: [ 'console', 'wandb' ]
val_generations_to_log_to_wandb: 0
nnodes: 1
n_gpus_per_node: 8
save_freq: -1
# auto: find the last ckpt to resume. If can't find, start from scratch
resume_mode: auto # or auto or resume_path if
resume_from_path: False
test_freq: -1
critic_warmup: 0
default_hdfs_dir: null
remove_previous_ckpt_in_save: False
del_local_ckpt_after_load: False
default_local_dir: checkpoints/${trainer.project_name}/${trainer.experiment_name}

View file

@ -45,7 +45,6 @@ actor_rollout_ref:
# transformer_layer_cls_to_wrap: None # transformer_layer_cls_to_wrap: None
min_num_params: 0 min_num_params: 0
param_offload: False param_offload: False
grad_offload: False
optimizer_offload: False optimizer_offload: False
fsdp_size: -1 fsdp_size: -1
ref: ref:
@ -104,7 +103,6 @@ critic:
use_remove_padding: False use_remove_padding: False
fsdp_config: fsdp_config:
param_offload: False param_offload: False
grad_offload: False
optimizer_offload: False optimizer_offload: False
wrap_policy: wrap_policy:
# transformer_layer_cls_to_wrap: None # transformer_layer_cls_to_wrap: None
@ -158,10 +156,16 @@ trainer:
project_name: verl_examples project_name: verl_examples
experiment_name: gsm8k experiment_name: gsm8k
logger: [ 'console', 'wandb' ] logger: [ 'console', 'wandb' ]
val_generations_to_log_to_wandb: 0
nnodes: 1 nnodes: 1
n_gpus_per_node: 8 n_gpus_per_node: 8
save_freq: -1 save_freq: -1
# auto: find the last ckpt to resume. If can't find, start from scratch
resume_mode: auto # or auto or resume_path if
resume_from_path: False
test_freq: -1 test_freq: -1
critic_warmup: 0 critic_warmup: 0
default_hdfs_dir: ~/experiments/gsm8k/ppo/${trainer.experiment_name} default_hdfs_dir: null
remove_previous_ckpt_in_save: False
del_local_ckpt_after_load: False
default_local_dir: checkpoints/${trainer.project_name}/${trainer.experiment_name} default_local_dir: checkpoints/${trainer.project_name}/${trainer.experiment_name}

View file

@ -6,4 +6,4 @@ export ROLLOUT_TP_SIZE=2
export EXPERIMENT_NAME=chain_sum_llama export EXPERIMENT_NAME=chain_sum_llama
export VLLM_ATTENTION_BACKEND=XFORMERS export VLLM_ATTENTION_BACKEND=XFORMERS
bash ./train.sh bash ./train_grpo.sh

View file

@ -0,0 +1,39 @@
#!/bin/bash
set -x
python3 -u main_ppo_custom_reward.py \
algorithm.adv_estimator=grpo \
data.train_files=$DATA_DIR/train.parquet \
data.val_files=$DATA_DIR/test.parquet \
data.train_batch_size=1024 \
data.val_batch_size=1312 \
data.max_prompt_length=512 \
data.max_response_length=1024 \
actor_rollout_ref.model.path=$BASE_MODEL \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.model.use_remove_padding=True \
actor_rollout_ref.actor.ppo_mini_batch_size=256 \
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=80 \
actor_rollout_ref.actor.use_kl_loss=True \
actor_rollout_ref.actor.kl_loss_coef=0.001 \
actor_rollout_ref.actor.kl_loss_type=low_var_kl \
actor_rollout_ref.model.enable_gradient_checkpointing=True \
actor_rollout_ref.actor.fsdp_config.param_offload=False \
actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=160 \
actor_rollout_ref.rollout.tensor_model_parallel_size=$ROLLOUT_TP_SIZE \
actor_rollout_ref.rollout.name=vllm \
actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \
actor_rollout_ref.rollout.n=8 \
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=160 \
actor_rollout_ref.ref.fsdp_config.param_offload=True \
algorithm.kl_ctrl.kl_coef=0.001 \
trainer.critic_warmup=0 \
trainer.logger=['console'] \
trainer.project_name='verl_chain_sum_grpo' \
trainer.experiment_name=$EXPERIMENT_NAME \
trainer.n_gpus_per_node=$N_GPUS \
trainer.nnodes=1 \
trainer.save_freq=100 \
trainer.test_freq=100 \
trainer.total_epochs=15 $@ 2>&1 | tee verl_output.log

View file

@ -25,6 +25,6 @@ trainer.n_gpus_per_node=$N_GPUS \
trainer.nnodes=1 \ trainer.nnodes=1 \
trainer.save_freq=100 \ trainer.save_freq=100 \
trainer.test_freq=100 \ trainer.test_freq=100 \
trainer.project_name=verl_chain_sum \ trainer.project_name='verl_chain_sum_ppo' \
trainer.experiment_name=$EXPERIMENT_NAME \ trainer.experiment_name=$EXPERIMENT_NAME \
trainer.total_epochs=15 2>&1 | tee verl_output.log trainer.total_epochs=15 2>&1 | tee verl_output.log

View file

@ -1,9 +1,10 @@
import random import random
import string import warnings
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple from typing import Any, Dict, Optional, Tuple
import sympy as sp import sympy as sp
from sympy.polys.monomials import itermonomials
from ..factory import ProceduralDataset, register_dataset from ..factory import ProceduralDataset, register_dataset
@ -18,11 +19,13 @@ class PolynomialMultiplicationConfig:
max_terms: int = 4 # Maximum number of polynomial terms max_terms: int = 4 # Maximum number of polynomial terms
min_value: int = 1 # Minimum value for coefficients min_value: int = 1 # Minimum value for coefficients
max_value: int = 100 # Maximum value for coefficients max_value: int = 100 # Maximum value for coefficients
min_degree: int = 1 # Minimum polynomial degree min_degree: int = 0 # Minimum polynomial degree
max_degree: int = 3 # Maximum polynomial degree max_degree: int = 3 # Maximum polynomial degree
min_polynomials: int = 2 # Minimum number of polynomials being multiplied min_polynomials: int = 2 # Minimum number of polynomials being multiplied
max_polynomials: int = 3 # Maximum number of polynomials being multiplied max_polynomials: int = 3 # Maximum number of polynomials being multiplied
single_variable: bool = True variables: Tuple[str] = ("x", "y", "z") # Tuple of variable names, that will be chosen randomly
allow_cross_variable_product: bool = False # Generate tasks like "Multiply (x^2+3x-1)*(y^2-5)"
allow_multivariate_polynomials: bool = False # Generate multivariate tasks like "Multiply (2x^2 + 3y)*(5x^2+3x-1)"
operators: Tuple[str, ...] = ( operators: Tuple[str, ...] = (
"+", "+",
"-", "-",
@ -38,12 +41,17 @@ class PolynomialMultiplicationConfig:
assert self.min_value > 0, "min_value must be positive." assert self.min_value > 0, "min_value must be positive."
assert self.max_value >= self.min_value, "max_value must be >= min_value." assert self.max_value >= self.min_value, "max_value must be >= min_value."
assert self.min_degree >= 1, "min_degree must be >= 1." assert self.min_degree >= 0, "min_degree must be >= 0."
assert self.max_degree >= self.min_degree, "max_degree must be >= min_degree." assert self.max_degree >= self.min_degree, "max_degree must be >= min_degree."
assert self.min_polynomials >= 2, "min_polynomials must be >= 2." assert self.min_polynomials >= 2, "min_polynomials must be >= 2."
assert self.max_polynomials >= self.min_polynomials, "max_polynomials must be >= min_polynomials." assert self.max_polynomials >= self.min_polynomials, "max_polynomials must be >= min_polynomials."
assert len(self.variables) > 0, "The variable tuple is empty."
assert not (
self.allow_multivariate_polynomials and not self.allow_cross_variable_product
), "Multivariate polynomials require cross product."
allowed_ops = {"+", "-"} allowed_ops = {"+", "-"}
assert len(self.operators) > 0, "operators tuple cannot be empty." assert len(self.operators) > 0, "operators tuple cannot be empty."
assert all(op in allowed_ops for op in self.operators), "Invalid operator found. Must be a subset of {+, -}." assert all(op in allowed_ops for op in self.operators), "Invalid operator found. Must be a subset of {+, -}."
@ -76,13 +84,24 @@ In addition, When doing calculation, Use the following instructions together wit
A dict with: A dict with:
- question: str (e.g. "Multiply polynomials: (8x^3 + x + 2)*(x - 3)") - question: str (e.g. "Multiply polynomials: (8x^3 + x + 2)*(x - 3)")
- answer: str (Product, e.g. "8x^4 - 24x^3 + x^2 - x - 6") - answer: str (Product, e.g. "8x^4 - 24x^3 + x^2 - x - 6")
- metadata: dict with details (polynomial_expr, single_variable) - metadata: dict with details (polynomial_expr, result, variables)
""" """
rng = random.Random(self.seed + idx)
number_polynomials = rng.randint(self.config.min_polynomials, self.config.max_polynomials)
polynomials = [self._generate_polynomial_expr(rng) for i in range(number_polynomials)]
polynomial_expr = sp.prod(polynomials) rng = random.Random(self.seed + idx)
"""
Three Monomial States:
- allow_multivariate_polynomials == 1: list of multivariate monomials (e.g "xy" --> [x, y, xy, x**2, y**2])
- allow_cross_variable_product == 1: None. Will generate a unique list of single variable monomials for each term
- allow_cross_variable_product == 0: A shared list of monomials for each term (e.g "x" --> [x, x**2, 1])
"""
monomials = self._get_monomials(rng) if self.config.allow_cross_variable_product else None
monomials = None if self.config.allow_cross_variable_product else self._get_monomials(rng)
number_polynomials = rng.randint(self.config.min_polynomials, self.config.max_polynomials)
polynomial_terms = [self._generate_polynomial(rng, monomials) for _ in range(number_polynomials)]
polynomial_expr = sp.prod(polynomial_terms)
product = sp.expand(polynomial_expr) product = sp.expand(polynomial_expr)
question = rng.choice(self._prompt_templates).format(polynomial_expr=polynomial_expr) + self.added_instruction question = rng.choice(self._prompt_templates).format(polynomial_expr=polynomial_expr) + self.added_instruction
@ -91,54 +110,39 @@ In addition, When doing calculation, Use the following instructions together wit
"answer": product, "answer": product,
"metadata": { "metadata": {
"polynomial_expr": str(polynomial_expr), "polynomial_expr": str(polynomial_expr),
"single_variable": self.config.single_variable,
"result": str(product), "result": str(product),
"variables": list(product.free_symbols),
}, },
} }
def _get_variable(self, rng: random.Random) -> str: def _get_monomials(self, rng: random.Random) -> str:
"""Get a random lowercase variable name""" """Get a list of monomials"""
if self.config.single_variable: if self.config.allow_multivariate_polynomials:
return "x" sym = sp.symbols(self.config.variables)
return rng.choice(string.ascii_lowercase) else:
sym = [sp.symbols(rng.choice(self.config.variables))]
def _generate_polynomial_expr(self, rng: random.Random): monomials = list(itermonomials(sym, self.config.max_degree, self.config.min_degree))
""" return monomials
Randomly generate a polynomial expression of 'degree'.
We'll use the config parameters:
- min_terms, max_terms: how many total terms to combine
- min_value, max_value: range for coefficients
- operators: to decide sign flips or direct addition
Args:
rng: Random number generator
Returns:
Polynomial string
"""
variable = self._get_variable(rng)
degree = rng.randint(self.config.min_degree, self.config.max_degree)
x = sp.Symbol(variable)
def _generate_polynomial(self, rng: random.Random, monomials: Optional[list]):
"""Generates a random polynomial, returns expression."""
# Choose the number of terms and their respective degrees # Choose the number of terms and their respective degrees
monomials = monomials if monomials else self._get_monomials(rng)
num_terms = rng.randint(self.config.min_terms, self.config.max_terms) num_terms = rng.randint(self.config.min_terms, self.config.max_terms)
# Keep track of exponents, exponents can repeat or skip but we force the highest exponent
chosen_exponents = [degree]
# Fill the rest randomly in [0, degree]
for _ in range(num_terms - 1):
exp = rng.randint(0, degree)
chosen_exponents.append(exp)
# Now build the polynomial expression: sum_{term}( coeff * x^exponent ), with optional sign
polynomial_expr = 0 polynomial_expr = 0
for exp in chosen_exponents: for _ in range(num_terms):
# Pick a nonzero random coefficient between min_value and max_value.
coeff = rng.randint(self.config.min_value, self.config.max_value) coeff = rng.randint(self.config.min_value, self.config.max_value)
# Pick a random monomial
var = rng.choice(monomials)
# If '-' in operators, we can randomly flip the sign # If '-' in operators, we can randomly flip the sign
if "-" in self.config.operators and rng.random() < 0.5: if "-" in self.config.operators and rng.random() < 0.5:
coeff = -coeff coeff = -coeff
term_expr = coeff * (x**exp)
polynomial_expr += term_expr polynomial_expr += coeff * var
return polynomial_expr return polynomial_expr
@ -151,7 +155,7 @@ In addition, When doing calculation, Use the following instructions together wit
target_poly = sp.parse_expr(metadata["result"]) target_poly = sp.parse_expr(metadata["result"])
# Check if the difference simplifies to zero (i.e. they are equivalent). # Check if the difference simplifies to zero (i.e. they are equivalent).
if sp.simplify(predicted_poly - target_poly) == 0: if predicted_poly == target_poly:
reward = 1.0 reward = 1.0
elif answer.strip(): elif answer.strip():
reward = 0.05 reward = 0.05

View file

@ -11,6 +11,7 @@ from .base_conversion import BaseConversionConfig, BaseConversionDataset
from .binary_matrix import BinaryMatrixConfig, BinaryMatrixDataset from .binary_matrix import BinaryMatrixConfig, BinaryMatrixDataset
from .caesar_cipher import CaesarCipherConfig, CaesarCipherDataset from .caesar_cipher import CaesarCipherConfig, CaesarCipherDataset
from .count_primes import CountPrimesConfig, CountPrimesDataset from .count_primes import CountPrimesConfig, CountPrimesDataset
from .cryptarithm import CryptarithmConfig, CryptarithmDataset
from .game_of_life import GameOfLifeConfig, GameOfLifeDataset from .game_of_life import GameOfLifeConfig, GameOfLifeDataset
from .graph_color import GraphColorConfig, GraphColorDataset from .graph_color import GraphColorConfig, GraphColorDataset
from .group_anagrams import GroupAnagramsConfig, GroupAnagramsDataset from .group_anagrams import GroupAnagramsConfig, GroupAnagramsDataset
@ -21,6 +22,7 @@ from .manipulate_matrix import ManipulateMatrixConfig, ManipulateMatrixDataset
from .number_filtering import NumberFilteringConfig, NumberFilteringDataset from .number_filtering import NumberFilteringConfig, NumberFilteringDataset
from .number_sorting import NumberSortingConfig, NumberSortingDataset from .number_sorting import NumberSortingConfig, NumberSortingDataset
from .palindrome_generation import PalindromeConfig, PalindromeDataset from .palindrome_generation import PalindromeConfig, PalindromeDataset
from .palindrome_partitioning import PalindromePartitioningConfig, PalindromePartitioningDataset
from .pool_matrix import PoolMatrixConfig, PoolMatrixDataset from .pool_matrix import PoolMatrixConfig, PoolMatrixDataset
from .ransom_note import RansomNoteConfig, RansomNoteDataset from .ransom_note import RansomNoteConfig, RansomNoteDataset
from .rotate_matrix import RotateMatrixConfig, RotateMatrixDataset from .rotate_matrix import RotateMatrixConfig, RotateMatrixDataset
@ -42,6 +44,8 @@ __all__ = [
"BaseConversionDataset", "BaseConversionDataset",
"CaesarCipherConfig", "CaesarCipherConfig",
"CaesarCipherDataset", "CaesarCipherDataset",
"CryptarithmConfig",
"CryptarithmDataset",
"GameOfLifeConfig", "GameOfLifeConfig",
"GameOfLifeDataset", "GameOfLifeDataset",
"LetterCountingConfig", "LetterCountingConfig",
@ -65,6 +69,8 @@ __all__ = [
"PalindromeDataset", "PalindromeDataset",
"GroupAnagramsConfig", "GroupAnagramsConfig",
"GroupAnagramsDataset", "GroupAnagramsDataset",
"PalindromePartitioningConfig",
"PalindromePartitioningDataset",
"SpiralMatrixConfig", "SpiralMatrixConfig",
"SpiralMatrixDataset", "SpiralMatrixDataset",
"RansomNoteConfig", "RansomNoteConfig",

View file

@ -6,6 +6,26 @@ from typing import Optional, Tuple
from ..factory import ProceduralDataset, register_dataset from ..factory import ProceduralDataset, register_dataset
QUESTION_TEMPLATE = """Your task is to convert a number between two different bases.
If the target base is > 10, use lowercase letters a-z for digits above 9.
Example:
- Input: Convert the base-9 number 440 to base-5
- Output: 2420
- Explanation
- First, we convert the base-9 number 440 to base-10: 4 * 9**2 + 4 * 9**1 + 0 * 9**0 = 324 + 36 + 0 = 360
- Next, we convert the base-10 number 360 to base-5:
- 360 // 5 = 72 remainder 0
- 72 // 5 = 14 remainder 2
- 14 // 5 = 2 remainder 4
- 2 // 5 = 0 remainder 2
- Reading the remainders in reverse order gives us the base-5 number 2 4 2 0
- Hence, the final answer is 2420
Now, convert the {source_name} number {source_repr} to {target_name}
"""
@dataclass @dataclass
class BaseConversionConfig: class BaseConversionConfig:
@ -90,11 +110,10 @@ class BaseConversionDataset(ProceduralDataset):
source_name = self._format_base_name(source_base) source_name = self._format_base_name(source_base)
target_name = self._format_base_name(target_base) target_name = self._format_base_name(target_base)
# Add hint for bases > 10 about using lowercase letters
hint = " (use lowercase letters a-z for digits above 9)" if target_base > 10 else ""
return { return {
"question": f"Convert the {source_name} number {source_repr} to {target_name}{hint}", "question": QUESTION_TEMPLATE.format(
source_name=source_name, source_repr=source_repr, target_name=target_name
),
"answer": target_repr, "answer": target_repr,
"metadata": { "metadata": {
"decimal_value": value, "decimal_value": value,

View file

@ -7,23 +7,28 @@ https://leetcode.com/problems/01-matrix/description/
from collections import deque from collections import deque
from dataclasses import dataclass from dataclasses import dataclass
from random import Random from random import Random
from typing import Optional from typing import Dict, Optional
from ..factory import ProceduralDataset, register_dataset from ..factory import ProceduralDataset, register_dataset
QUESTION_TEMPLATE = """Given a square matrix, your job is to find the taxicab distance of the nearest 0 for each cell. QUESTION_TEMPLATE = """Given a square matrix, your job is to find the taxicab (Manhattan) distance of the nearest 0 for each cell.
Example: Example:
- Input: Find the distance to the nearest 0 for each cell in the matrix below:
Input: Find the distance to the nearest 0 for each cell in the matrix below:
0 0 0 0 0 0
0 1 0 0 1 0
1 1 1 1 1 1
- Output:
Output:
0 0 0 0 0 0
0 1 0 0 1 0
1 2 1 1 2 1
- Explanation
- Each cell with a 0 has a distance of 0 to itself.
- The cell at (1, 1) has a distance of 1 to the nearest 0 (any of the three 0's at (1, 0), (0, 1), (1, 2)).
- The cell at (2, 0) has a distance of 1 to the nearest 0 (the 0 at (1, 0)).
- The cell at (2, 1) has a distance of 2 to the nearest 0 (any of the two 0's at (1, 0), (1, 2))
- The cell at (2, 2) has a distance of 1 to the nearest 0 (the 0 at (1, 2)).
- Hence, the final answer is the matrix is the output shown above, where each cell contains the distance to the nearest 0, in the same format as the input matrix.
Find the distance to the nearest 0 for each cell in the matrix below: Find the distance to the nearest 0 for each cell in the matrix below:
{matrix} {matrix}
@ -34,6 +39,7 @@ Find the distance to the nearest 0 for each cell in the matrix below:
class BinaryMatrixConfig: class BinaryMatrixConfig:
"""Configuration for Binary Matrix dataset generation""" """Configuration for Binary Matrix dataset generation"""
min_n: int = 3 # Minimum size of the matrix
max_n: int = 10 # Maximum size of the matrix max_n: int = 10 # Maximum size of the matrix
p_zero: float = 0.25 # Probability of a cell being 0 p_zero: float = 0.25 # Probability of a cell being 0
@ -42,7 +48,8 @@ class BinaryMatrixConfig:
def validate(self): def validate(self):
"""Validate configuration parameters""" """Validate configuration parameters"""
assert 1 <= self.max_n, "max_n must be at least 1" assert 1 <= self.min_n, "min_n must be at least 1"
assert self.min_n <= self.max_n, "min_n must be less than or equal to max_n"
assert 0 < self.p_zero <= 1, "p_zero must be between 0 and 1" assert 0 < self.p_zero <= 1, "p_zero must be between 0 and 1"
@ -54,7 +61,7 @@ class BinaryMatrixDataset(ProceduralDataset):
def _get_binary_matrix(self, rng: Random) -> list[list[int]]: def _get_binary_matrix(self, rng: Random) -> list[list[int]]:
"""Generate a random binary matrix""" """Generate a random binary matrix"""
n = rng.randint(1, self.config.max_n) n = rng.randint(self.config.min_n, self.config.max_n)
# Ensure at least one 0 in the matrix, so that a solution exists # Ensure at least one 0 in the matrix, so that a solution exists
numbers = [0] + [0 if rng.random() < self.config.p_zero else 1 for _ in range(n**2 - 1)] numbers = [0] + [0 if rng.random() < self.config.p_zero else 1 for _ in range(n**2 - 1)]
rng.shuffle(numbers) rng.shuffle(numbers)
@ -105,6 +112,22 @@ class BinaryMatrixDataset(ProceduralDataset):
"""Get a string representation of the matrix""" """Get a string representation of the matrix"""
return "\n".join(" ".join(str(x) for x in row) for row in matrix) return "\n".join(" ".join(str(x) for x in row) for row in matrix)
def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float:
"""Overwrite this method in derived classes if a single oracle answer is not available."""
oracle_answer = entry["answer"]
if answer is not None:
if answer == oracle_answer:
return 1.0
else:
try:
# check if answer is python list of lists
answer = self._matrix_to_str(eval(answer))
if answer == oracle_answer:
return 0.5
except Exception as e:
return 0.01
return 0.0
def __getitem__(self, idx: int) -> dict: def __getitem__(self, idx: int) -> dict:
"""Generate a single Binary Matrix question""" """Generate a single Binary Matrix question"""
rng = Random(self.seed + idx) rng = Random(self.seed + idx)

View file

@ -0,0 +1,206 @@
"""
Cryptarithm puzzle generator (numbers -> letters approach).
Generates puzzles such that:
WORD1
+ WORD2
[+ WORD3]
---------
RESULT
where each letter corresponds to exactly one digit (0..9).
No leading letter can be zero (unless allow_leading_zero=True).
"""
from dataclasses import dataclass
from random import Random
from typing import Optional
from ..factory import ProceduralDataset, register_dataset
EXAMPLE_CASE = """
BASE
+ BALL
------
GAMES
Answer (one possible solution):
B=7, A=8, S=2, E=9, L=1, G=1, M=0
Summation: 7829 + 7811 = 15640 (the puzzle might produce a different arrangement, but the principle is the same)."""
@dataclass
class CryptarithmConfig:
"""Configuration for Cryptarithm dataset generation."""
min_words: int = 2 # Minimum number of addends
max_words: int = 3 # Maximum number of addends
allow_leading_zero: bool = False
include_example: bool = True
seed: Optional[int] = None
size: int = 500 # Number of puzzle instances to generate
def validate(self):
"""Validate configuration parameters."""
assert 2 <= self.min_words <= self.max_words, "min_words must be <= max_words, both >= 2."
assert self.size > 0, "Dataset size must be positive."
class CryptarithmDataset(ProceduralDataset):
"""
Generates cryptarithm puzzles by:
1) Randomly choosing integers for each "addend" (with no leading zero if not allowed),
2) Summing them,
3) Mapping distinct digits (0..9) to letters (A..Z),
4) Formatting the puzzle text.
This approach guarantees sum correctness and avoids repeated failures.
"""
def __init__(self, config: CryptarithmConfig):
super().__init__(config=config, seed=config.seed, size=config.size)
def __getitem__(self, idx: int) -> dict:
rng = Random(self.seed + idx)
return self._create_single_puzzle(rng)
def _create_single_puzzle(self, rng: Random) -> dict:
"""
Creates one puzzle with N addends (2..3) plus a result.
Ensures total distinct digits <= 10.
"""
# 1) Pick how many addends
n_words = rng.randint(self.config.min_words, self.config.max_words)
# 2) For each addend, pick a random length (3..5) and then pick a random integer with that many digits.
# If leading zero is disallowed, the first digit is from 1..9.
word_lengths = [rng.randint(3, 5) for _ in range(n_words)]
words_numbers = []
for length in word_lengths:
if self.config.allow_leading_zero:
# e.g. random integer in [0, 10^length - 1], then zero-pad to length
num = rng.randint(0, 10**length - 1)
else:
# leading digit is from 1..9, rest are from 0..9
# e.g. random integer in [10^(length-1), 10^length - 1]
num = rng.randint(10 ** (length - 1), 10**length - 1)
words_numbers.append(num)
# 3) Compute the sum
total_sum = sum(words_numbers)
# The sum can have up to (max_length+1) digits, which is normal in cryptarithms.
# 4) Gather all digits from the addends and the sum
digits_in_use = set()
def collect_digits(num: int):
return set(str(num))
for wn in words_numbers:
digits_in_use.update(collect_digits(wn))
digits_in_use.update(collect_digits(total_sum))
# If we exceed 10 distinct digits, try again (pick new random numbers).
# In practice, we can loop until success. But for demonstration, let's do a simple re-pick approach.
# We'll do a while loop up to some attempts:
if len(digits_in_use) > 10:
# Just do a recursion call to pick new numbers, ignoring current picks
return self._create_single_puzzle(rng)
# 5) Map each digit to a letter
# If no leading zero is allowed, the leading digit of each addend + result must not map to '0'.
# Actually, we are generating real numeric values, so there's no scenario of leading "0" for
# the addends we enforced (except if allow_leading_zero is True).
# For the puzzle's perspective, we simply create a random assignment from {digits_in_use} -> letters.
# Then the solver has to figure it out. They don't see the digits, only letters.
digits_in_use_list = sorted(list(digits_in_use)) # e.g. ['0', '1', '3', '9']
rng.shuffle(digits_in_use_list) # shuffle so mapping is random
letters_pool = [chr(i) for i in range(ord("A"), ord("Z") + 1)]
rng.shuffle(letters_pool)
chosen_letters = letters_pool[: len(digits_in_use_list)]
# digit -> letter mapping
digit_to_letter = {}
for d, letter in zip(digits_in_use_list, chosen_letters):
digit_to_letter[d] = letter
# If leading-zero is not allowed, we must ensure that the first digit of each addend and the sum
# does not map to the letter that is assigned to digit '0'. If we see a conflict, we can just re-pick
# or we can try to swap letters. The simplest is to re-pick for demonstration.
if not self.config.allow_leading_zero and "0" in digit_to_letter:
zero_letter = digit_to_letter["0"]
# Check the first digit of each addend and of the sum
for wn in words_numbers:
first_digit = str(wn)[0]
if digit_to_letter.get(first_digit) == zero_letter:
# Conflict => re-generate puzzle
return self._create_single_puzzle(rng)
sum_first_digit = str(total_sum)[0]
if digit_to_letter.get(sum_first_digit) == zero_letter:
return self._create_single_puzzle(rng)
# Now we have a stable digit->letter mapping. Let's create the letter->digit mapping for the answer.
letter_to_digit = {v: int(k) for k, v in digit_to_letter.items()}
# 6) Convert each integer to its letter representation
def int_to_letter_str(num: int) -> str:
return "".join(digit_to_letter[d] for d in str(num))
words_letters = [int_to_letter_str(num) for num in words_numbers]
result_letters = int_to_letter_str(total_sum)
# 7) Create the puzzle text
# We'll do the typical vertical format, with a plus sign before the last addend, dashes, then result
puzzle_lines = []
max_width = max(len(w) for w in words_letters + [result_letters])
for i, wl in enumerate(words_letters):
if i < len(words_letters) - 1:
# Right align with spaces, +2 for the " " prefix
puzzle_lines.append(f"{wl:>{max_width+2}}")
else:
# Right align with spaces, +2 for the "+ " prefix
puzzle_lines.append(f"+ {wl:>{max_width}}")
# The line of dashes should match the longest line
puzzle_lines.append("-" * (max_width + 2))
# Right align the result
puzzle_lines.append(f"{result_letters:>{max_width+2}}")
puzzle_text = "\n".join(puzzle_lines)
question_str = (
"Solve this cryptarithm:\n\n"
f"{puzzle_text}\n\n"
"Each letter stands for a unique digit (0-9). "
+ (
"Leading letters may be zero.\n"
if self.config.allow_leading_zero
else "No leading letter can be zero.\n"
)
+ "Provide a mapping from letters to digits that satisfies the equation.\n"
)
if self.config.include_example:
question_str += "Here's an example:\n" + EXAMPLE_CASE
# 8) Create a human-readable answer, e.g. "A=1,B=0,C=9,..."
sorted_letter_keys = sorted(letter_to_digit.keys())
answer_str = ",".join(f"{letter}={letter_to_digit[letter]}" for letter in sorted_letter_keys)
# 9) Return the final puzzle item
return {
"question": question_str,
"answer": answer_str,
"metadata": {
"letters": list(letter_to_digit.keys()),
"word_values": words_numbers,
"sum_number": total_sum,
"words_letters": words_letters,
"result_letters": result_letters,
"digit_to_letter": digit_to_letter,
"letter_to_digit": letter_to_digit,
},
}
register_dataset("cryptarithm", CryptarithmDataset, CryptarithmConfig)

View file

@ -32,7 +32,7 @@ class GameOfLifeDataset(ProceduralDataset):
def __init__(self, config: GameOfLifeConfig): def __init__(self, config: GameOfLifeConfig):
self._prompt_templates = [ self._prompt_templates = [
"What will this Game of Life board look like after {simulation_steps} steps of simulation? Reply as array of array representing rows in the grid from top to bottom in JSON format. (An empty 3x3 grid would look like this: [[0,0,0],[0,0,0],[0,0,0]])\n\n{board}." "What will this Game of Life board look like after {simulation_steps} steps of simulation? Reply as array of array representing rows in the grid from top to bottom in JSON format. Let your answer(array of array be on a single line). (An empty 3x3 grid would look like this: [[0,0,0],[0,0,0],[0,0,0]])\n\n{board}."
] ]
super().__init__(config=config, seed=config.seed, size=config.size) super().__init__(config=config, seed=config.seed, size=config.size)

View file

@ -200,7 +200,7 @@ Vertices: {puzzle["vertices"]}
Edges: {edges} Edges: {edges}
Possible colors: {puzzle["color_options"]} Possible colors: {puzzle["color_options"]}
Return your solution as a JSON map of verteces to colors. (For example: {{0: 1, 1: 2, 2: 3}}) Return your solution as a JSON map of vertices to colors. (For example: {{0: 1, 1: 2, 2: 3}})
""" """
return { return {

View file

@ -90,7 +90,7 @@ class GroupAnagramsDataset(ProceduralDataset):
def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float: def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float:
"""Score a single Group Anagrams question""" """Score a single Group Anagrams question"""
reward = 0 reward = 0.0
if answer is not None: if answer is not None:
try: try:
answer = json.loads(answer) answer = json.loads(answer)
@ -98,11 +98,11 @@ class GroupAnagramsDataset(ProceduralDataset):
answer_str = json.dumps(self._sort_nested_list(answer)) answer_str = json.dumps(self._sort_nested_list(answer))
oracle_str = json.dumps(self._sort_nested_list(oracle)) oracle_str = json.dumps(self._sort_nested_list(oracle))
if answer_str == oracle_str: if answer_str == oracle_str:
reward = 1 reward = 1.0
else: else:
reward = 0.01 reward = 0.01
except Exception: except Exception:
reward = 0 reward = 0.0
return reward return reward
def __getitem__(self, idx: int) -> dict: def __getitem__(self, idx: int) -> dict:

View file

@ -9,6 +9,30 @@ from reasoning_gym.data import read_data_file
from ..factory import ProceduralDataset, register_dataset from ..factory import ProceduralDataset, register_dataset
QUESTION_TEMPLATE = """Your task is to unsramble words in a sentence.
For each word in a sentence, the letter may have been randomly shuffled. Your task is to unscramble the words.
The order of the words in the sentence is preserved. Moreover, the style of the sentence is preserved (i.e. punctuation, capitalization, new lines, etc.).
Example:
- Input: Unscramble these words: raendgmeins yWh nya hilcd anc od hatt
- Output: meanderings Why any child can do that
- Explanation
- We unscramble each of the words independently.
- raendgmeins -> meanderings
- yWh -> Why
- nya -> any
- hilcd -> child
- anc -> can
- od -> do
- hatt -> that
- The final answer is: meanderings Why any child can do that
- Notice that the order of the words is preserved, no new words / symbols (e.g. new lines) are added.
Now, unscramble these words: {words}
"""
@dataclass @dataclass
class LetterJumbleConfig: class LetterJumbleConfig:
@ -89,7 +113,7 @@ class LetterJumbleDataset(ProceduralDataset):
scrambled_words = [self._scramble_word(word, corruption_level, rng) for word in selected_words] scrambled_words = [self._scramble_word(word, corruption_level, rng) for word in selected_words]
return { return {
"question": f"Unscramble these words: {' '.join(scrambled_words)}", "question": QUESTION_TEMPLATE.format(words=" ".join(scrambled_words)),
"answer": " ".join(selected_words), "answer": " ".join(selected_words),
"metadata": { "metadata": {
"num_words": num_words, "num_words": num_words,
@ -112,14 +136,16 @@ class LetterJumbleDataset(ProceduralDataset):
float: The computed score between 0.0 and 1.0. float: The computed score between 0.0 and 1.0.
""" """
if answer == None: oracle_answer = entry["answer"].strip()
return 0.0 if answer:
answer = answer.strip()
s_answer = answer.strip().lower() if answer == oracle_answer:
if not s_answer == entry["answer"].strip().lower(): return 1.0
return 0.01 elif answer.lower() == oracle_answer.lower():
else: return 0.5
return 1.0 else:
return 0.01
return 0.0
register_dataset("letter_jumble", LetterJumbleDataset, LetterJumbleConfig) register_dataset("letter_jumble", LetterJumbleDataset, LetterJumbleConfig)

View file

@ -5,6 +5,23 @@ from typing import Any, Dict, Optional
from ..factory import ProceduralDataset, register_dataset from ..factory import ProceduralDataset, register_dataset
QUESTION_TEMPALTE = """Your task is, given a list of letters, to form a valid palindrome.
A palindrome is a phrase that reads the same forwards and backwards.
If there are multiple possible answers, only respond with one of them. You must use all the letters provided.
Example:
- Input: Form a valid palindrome using the following letters: a, a, b
- Output: aba
- Explanation:
- The phrase aba reads the same forwards and backwards.
- The output answer is a valid palindrome using all the letters provided.
- The answer is a string, rather than a list of characters.
Now, form a valid palindrome using the following letters: {letters}
"""
@dataclass @dataclass
class PalindromeConfig: class PalindromeConfig:
@ -51,16 +68,8 @@ class PalindromeDataset(ProceduralDataset):
letters = self._generate_palindrome_letters(rng, length) letters = self._generate_palindrome_letters(rng, length)
scrambled_letters = rng.sample(letters, len(letters)) # Scramble the order scrambled_letters = rng.sample(letters, len(letters)) # Scramble the order
palindrome = self._assemble_palindrome(letters) palindrome = self._assemble_palindrome(letters)
question_str = (
"Rearrange these letters to form a palindrome. A palindrome is a word, phrase, or sequence that reads the same forward and backward. If there are multiple answers, only respond with one of them.\n\n"
"For example, if the letters are: a, a, b — a valid palindrome is: aba.\n\n"
f"Your letters: {', '.join(scrambled_letters)}\n\n"
"What palindrome can you form from these letters?"
)
return { return {
"question": question_str, "question": QUESTION_TEMPALTE.format(letters=", ".join(scrambled_letters)),
"answer": palindrome, "answer": palindrome,
"metadata": { "metadata": {
"letters": scrambled_letters, "letters": scrambled_letters,

View file

@ -0,0 +1,144 @@
"""Given a string, return all possible partitions of the string such that each substring is a palindrome.
A popular Leetcode problem:
https://leetcode.com/problems/palindrome-partitioning/description/
"""
import json
import string
from dataclasses import dataclass
from random import Random
from typing import Dict, Optional
from ..factory import ProceduralDataset, register_dataset
QUESTION_TEMPLATE = """Given a string, partition it such that every substring is a palindrome.
A palindrome is a word that reads the same backward as forward.
You may return all possible palindrome partitioning in any order.
Example:
- Input: Partition the following string into palindromes: aab
- Output: [["a","a","b"],["aa","b"]]
- Explanation:
- One way to partition the string is "a" | "a" | "b", where each substring is a palindrome.
- Another way to partition the string is "aa" | "b", where again each substring is a palindrome.
- Therefore, the final result is a list of the two palindrome partitions.
Partition the following string into palindromes: {string}
"""
@dataclass
class PalindromePartitioningConfig:
"""Configuration for Palindrome Partitioning dataset generation"""
min_string_len: int = 5
max_string_len: int = 15
max_substring_palindome_len: int = 5
size: int = 500 # Virtual dataset size
seed: Optional[int] = None
def validate(self):
"""Validate configuration parameters"""
assert 1 <= self.min_string_len, "Minimum string length must be at least 1"
assert self.min_string_len <= self.max_string_len, "Minimum string length must be less than or equal to maximum"
assert 1 <= self.max_substring_palindome_len, "Maximum substring palindrome length must be at least 1"
assert (
self.max_substring_palindome_len <= self.max_string_len
), "Maximum substring palindrome length must be less than or equal to maximum string length"
class PalindromePartitioningDataset(ProceduralDataset):
"""Generates Palindrome Partitioning exercises with configurable difficulty"""
def __init__(self, config: PalindromePartitioningConfig):
super().__init__(config=config, seed=config.seed, size=config.size)
def _sort_list(self, lst: list[list[str]]) -> list[list[str]]:
"""Sort the list of palindrome partitions"""
return sorted(lst, key=lambda x: x[0] if x else "")
def to_set_of_tuples(self, list_of_lists: list[list[str]]) -> set[tuple[str]]:
"""Convert a list of lists to a set of tuples"""
return {tuple(lst) for lst in list_of_lists}
def _palindrome_partitioning(self, string: str) -> list[list[str]]:
"""Return all possible palindrome partitions of a string"""
if not string:
return []
dp = {}
def is_palindrome(i, j) -> bool:
if i >= j:
return True
if (i, j) in dp:
return dp[(i, j)]
dp[(i, j)] = string[i] == string[j] and is_palindrome(i + 1, j - 1)
return dp[(i, j)]
res, temp = [], []
def _partition(idx) -> None:
if idx >= len(string):
res.append(temp[:])
for i in range(idx, len(string)):
if is_palindrome(idx, i):
temp.append(string[idx : i + 1])
_partition(i + 1)
temp.pop()
_partition(0)
return self._sort_list(res)
def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float:
"""Score a single Palindrome Partitioning question"""
if answer is not None:
try:
answer = self.to_set_of_tuples(json.loads(answer))
oracle = self.to_set_of_tuples(entry["metadata"]["solution"])
if answer == oracle:
return 1.0
return 0.01
except Exception:
return 0.0
return 0.0
def _generate_palindrome_letters(self, rng: Random, length: int) -> list[str]:
"""Generate a set of letters that can form a palindrome."""
half_length = length // 2
letters = rng.choices(string.ascii_lowercase, k=half_length)
if length % 2 == 1:
middle_letter = rng.choice(string.ascii_lowercase)
return letters + [middle_letter] + letters[::-1]
return letters + letters[::-1]
def _get_string(self, rng: Random) -> str:
"""Generate a random string"""
size = rng.randint(self.config.min_string_len, self.config.max_string_len)
output = ""
while len(output) < size:
palindrome_len = rng.randint(1, min(self.config.max_substring_palindome_len, size - len(output)))
substring = "".join(self._generate_palindrome_letters(rng, palindrome_len))
output += substring
return output
def __getitem__(self, idx: int) -> dict:
"""Generate a single Palindrome Partitioning question"""
rng = Random(self.seed + idx)
string = self._get_string(rng)
answer = self._palindrome_partitioning(string)
answer_str = json.dumps(answer)
return {
"question": QUESTION_TEMPLATE.format(string=string),
"answer": answer_str,
"metadata": {"string": string, "solution": answer},
}
register_dataset("palindrome_partitioning", PalindromePartitioningDataset, PalindromePartitioningConfig)

View file

@ -3,7 +3,7 @@
import re import re
from dataclasses import dataclass from dataclasses import dataclass
from random import Random from random import Random
from typing import Optional from typing import Any, Dict, Optional
from ..data import read_data_file from ..data import read_data_file
from ..factory import ProceduralDataset, register_dataset from ..factory import ProceduralDataset, register_dataset
@ -92,5 +92,26 @@ class SentenceReorderingDataset(ProceduralDataset):
"metadata": {"word_count": word_count}, "metadata": {"word_count": word_count},
} }
def score_answer(self, answer: Optional[str], entry: Dict[str, Any]) -> float:
reward = 0.0
expected_answer = entry["answer"]
if answer is not None:
try:
if expected_answer == answer:
return 1.0
goal_words = expected_answer.split()
answer_words = answer.split()
if len(goal_words) == len(answer_words):
credit = [
1 if goal_word.lower() == answer_word.lower() else 0
for goal_word, answer_word in zip(goal_words, answer_words)
]
reward = sum(credit) / len(credit)
else:
reward = 0.05
except:
reward = 0.01
return reward
register_dataset("sentence_reordering", SentenceReorderingDataset, SentenceReorderingConfig) register_dataset("sentence_reordering", SentenceReorderingDataset, SentenceReorderingConfig)

View file

@ -3,7 +3,7 @@
import re import re
from dataclasses import dataclass from dataclasses import dataclass
from random import Random from random import Random
from typing import Optional from typing import Any, Dict, Optional
from ..data import read_data_file from ..data import read_data_file
from ..factory import ProceduralDataset, register_dataset from ..factory import ProceduralDataset, register_dataset
@ -49,5 +49,18 @@ class SpellBackwardDataset(ProceduralDataset):
"metadata": {"word": word, "word_len": len(word)}, "metadata": {"word": word, "word_len": len(word)},
} }
def score_answer(self, answer: Optional[str], entry: Dict[str, Any]) -> float:
reward = 0.0
expected_answer = entry["answer"]
if answer is not None:
try:
if expected_answer.lower() == answer.lower():
reward = 1.0
else:
reward = 0.05
except:
reward = 0.01
return reward
register_dataset("spell_backward", SpellBackwardDataset, SpellBackwardConfig) register_dataset("spell_backward", SpellBackwardDataset, SpellBackwardConfig)

View file

@ -6,20 +6,25 @@ https://leetcode.com/problems/spiral-matrix/description/
from dataclasses import dataclass from dataclasses import dataclass
from random import Random from random import Random
from typing import Optional from typing import Dict, Optional
from ..factory import ProceduralDataset, register_dataset from ..factory import ProceduralDataset, register_dataset
QUESTION_TEMPLATE = """Given a matrix, your job is to generate a list of elements in spiral order, starting from the top-left element. QUESTION_TEMPLATE = """Given a matrix, your job is to generate a list of elements in spiral order, starting from the top-left element.
Example: Example:
- Input: For the matrix below, what is the list of elements in spiral order?
Input:
1 2 3 1 2 3
4 5 6 4 5 6
7 8 9 7 8 9
- Output: 1 2 3 6 9 8 7 4 5
Output: 1 2 3 6 9 8 7 4 5 - Explanation:
- We start from the top-left element (1) and move right until we reach the end of the row: 1 2 3
- Then, we move down until we reach the last column: 1 2 3 6 9
- Next, we move left until we reach the first column: 1 2 3 6 9 8 7
- Then, we move up until we reach the second row (i.e. one below the previously traversed row): 1 2 3 6 9 8 7 4
- Finally, we move right until we reach the second to last column: 1 2 3 6 9 8 7 4 5
- The output format is a space-separated list of elements in spiral order (as opposed to a python list)
For the matrix below, what is the list of elements in spiral order? For the matrix below, what is the list of elements in spiral order?
{matrix} {matrix}
@ -37,7 +42,7 @@ class SpiralMatrixConfig:
def validate(self): def validate(self):
"""Validate configuration parameters""" """Validate configuration parameters"""
assert 1 <= self.max_n, "max_n must be at least 1" assert 2 <= self.max_n, "max_n must be at least 2"
class SpiralMatrixDataset(ProceduralDataset): class SpiralMatrixDataset(ProceduralDataset):
@ -48,7 +53,7 @@ class SpiralMatrixDataset(ProceduralDataset):
def _get_matrix(self, rng: Random) -> list[list[int]]: def _get_matrix(self, rng: Random) -> list[list[int]]:
"""Generate a random matrix""" """Generate a random matrix"""
n = rng.randint(1, self.config.max_n) n = rng.randint(2, self.config.max_n)
numbers = [rng.randint(0, 9) for _ in range(n**2)] numbers = [rng.randint(0, 9) for _ in range(n**2)]
rng.shuffle(numbers) rng.shuffle(numbers)
matrix = [numbers[i * n : (i + 1) * n] for i in range(n)] matrix = [numbers[i * n : (i + 1) * n] for i in range(n)]
@ -111,5 +116,28 @@ class SpiralMatrixDataset(ProceduralDataset):
"metadata": {"matrix": matrix, "solution": answer}, "metadata": {"matrix": matrix, "solution": answer},
} }
def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float:
"""Overwrite this method in derived classes if a single oracle answer is not available."""
oracle_answer = entry["answer"].strip()
if answer is not None and len(answer) > 0:
answer = answer.strip()
# Exact match
if answer == oracle_answer:
return 1.0
# Try to see if the model's answer is a python list
try:
answer = " ".join(str(item) for item in eval(answer))
if answer == oracle_answer:
return 0.5
else:
return 0.01
except Exception as e:
return 0.01
return 0.0
register_dataset("spiral_matrix", SpiralMatrixDataset, SpiralMatrixConfig) register_dataset("spiral_matrix", SpiralMatrixDataset, SpiralMatrixConfig)

View file

@ -5,7 +5,7 @@ https://github.com/yongchao98/CodeSteer-v1.0/blob/main/create_dataset/create_dat
from dataclasses import dataclass from dataclasses import dataclass
from random import Random from random import Random
from typing import Optional from typing import Dict, Optional
from ..factory import ProceduralDataset, register_dataset from ..factory import ProceduralDataset, register_dataset
@ -26,6 +26,7 @@ Example
- First, we insert A after ABCD. - First, we insert A after ABCD.
- Even though with the newly inserted 'A' we can obtain the substring BCD[A], we can't use it to insert another character. - Even though with the newly inserted 'A' we can obtain the substring BCD[A], we can't use it to insert another character.
- Lastly, we insert D after DEAB. - Lastly, we insert D after DEAB.
- Therefore, the final answer is DDABCDAEEDEABD (represented as a string, instead of a list of characters).
Given the following string, provide the answer after inserting the characters according to the pattern: {string} Given the following string, provide the answer after inserting the characters according to the pattern: {string}
""" """
@ -79,12 +80,28 @@ class StringInsertionDataset(ProceduralDataset):
i += 1 i += 1
return "".join(output) return "".join(output)
def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float:
"""Overwrite this method in derived classes if a single oracle answer is not available."""
oracle_answer = entry["answer"]
if answer is not None:
if answer == oracle_answer:
return 1.0
else:
try:
# check if answer is python list of characters
answer = "".join(eval(answer))
if answer == oracle_answer:
return 0.5
except Exception as e:
return 0.01
return 0.0
def __getitem__(self, idx: int) -> dict: def __getitem__(self, idx: int) -> dict:
"""Generate a single String Insertion question""" """Generate a single String Insertion question"""
rng = Random(self.seed + idx) rng = Random(self.seed + idx)
string_length = rng.randint(self.config.min_string_length, self.config.max_string_length) string_length = rng.randint(self.config.min_string_length, self.config.max_string_length)
string = [rng.choice(self.vocabulary) for _ in range(string_length)] string = "".join(rng.choice(self.vocabulary) for _ in range(string_length))
answer = self._get_answer(string) answer = self._get_answer(string)

View file

@ -8,6 +8,10 @@ from typing import Dict, List, Optional, Set, Tuple
from ..data import get_data_file_path from ..data import get_data_file_path
from ..factory import ProceduralDataset, register_dataset from ..factory import ProceduralDataset, register_dataset
QUESTION_TEMPLATE = """Transform the word ladder '{start}' to '{end}' by changing one letter at a time.
Provide your answer as a comma-separated sequence of uppercase letters without spaces.
Each step must be a valid English word."""
@dataclass @dataclass
class WordLadderConfig: class WordLadderConfig:
@ -211,7 +215,7 @@ class WordLadderDataset(ProceduralDataset):
raise IndexError(f"Dataset exhausted at index {idx}. {str(e)}") raise IndexError(f"Dataset exhausted at index {idx}. {str(e)}")
return { return {
"question": f"Transform the word ladder '{start}' to '{end}' by changing one letter at a time.", "question": QUESTION_TEMPLATE.format(start=start, end=end),
"answer": ",".join(path), "answer": ",".join(path),
"metadata": {"start_word": start, "end_word": end, "word_length": length, "chain_length": len(path)}, "metadata": {"start_word": start, "end_word": end, "word_length": length, "chain_length": len(path)},
} }

View file

@ -19,6 +19,23 @@ class TextTransformation(StrEnum):
RANDOMCASE = "randomcase" RANDOMCASE = "randomcase"
QUESTION_TEMPLATE = """Your task is to sort words in ascending or descending order using ASCII/Unicode ordering.
Example:
- Input: Sort these words in ascending order (using ASCII/Unicode ordering) and return them as a comma-separated list: freely, idea, indemnify, last, END, solving
- Output: END, freely, idea, indemnify, last, solving
- Explanation:
- Uppercase letters come before lowercase letters, hence why "END" comes first.
- "freely" comes before "idea" because "f" comes before "i".
- "idea" comes before "indemnify" because even though they both start with "i", "d" comes before "n".
- "indemnify" comes before "last" because "i" comes before "l".
- "last" comes before "solving" because "l" comes before "s".
- Finally, the output is provided as a comma separated list of the sorted words.
Now, sort these words in {direction} order (using ASCII/Unicode ordering) and return them as a comma-separated list: {words}
"""
@dataclass @dataclass
class WordSortingConfig: class WordSortingConfig:
"""Configuration for word sorting task generation""" """Configuration for word sorting task generation"""
@ -94,7 +111,7 @@ class WordSortingDataset(ProceduralDataset):
answer = asc_words if is_ascending else desc_words answer = asc_words if is_ascending else desc_words
return { return {
"question": f"Sort these words in {direction} order (using ASCII/Unicode ordering) and return them as a comma-separated list:\n{', '.join(transformed_words)}", "question": QUESTION_TEMPLATE.format(direction=direction, words=", ".join(transformed_words)),
"answer": ", ".join(answer), "answer": ", ".join(answer),
"metadata": { "metadata": {
"original_words": original_words, "original_words": original_words,
@ -106,26 +123,17 @@ class WordSortingDataset(ProceduralDataset):
} }
def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float: def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float:
"""Determine if the solution provided solves this task. oracle_answer = entry["metadata"]["sorted_words"]
if answer is not None and len(answer) > 0:
parsed_answer = [word.strip() for word in re.split(r",\s*", answer)]
if parsed_answer == oracle_answer:
return 1.0
elif sorted(parsed_answer) == oracle_answer:
return 0.2
else:
return 0.01
The function awards 1.0 for a correct answer. return 0.0
Args:
answer (Optional[str]): The user's answer.
entry (Dict[str, any]): The original dataset entry containing the correct answer.
Returns:
float: The computed score between 0.0 and 1.0.
"""
if answer == None:
return 0.0
s_answer = answer.strip().replace(" ", "")
if not s_answer == entry["answer"].strip().replace(" ", ""):
return 0.01
else:
return 1.0
register_dataset("word_sorting", WordSortingDataset, WordSortingConfig) register_dataset("word_sorting", WordSortingDataset, WordSortingConfig)

View file

@ -27,6 +27,7 @@ class ArcAgiConfig:
default_factory=lambda: ["horizontal", "vertical", "diagonal", "counterdiagonal"] default_factory=lambda: ["horizontal", "vertical", "diagonal", "counterdiagonal"]
) # empty list for no mirrors ) # empty list for no mirrors
use_color_permutation: bool = True use_color_permutation: bool = True
shuffle_example_order: bool = True # whether to shuffle the order of example board pairs for each riddle
seed: Optional[int] = None seed: Optional[int] = None
size: int = 500 size: int = 500
@ -87,8 +88,8 @@ def cmap(board: Board, colors: list[int]) -> Board:
return [[colors[c] for c in row] for row in board] return [[colors[c] for c in row] for row in board]
ROTATION_AUGMENTATIONS = [identity, rot90, rot180, rot270] # ROTATION_AUGMENTATIONS = [identity, rot90, rot180, rot270]
MIRROR_AUGMENTATIONS = [identity, hmirror, vmirror, dmirror, cmirror] # MIRROR_AUGMENTATIONS = [identity, hmirror, vmirror, dmirror, cmirror]
class ArcAgiDataset(ProceduralDataset): class ArcAgiDataset(ProceduralDataset):
@ -156,6 +157,9 @@ class ArcAgiDataset(ProceduralDataset):
for p in train: for p in train:
augmented_train.append({"input": augment(p["input"]), "output": augment(p["output"])}) augmented_train.append({"input": augment(p["input"]), "output": augment(p["output"])})
if self.config.shuffle_example_order:
rng.shuffle(augmented_train)
examples = [ examples = [
format_board_pair(i + 1, p, formatting_options=self.config.board_format_opts) format_board_pair(i + 1, p, formatting_options=self.config.board_format_opts)
for i, p in enumerate(augmented_train) for i, p in enumerate(augmented_train)

View file

@ -64,6 +64,9 @@ class BasicArithmeticDataset(ProceduralDataset):
def __init__(self, config: BasicArithmeticDatasetConfig): def __init__(self, config: BasicArithmeticDatasetConfig):
super().__init__(config=config, seed=config.seed, size=config.size) super().__init__(config=config, seed=config.seed, size=config.size)
self.added_instruction = (
" Ensure to report the answer as an integer. Do not add commas to the integer answers reported."
)
def __getitem__(self, idx: int) -> dict[str, Any]: def __getitem__(self, idx: int) -> dict[str, Any]:
"""Generate a single arithmetic task """Generate a single arithmetic task
@ -88,7 +91,7 @@ class BasicArithmeticDataset(ProceduralDataset):
else: else:
expression, result = self._generate_simple_task(rng, num_terms, num_digits) expression, result = self._generate_simple_task(rng, num_terms, num_digits)
question = self._format_question(rng, expression) question = self._format_question(rng, expression) + self.added_instruction
return { return {
"question": question, "question": question,
@ -223,12 +226,14 @@ class BasicArithmeticDataset(ProceduralDataset):
return expression, result return expression, result
def _format_question(self, rng: Random, expression: str) -> str: def _format_question(self, rng: Random, expression: str) -> str:
"""Format the expression according to config style""" """Format the the question with the arithmetic expression"""
if self.config.format_style == "simple": if self.config.format_style == "simple":
return f"{expression} =" return f"Calculate {expression}."
else: else:
templates = ["What is {0}?", "Calculate {0}", "Solve {0}", "Evaluate the expression: {0}"] templates = ["What is {0}?", "Solve {0}.", "Compute {0}.", "Evaluate: {0}."]
return rng.choice(templates).format(expression) template = rng.choice(templates)
return template.format(expression)
# Register the dataset # Register the dataset

View file

@ -63,7 +63,7 @@ class ChainSumDataset(ProceduralDataset):
expression, result = self._generate_task(rng, num_terms, min_value, max_value) expression, result = self._generate_task(rng, num_terms, min_value, max_value)
return { return {
"question": f"{expression} =", "question": f"State the final answer to the following arithmetic problem: {expression} =",
"answer": str(result), "answer": str(result),
"metadata": { "metadata": {
"difficulty": { "difficulty": {

View file

@ -1,12 +1,15 @@
"""Fraction simplification task generator""" """Fraction simplification task generator"""
import re
from dataclasses import dataclass from dataclasses import dataclass
from math import gcd from math import gcd
from random import Random from random import Random
from typing import Optional, Sequence, Tuple from typing import Any, Dict, Optional, Sequence, Tuple
from ..factory import ProceduralDataset, register_dataset from ..factory import ProceduralDataset, register_dataset
QUESTION_TEMPLATE = "Simplify the fraction {question_fraction} to its lowest terms. Give only the simplified fraction as your final answer."
@dataclass @dataclass
class FractionSimplificationConfig: class FractionSimplificationConfig:
@ -107,7 +110,7 @@ class FractionSimplificationDataset(ProceduralDataset):
answer_fraction = self._format_fraction(simple_num, simple_den, style) answer_fraction = self._format_fraction(simple_num, simple_den, style)
return { return {
"question": f"Simplify the fraction {question_fraction} to its lowest terms", "question": QUESTION_TEMPLATE.format(question_fraction=question_fraction),
"answer": answer_fraction, "answer": answer_fraction,
"metadata": { "metadata": {
"numerator": num, "numerator": num,
@ -119,5 +122,34 @@ class FractionSimplificationDataset(ProceduralDataset):
}, },
} }
def _extract_fraction(self, answer: Optional[str]):
try:
cleaned = answer.strip().strip("$").strip()
latex_match = re.match(r"\\(?:frac|dfrac)\s*{\s*(\d+)\s*}\s*{\s*(\d+)\s*}", cleaned, re.IGNORECASE)
if latex_match:
return int(latex_match.group(1)), int(latex_match.group(2))
if "/" in cleaned:
numerator, denominator = map(str.strip, cleaned.split("/", 1))
return int(numerator), int(denominator)
except:
return None
def score_answer(self, answer: Optional[str], entry: Dict[str, Any]):
reward = 0.0
metadata = entry["metadata"]
try:
numerator, denominator = self._extract_fraction(answer)
if numerator == metadata["simplified_numerator"] and denominator == metadata["simplified_denominator"]:
reward = 1.0
elif numerator == metadata["numerator"] or denominator == metadata["denominator"]:
reward = 0.1
elif len(answer.strip()) > 0:
reward = 0.05
else:
reward = 0.01
except:
reward = 0.01
return reward
register_dataset("fraction_simplification", FractionSimplificationDataset, FractionSimplificationConfig) register_dataset("fraction_simplification", FractionSimplificationDataset, FractionSimplificationConfig)

View file

@ -57,7 +57,7 @@ class GCDDataset(ProceduralDataset):
numbers_str = ", ".join(str(n) for n in numbers) numbers_str = ", ".join(str(n) for n in numbers)
return { return {
"question": f"Find the Greatest Common Divisor (GCD) of these numbers: {numbers_str}", "question": f"Find the Greatest Common Divisor (GCD) of these numbers: {numbers_str}. Give only the GCD as your final answer.",
"answer": str(result), "answer": str(result),
"metadata": {"numbers": numbers, "result": result}, "metadata": {"numbers": numbers, "result": result},
} }

View file

@ -148,7 +148,9 @@ class GSMSymbolicDataset(ProceduralDataset):
rng = Random(self.seed + idx) rng = Random(self.seed + idx)
generator_idx = self.task_indices[idx] generator_idx = self.task_indices[idx]
generator = self.generators[generator_idx] generator = self.generators[generator_idx]
return generator(rng, self.config.difficulty) example = generator(rng, self.config.difficulty)
example["question"] += " Give only the result as your final answer."
return example
register_dataset("gsm_symbolic", GSMSymbolicDataset, GSMSymbolicDatasetConfig) register_dataset("gsm_symbolic", GSMSymbolicDataset, GSMSymbolicDatasetConfig)

View file

@ -54,14 +54,29 @@ ANIMALS = {
"woodlouse": 14, "woodlouse": 14,
} }
QUESTION_TEMPLATE = """Your task is to count how many legs there are in total when given a list of animals.
Example:
- Input: How many legs are there in total if you have 1 duck, 2 deers, 1 spider, 3 cows?
- Output: 30
- Explanation:
- Ducks have 2 legs each, so 1 duck has 2 legs.
- Deers have 4 legs each, so 2 deers have 8 legs.
- Spiders have 8 legs each, so 1 spider has 8 legs.
- Cows have 4 legs each, so 3 cows have 12 legs.
- Therefore, the total number of legs is 2 + 8 + 8 + 12 = 30
Now, how many legs are there in total if you have {animals}?
"""
@dataclass @dataclass
class LegCountingConfig: class LegCountingConfig:
"""Configuration for leg counting task generation""" """Configuration for leg counting task generation"""
min_animals: int = 2 # Minimum number of animals in problem min_animals: int = 3 # Minimum number of animals in problem
max_animals: int = 5 # Maximum number of animals max_animals: int = 10 # Maximum number of animals
max_instances: int = 3 # Maximum instances of each animal max_instances: int = 15 # Maximum instances of each animal
seed: Optional[int] = None seed: Optional[int] = None
size: int = 500 # Virtual dataset size size: int = 500 # Virtual dataset size
@ -106,10 +121,8 @@ class LegCountingDataset(ProceduralDataset):
for animal, count in animals.items(): for animal, count in animals.items():
animal_list.append(f"{count} {animal}{'s' if count > 1 else ''}") animal_list.append(f"{count} {animal}{'s' if count > 1 else ''}")
question = "How many legs are there in total if you have " + ", ".join(animal_list) + "?"
return { return {
"question": question, "question": QUESTION_TEMPLATE.format(animals=", ".join(animal_list)),
"answer": str(total_legs), "answer": str(total_legs),
"metadata": { "metadata": {
"difficulty": { "difficulty": {

View file

@ -7,7 +7,24 @@ from typing import Dict, Optional
from ..factory import ProceduralDataset, register_dataset from ..factory import ProceduralDataset, register_dataset
QUESTION_TEMPLATE = """Compute {base}^{exponent}""" QUESTION_TEMPLATE = """Your task is to compute an exponentiation of a number.
Example:
- Input: Compute 2^3
- Output: 8
- Explanation:
- 2^3 = 2 * 2 * 2 = 8
- Therefore, the final answer is 8
Example:
- Input: Compute 412.5^3
- Output: 70189453.125
- Explanation:
- 412.5^3 = 412.5 * 412.5 * 412.5 = 70189453.125
- Therefore, the final answer is 70189453.125
Compute {base}^{exponent}
"""
@dataclass @dataclass
@ -32,28 +49,31 @@ class PowerFunctionDataset(ProceduralDataset):
def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float: def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float:
"""Overwrite this method in derived classes if a single oracle answer is not available.""" """Overwrite this method in derived classes if a single oracle answer is not available."""
oracle_answer = entry["answer"] oracle_answer = entry["answer"]
reward = 0.0
if answer is not None: if answer is not None:
difference = abs(float(answer) - float(oracle_answer)) try:
if difference < 1e-6: answer = round(float(answer), 4)
reward = 1.0 oracle_answer = round(float(oracle_answer), 4)
elif difference < 1e-1: difference = abs(float(answer) - float(oracle_answer))
reward = 0.5 if difference < 1e-4:
else: return 1.0
reward = 0.01 elif difference < 1e-1:
return 0.5
return reward else:
return 0.01
except Exception as e:
return 0.01
return 0.0
def __getitem__(self, idx: int) -> dict: def __getitem__(self, idx: int) -> dict:
"""Generate a single Power Function question""" """Generate a single Power Function question"""
rng = Random(self.seed + idx) rng = Random(self.seed + idx)
base = rng.uniform(self.config.min_base, self.config.max_base) base = round(rng.uniform(self.config.min_base, self.config.max_base), 4)
exponent = rng.randint(self.config.min_exponent, self.config.max_exponent) exponent = rng.randint(self.config.min_exponent, self.config.max_exponent)
answer = pow(base, exponent) answer = pow(base, exponent)
return { return {
"question": f"Compute {base}^{exponent}", "question": QUESTION_TEMPLATE.format(base=base, exponent=exponent),
"answer": str(answer), "answer": str(answer),
"metadata": {"base": base, "exponent": exponent, "solution": answer}, "metadata": {"base": base, "exponent": exponent, "solution": answer},
} }

View file

@ -57,7 +57,7 @@ class ProductsDataset(ProceduralDataset):
expression, result = self._generate_task(rng, num_terms, min_value, max_value) expression, result = self._generate_task(rng, num_terms, min_value, max_value)
return { return {
"question": f"{expression} =", "question": f"Solve the following multiplication: {expression}. Give only the result as your final answer.",
"answer": str(result), "answer": str(result),
"metadata": { "metadata": {
"difficulty": { "difficulty": {

View file

@ -28,7 +28,8 @@ class BFDataset(ProceduralDataset):
def __init__(self, config: BFConfig): def __init__(self, config: BFConfig):
self._prompt_templates = [ self._prompt_templates = [
"This is a BF (Brainf*ck) computer program. What is the output? Reply only with the program output, ex: 42. \n\n{bf_program}", "This is a BF (Brainf*ck) computer program. What is the output?\n\n{bf_program}\n\nRespond only with the exact output of the program.",
"Consider the following BF (Brainf*ck) code. What would it output?\n\n{bf_program}\n\nProvide only the exact output of the code.",
] ]
super().__init__(config=config, seed=config.seed, size=config.size) super().__init__(config=config, seed=config.seed, size=config.size)
@ -123,6 +124,13 @@ int main() {{
if answer == None: if answer == None:
return 0.0 return 0.0
if answer != entry["answer"]: if answer != entry["answer"]:
if entry["answer"] in answer.splitlines():
# We can be quite confident that the correct answer was given
# It was likely just given alongside an explanation
return max(0.9 * len(answer) / len(entry["answer"]), 0.1)
if entry["answer"] in answer:
# Since answers are English words, some risk of the response coincidentally containing the answer
return max(0.5 * len(answer) / len(entry["answer"]), 0.1)
return 0.01 return 0.01
else: else:
return 1.0 # Yay return 1.0 # Yay

View file

@ -4,6 +4,42 @@ from typing import Dict, Optional
from ..factory import ProceduralDataset, register_dataset from ..factory import ProceduralDataset, register_dataset
QUESTION_TEMPLATE = """Your task is to count how many rectangles are present in an ASCII grid.
Single rectangles are outlined with a '#', overlapping rectangles (max 2) are shown with ''.
Example:
- Input: How many rectangles are in the grid below?
####
# #
####
#########
# █##
# █ #
########█ #
# #
###
- Output: 3
- Explanation:
- The first rectangle is the 3x4 rectangle in the top right.
- The other two rectangles are overlapping in the bottom left corner.
- Therefore, the final answer is 3.
Now, it's your turn. How many rectangles do you see in the grid below?
{puzzle}
"""
def draw_rectangles_with_overlap(n, width, height, rng): def draw_rectangles_with_overlap(n, width, height, rng):
# Create a grid that holds a count of how many times a cell is drawn. # Create a grid that holds a count of how many times a cell is drawn.
@ -103,12 +139,10 @@ class RectangleCountDataset(ProceduralDataset):
target = rng.randint(1, self.config.max_rectangles) target = rng.randint(1, self.config.max_rectangles)
puzzle, answer = draw_rectangles_with_overlap(target, self.config.width, self.config.height, rng) puzzle, answer = draw_rectangles_with_overlap(target, self.config.width, self.config.height, rng)
puzz = f"How many rectangles do you see? Single rectangles are outlined with a '#', overlapping rectangles (max 2) are shown with ''. \n\n {puzzle}"
return { return {
"question": puzz, "question": QUESTION_TEMPLATE.format(puzzle=puzzle),
"answer": str(answer), "answer": str(answer),
"metadata": {}, "metadata": {"puzzle": puzzle, "solution": answer},
} }
def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float: def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float:

View file

@ -53,9 +53,10 @@ class ProceduralDataset(ABC, Sized, Iterable[Dict[str, Any]]):
def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float: def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float:
"""Overwrite this method in derived classes if a single oracle answer is not available.""" """Overwrite this method in derived classes if a single oracle answer is not available."""
oracle_answer = entry["answer"] oracle_answer = entry["answer"].strip()
reward = 0.0 reward = 0.0
if answer is not None: if answer is not None and len(answer) > 0:
answer = answer.strip()
if answer == oracle_answer: if answer == oracle_answer:
reward = 1.0 reward = 1.0
elif oracle_answer in answer: elif oracle_answer in answer:

View file

@ -7,6 +7,7 @@ Game tasks for training reasoning capabilities:
""" """
from .countdown import CountdownConfig, CountdownDataset from .countdown import CountdownConfig, CountdownDataset
from .futoshiki import FutoshikiConfig, FutoshikiDataset
from .knight_swap import KnightSwapConfig, KnightSwapDataset from .knight_swap import KnightSwapConfig, KnightSwapDataset
from .maze import MazeConfig, MazeDataset from .maze import MazeConfig, MazeDataset
from .mini_sudoku import MiniSudokuConfig, MiniSudokuDataset from .mini_sudoku import MiniSudokuConfig, MiniSudokuDataset
@ -20,6 +21,8 @@ from .tsumego import TsumegoConfig, TsumegoDataset
__all__ = [ __all__ = [
"CountdownConfig", "CountdownConfig",
"CountdownDataset", "CountdownDataset",
"FutoshikiConfig",
"FutoshikiDataset",
"MiniSudokuConfig", "MiniSudokuConfig",
"MiniSudokuDataset", "MiniSudokuDataset",
"SudokuConfig", "SudokuConfig",

View file

@ -8,6 +8,15 @@ from sympy.parsing.sympy_parser import parse_expr
from ..factory import ProceduralDataset, register_dataset from ..factory import ProceduralDataset, register_dataset
QUESTION_FORMAT_TEMPLATE = """{question}
Final answer format instructions:
1. Provide your solution as a arithmetic expression (no '=' sign).
2. Do not include the target number in the expression.
3. Use '*' for multiplication.
4. Use '/' for division.
5. Do not include any other text or formatting.
"""
@dataclass @dataclass
class CountdownConfig: class CountdownConfig:
@ -67,8 +76,11 @@ class CountdownDataset(ProceduralDataset):
numbers_str = ", ".join(map(str, numbers)) numbers_str = ", ".join(map(str, numbers))
question = rng.choice(self._prompt_templates)
question = question.format(numbers=numbers_str, target=target)
return { return {
"question": rng.choice(self._prompt_templates).format(numbers=numbers_str, target=target), "question": QUESTION_FORMAT_TEMPLATE.format(question=question),
"answer": expression, "answer": expression,
"metadata": { "metadata": {
"numbers": numbers, "numbers": numbers,

View file

@ -0,0 +1,656 @@
"""Futoshiki puzzle generator"""
import copy
import itertools
from dataclasses import dataclass
from random import Random
from typing import Any, Dict, List, Optional, Set, Tuple
from ..factory import ProceduralDataset, register_dataset
@dataclass
class FutoshikiConfig:
"""Configuration for Futoshiki puzzle generation"""
board_size: int = 4 # Board will be NxN where N is this value
difficulty: int = 1 # Possible values: 0, 1, 2, 3
seed: Optional[int] = None
size: int = 500 # Virtual dataset size
def validate(self):
"""Validate configuration parameters"""
assert 4 <= self.board_size <= 9, "board_size must be between 4 and 9"
assert 0 <= self.difficulty <= 3, "difficulty must be between 0 and 3"
class FutoshikiDataset(ProceduralDataset):
"""Generates Futoshiki puzzles with configurable board size and difficulty"""
def __init__(self, config: FutoshikiConfig):
super().__init__(config=config, seed=config.seed, size=config.size)
def __len__(self) -> int:
return self.config.size
def __iter__(self):
self._current_idx = 0
return self
def __next__(self):
if self._current_idx >= self.config.size:
raise StopIteration
item = self[self._current_idx]
self._current_idx += 1
return item
def __getitem__(self, idx: int) -> dict:
"""
Generate a single Futoshiki puzzle with blanks, represented by 0s, and constraints.
Clues are pre-filled numbers in the grid.
Constraints are adjacent cell pairs which may have '<' or '>' relations.
Difficulty in [0..3] affects number of clues and constraints.
"""
rng = Random(self.seed + idx)
# Generate random "solved" Futoshiki grid
solution = self._generate_random_solution(self.config.board_size, rng)
# Add random adjacency constraints consistent with generated solved grid
constraints = self._generate_random_constraints(solution, self.config.difficulty, rng)
# Starting with full solution, remove clues to desired difficulty
puzzle = self._remove_clues(copy.deepcopy(solution), constraints, rng)
# Format as strings
puzzle_str = self._puzzle_to_string(puzzle, constraints)
solution_str = self._puzzle_to_string(solution, constraints)
question = (
f"Solve the following {self.config.board_size}x{self.config.board_size} Futoshiki puzzle:\n\n"
f"{puzzle_str}\n\n"
"Ensure your answer follows the same format as the puzzle above, just replace blanks (_) with the correct value for the cell.\n"
"Use < and > for horizontal constraints. Use \u2227 and \u2228 for vertical constraints.\n"
f"Remember, in Futoshiki each row and column must contain each number from 1 to {self.config.board_size} exactly once."
)
return {
"question": question,
"answer": solution_str,
"metadata": {
"puzzle": puzzle,
"constraints": constraints,
"solution": solution,
"board_size": self.config.board_size,
"difficulty": self.config.difficulty,
},
}
def _puzzle_to_string(
self, puzzle_grid: List[List[int]], constraints: Dict[Tuple[Tuple[int, int], Tuple[int, int]], str]
) -> str:
"""
Formats a Futoshiki puzzle grid as a string with constraints.
Constraints are represented as '<', '>', '\u2227', or '\u2228' between adjacent cells.
"""
n = len(puzzle_grid)
def cell_str(val: int) -> str:
return str(val) if val != 0 else "_"
# Helper to look up constraints between two adjacent cells
# Ensures the first tuple is always the “lesser” in row-major order
# If order is reversed in the dict, invert the constraint
def get_constraint(r1, c1, r2, c2) -> Optional[str]:
if (r1, c1) == (r2, c2):
return None
if (r1, c1) < (r2, c2):
key = ((r1, c1), (r2, c2))
sign = constraints.get(key)
if sign == ">": # first is bigger
if r1 == r2: # horizontal
return ">"
else: # vertical
return "\u2228"
elif sign == "<": # first is smaller
if r1 == r2: # horizontal
return "<"
else:
return "\u2227"
else:
# reversed order in the dictionary -> invert the sign
key = ((r2, c2), (r1, c1))
sign = constraints.get(key)
if sign == ">":
if r1 == r2:
return "<"
else:
return "\u2227"
elif sign == "<":
if r1 == r2:
return ">"
else:
return "\u2228"
return None
lines = []
for r in range(n):
# Build the row string with horizontal constraints
row_cells = []
for c in range(n):
row_cells.append(cell_str(puzzle_grid[r][c]))
if c < n - 1:
hc = get_constraint(r, c, r, c + 1)
row_cells.append(hc if hc else " ")
lines.append(" ".join(row_cells))
# If not the last row, build the line of vertical constraints
if r < n - 1:
vert_cells = []
for c in range(n):
vc = get_constraint(r, c, r + 1, c)
if vc:
vert_cells.append(vc)
else:
vert_cells.append(" ")
# Space out columns so vertical symbols line up under the correct spot
if c < n - 1:
vert_cells.append(" ")
lines.append(" ".join(vert_cells))
return "\n".join(lines)
def _solve_logical(
self,
grid: List[List[int]],
constraints: Dict[Tuple[Tuple[int, int], Tuple[int, int]], str],
) -> Tuple[List[List[int]], List[List[Set[int]]]]:
"""
Apply logical rules to progress solution.
Returns current state if logical rules can't progress further.
Logical rules are implemented based on the descriptions here: https://futoshiki.uk/
"""
size, working_grid = len(grid), copy.deepcopy(grid)
# Starting point all numbers are candidates for all unfilled squares
candidates: List[List[Set[int]]] = [
[set(range(1, len(grid) + 1)) if grid[r][c] == 0 else {grid[r][c]} for c in range(len(grid))]
for r in range(len(grid))
]
# Any cells > another cannot be 1, and any cells < another cannot be `size`
# This is separated from the repeated function below to avoid redundant checks
for ((r1, c1), (_, _)), rel in constraints.items():
if rel == ">":
candidates[r1][c1].discard(1)
elif rel == "<":
candidates[r1][c1].discard(size)
def _update_grid():
"""Update solution wherever a cell's candidates set is reduced to a single element."""
for r in range(len(working_grid)):
for c in range(len(working_grid)):
if working_grid[r][c] == 0 and len(candidates[r][c]) == 1:
working_grid[r][c] = next(iter(candidates[r][c]))
def _try_solve_logical() -> bool:
progress = False
# Eliminate candidates based on numbers already placed
# If a number is placed in a cell, it cannot be a candidate for any other cell in the same row or column
for r in range(len(working_grid)):
for c in range(len(working_grid)):
if working_grid[r][c] == 0:
continue
for cc in range(len(working_grid)):
if cc != c and working_grid[r][c] in candidates[r][cc]:
candidates[r][cc].discard(working_grid[r][c])
progress = True
for rr in range(len(working_grid)):
if rr != r and working_grid[r][c] in candidates[rr][c]:
candidates[rr][c].discard(working_grid[r][c])
progress = True
_update_grid()
# Eliminate candidates based on constraints
# Based on currently filled values, eliminate candidates that violate constraints
def _eliminate_by_constraint(rc_less: Tuple[int, int], rc_greater: Tuple[int, int]) -> bool:
r_less, c_less = rc_less
r_greater, c_greater = rc_greater
progress = False
if working_grid[r_less][c_less] != 0:
# greater must only have candidates > less
for v in candidates[r_greater][c_greater].copy():
if v <= working_grid[r_less][c_less] and v in candidates[r_greater][c_greater]:
candidates[r_greater][c_greater].discard(v)
progress = True
if working_grid[r_greater][c_greater] != 0:
# less must only have candidates < greater
for v in candidates[r_less][c_less].copy():
if v >= working_grid[r_greater][c_greater] and v in candidates[r_less][c_less]:
candidates[r_less][c_less].discard(v)
progress = True
return progress
for ((r1, c1), (r2, c2)), rel in constraints.items():
v1, v2 = working_grid[r1][c1], working_grid[r2][c2]
if v1 != 0 and v2 != 0: # both already filled, skip
continue
if rel == "<":
progress |= _eliminate_by_constraint((r1, c1), (r2, c2))
elif rel == ">":
progress |= _eliminate_by_constraint((r2, c2), (r1, c1))
_update_grid()
# Seek "hidden singles" - cells where a candidate is unique in the row or column
for r in range(len(working_grid)):
for c in range(len(working_grid)):
if working_grid[r][c] != 0:
continue
for v in candidates[r][c]:
if sum(v in candidates[r][cc] for cc in range(len(working_grid))) == 1:
candidates[r][c] = {v} # candidate unique in row
break
if sum(v in candidates[rr][c] for rr in range(len(working_grid))) == 1:
candidates[r][c] = {v} # candidate unique in column
break
_update_grid()
# Seek "naked pairs" if same pair of candidates twice in a row or col, with nothing else in those two cells
# Remove them from other cells in row/col
for r in range(len(working_grid)):
for c in range(len(working_grid)):
if working_grid[r][c] != 0 or len(candidates[r][c]) != 2:
continue
for cc in range(len(working_grid)):
if cc != c and candidates[r][cc] == candidates[r][c]:
for ccc in range(len(working_grid)):
if ccc != c and ccc != cc and candidates[r][c].intersection(candidates[r][ccc]):
candidates[r][ccc] -= candidates[r][c]
progress = True
for rr in range(len(working_grid)):
if rr != r and candidates[rr][c] == candidates[r][c]:
for rrr in range(len(working_grid)):
if rrr != r and rrr != rr and candidates[r][c].intersection(candidates[rrr][c]):
candidates[rrr][c] -= candidates[r][c]
progress = True
_update_grid()
# Seek "hidden pairs" - same pair of candidates occurs in two cells in a line, but nowhere else in the line
# alongside other candidates (hence hidden). All other candidates can be removed from those two cells
for r in range(len(working_grid)):
for c in range(len(working_grid)):
if working_grid[r][c] != 0:
continue
for cc in range(c + 1, len(working_grid)):
if working_grid[r][cc] != 0:
continue
# Check if r, c shares a candidate pair with r, cc (maybe subset, not exact candidate set match)
r_c_pairs = itertools.permutations(candidates[r][c], 2)
r_cc_pairs = itertools.permutations(candidates[r][cc], 2)
for pair in r_c_pairs:
if pair not in r_cc_pairs:
continue
otherwise_unique = True
# If this pair occurs elsewhere in the row, skip
for ccc in range(len(working_grid)):
if ccc in (c, cc):
continue
if pair in itertools.permutations(candidates[r][ccc], 2):
otherwise_unique = False
break
if not otherwise_unique:
continue
# Found a hidden pair, remove all other candidates from these two cells
candidates[r][c] = set(pair)
candidates[r][cc] = set(pair)
for rr in range(r + 1, len(working_grid)):
if working_grid[rr][c] != 0:
continue
# Check if r, c shares a candidate pair with rr, c (maybe subset, not exact candidate set match)
r_c_pairs = itertools.permutations(candidates[r][c], 2)
rr_c_pairs = itertools.permutations(candidates[rr][c], 2)
for pair in r_c_pairs:
if pair not in rr_c_pairs:
continue
otherwise_unique = True
# If this pair occurs elsewhere in the col, skip
for rrr in range(len(working_grid)):
if rrr in (r, rr):
continue
if pair in itertools.permutations(candidates[rrr][c], 2):
otherwise_unique = False
break
if not otherwise_unique:
continue
# Found a hidden pair, remove all other candidates from these two cells
candidates[r][c] = set(pair)
candidates[rr][c] = set(pair)
_update_grid()
# Seek X-wings by rows
for v in range(1, size + 1):
# If candidate is in the same 2 positions in 2 rows, and nowhere else in those rows
# Delete from the 2 intersecting cols
# Find rows which have exactly 2 instances of the value in their candidate sets
rows_with_v = [r for r in range(size) if sum(v in candidates[r][c] for c in range(size)) == 2]
if len(rows_with_v) < 2:
continue
# Check whether the 2 columns with the value are the same in the 2 rows
cols_with_v_per_row = [set() for _ in range(len(rows_with_v))]
for i, r in enumerate(rows_with_v):
for c in range(size):
if v in candidates[r][c]:
cols_with_v_per_row[i].add(c)
# Check if there are a pair of tows with the same cols (there may be more than 2 rows)
for i in range(len(rows_with_v)):
for j in range(i + 1, len(rows_with_v)):
if cols_with_v_per_row[i] == cols_with_v_per_row[j]:
# Found an X-wing, remove candidate from the 2 cols
for c in cols_with_v_per_row[i]:
for rr in range(size):
if rr not in (rows_with_v[i], rows_with_v[j]) and v in candidates[rr][c]:
candidates[rr][c].discard(v)
progress = True
# Seek X-wings by cols
for v in range(1, size + 1):
# If candidate is in the same 2 positions in 2 cols, and nowhere else in those cols
# Delete from the 2 intersecting rows
# Find cols which have exactly 2 instances of the value in their candidate sets
cols_with_v = [c for c in range(size) if sum(v in candidates[r][c] for r in range(size)) == 2]
if len(cols_with_v) < 2:
continue
# Check whether the 2 rows with the value are the same in the 2 cols
rows_with_v_per_col = [set() for _ in range(len(cols_with_v))]
for i, c in enumerate(cols_with_v):
for r in range(size):
if v in candidates[r][c]:
rows_with_v_per_col[i].add(r)
# Check if there are a pair of cols with the same rows (there may be more than 2 cols)
for i in range(len(cols_with_v)):
for j in range(i + 1, len(cols_with_v)):
if rows_with_v_per_col[i] == rows_with_v_per_col[j]:
# Found an X-wing, remove candidate from the 2 rows
for r in rows_with_v_per_col[i]:
for cc in range(size):
if cc not in (cols_with_v[i], cols_with_v[j]) and v in candidates[r][cc]:
candidates[r][cc].discard(v)
progress = True
_update_grid()
return progress
while _try_solve_logical():
continue
return working_grid, candidates
def _solve(
self,
grid: List[List[int]],
constraints: Dict[Tuple[Tuple[int, int], Tuple[int, int]], str],
) -> List[List[int]] | None:
"""
Backtracking Futoshiki solver. Used to verify generated puzzles.
Applies logical rules first then backtracks to fill gaps.
Return solved grid, or None if unsolvable.
"""
grid, candidates = self._solve_logical(grid, constraints)
size = len(grid)
empty_cell = None
# Find first empty cell
for r in range(size):
for c in range(size):
if grid[r][c] == 0:
empty_cell = (r, c)
break
if empty_cell:
break
# If no empty cell, solution is complete
if not empty_cell:
return copy.deepcopy(grid)
r, c = empty_cell
for val in range(1, size + 1):
if val not in candidates[r][c]:
continue
if self._is_valid(grid, r, c, val, constraints):
grid[r][c] = val
solution = self._solve(grid, constraints)
if solution is not None:
grid[r][c] = 0
return solution
grid[r][c] = 0
return None
def _is_valid(
self,
grid: List[List[int]],
row: int,
col: int,
val: int,
constraints: Dict[Tuple[Tuple[int, int], Tuple[int, int]], str],
) -> bool:
"""Check row, col, and inequality constraints for placing val in grid[row][col]."""
size = len(grid)
# Row or column conflict?
if val in grid[row]:
return False
for r in range(size):
if grid[r][col] == val:
return False
# Temporarily place the val and check constraints
original_val = grid[row][col]
grid[row][col] = val
# Check all constraints involving this cell
for ((r1, c1), (r2, c2)), rel in constraints.items():
v1 = grid[r1][c1]
v2 = grid[r2][c2]
# If either is 0, skip
if v1 == 0 or v2 == 0:
continue
# If relation is '<', v1 < v2 must hold
if rel == "<":
if not (v1 < v2):
grid[row][col] = original_val
return False
elif rel == ">":
if not (v1 > v2):
grid[row][col] = original_val
return False
grid[row][col] = original_val
return True
def _generate_random_solution(self, size: int, rng: Random) -> List[List[int]]:
"""
Generates a random valid completed Futoshiki solution with numbers 1..size.
Ensures each row and column has unique numbers.
"""
# Fill row by row with a random permutation, ensuring no column conflicts. Use backtracking
grid = [[0] * size for _ in range(size)]
def backtrack(r):
if r == size:
return True
nums = list(range(1, size + 1))
rng.shuffle(nums)
for permutation in itertools.permutations(nums):
# Place row if columns are valid
valid = True
for c in range(size):
if any(grid[rr][c] == permutation[c] for rr in range(r)):
valid = False
break
if valid:
grid[r] = list(permutation)
if backtrack(r + 1):
return True
return False
if backtrack(0):
return grid
raise ValueError("Could not generate a random solution.")
def _generate_random_constraints(
self, solution: List[List[int]], difficulty: int, rng: Random
) -> Dict[Tuple[Tuple[int, int], Tuple[int, int]], str]:
"""
Randomly add inequality constraints that match the solution.
We only add constraints for adjacent cells (horizontal or vertical).
Probability of adding a constraint can scale with difficulty.
"""
size = len(solution)
constraints = {}
# For each pair of adjacent cells, we might add a constraint
# P(adding a constraint) increases with difficulty on an arbitrary scale
base_prob = 0.03 + 0.07 * difficulty
for r in range(size):
for c in range(size):
# Horizontal neighbor
if c < size - 1:
if rng.random() < base_prob:
if solution[r][c] < solution[r][c + 1]:
constraints[((r, c), (r, c + 1))] = "<"
else:
constraints[((r, c), (r, c + 1))] = ">"
# Vertical neighbor
if r < size - 1:
if rng.random() < base_prob:
if solution[r][c] < solution[r + 1][c]:
constraints[((r, c), (r + 1, c))] = "<"
else:
constraints[((r, c), (r + 1, c))] = ">"
return constraints
def count_solutions(self, grid, constraints, limit=2) -> int:
size = len(grid)
count = 0
def backtrack():
nonlocal count
# Early exit if limit reached
if count >= limit:
return
# Find the next empty cell
for r in range(size):
for c in range(size):
if grid[r][c] == 0:
for val in range(1, size + 1):
if self._is_valid(grid, r, c, val, constraints):
grid[r][c] = val
backtrack()
grid[r][c] = 0
return
count += 1
backtrack()
return count
def _remove_clues(
self,
grid: List[List[int]],
constraints: Dict[Tuple[Tuple[int, int], Tuple[int, int]], str],
rng: Random,
) -> List[List[int]]:
"""
Remove clues from a full solution to try to maintain a unique-solution puzzle.
We remove in random order until we reach our target, or can't without losing uniqueness.
"""
size = len(grid)
fill_fraction = 0.1
target_filled = int(fill_fraction * (size * size))
coords = [(r, c) for r in range(size) for c in range(size)]
rng.shuffle(coords)
def _count_filled_cells(g):
return sum(g[r][c] != 0 for r in range(size) for c in range(size))
def _try_remove():
for r, c in coords:
if _count_filled_cells(grid) <= target_filled:
break # Removal target hit
saved = grid[r][c]
if saved == 0:
continue
# Try remove
grid[r][c] = 0
# Check if unsolvable
sol = self._solve(copy.deepcopy(grid), constraints)
if sol is None:
grid[r][c] = saved
continue
# Check if not unique
if self.count_solutions(copy.deepcopy(grid), constraints, limit=2) > 1:
grid[r][c] = saved
_try_remove()
# Second pass if we aren't close to target
if _count_filled_cells(grid) > 2 * target_filled:
rng.shuffle(coords)
_try_remove()
return grid
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
if not answer:
return 0.0
oracle_answer = entry["answer"]
metadata = entry["metadata"]
solution: list[list[int]] = metadata["solution"]
board_size: int = len(solution[0])
# 1. match answer without trailing whitespaces
answer_stripped = "\n".join(l.rstrip() for l in answer.split("\n"))
oracle_answer_stripped = "\n".join(l.rstrip() for l in oracle_answer.split("\n"))
if answer_stripped == oracle_answer_stripped:
reward = 1.0
else:
# 2. accept answers with correct numeric sequence (ignoring non-numeric characters)
row = 0
num_matching = 0
for ln in answer.split("\n"):
numbers = [int(c) for c in ln if c.isnumeric()]
if len(numbers) != len(solution[0]):
continue # ignore lines without numbers
for a, b in zip(solution[row], numbers):
if a == b:
num_matching += 1
row += 1
reward = num_matching / (board_size * board_size)
reward *= 0.9 # penalty for not using standard format
if len(answer) > len(oracle_answer):
reward *= len(oracle_answer) / len(answer) # penalty for additional length
return reward
register_dataset("futoshiki", FutoshikiDataset, FutoshikiConfig)

View file

@ -95,7 +95,8 @@ class MazeDataset(ProceduralDataset):
+ "\n```" + "\n```"
+ "\nLegend: " + "\nLegend: "
+ f"'{self.wall_char}' = Wall, '{self.path_char}' = Passage\n\n" + f"'{self.wall_char}' = Wall, '{self.path_char}' = Passage\n\n"
+ "What is the minimum number of steps to reach the goal?" + "What is the minimum number of steps to reach the goal?\n"
+ "Give only the number of steps as your final answer, no other text or formatting."
) )
return { return {

View file

@ -2,7 +2,7 @@
from dataclasses import dataclass from dataclasses import dataclass
from random import Random from random import Random
from typing import List, Optional, Tuple from typing import Any, List, Optional, Tuple
from ..factory import ProceduralDataset, register_dataset from ..factory import ProceduralDataset, register_dataset
@ -141,11 +141,55 @@ class MiniSudokuDataset(ProceduralDataset):
puzzle_str = self._board_to_string(puzzle) puzzle_str = self._board_to_string(puzzle)
solution_str = self._board_to_string(solved_board) solution_str = self._board_to_string(solved_board)
question = (
"In 4x4 Mini Sudoku:\n"
"- Each row must contain each number from 1-4 exactly once\n"
"- Each column must contain each number 1-4 exactly once\n"
"- Each 2x2 subgrid must contain each number 1-4 exactly once\n"
f"Solve this 4x4 Mini Sudoku puzzle:\n{puzzle_str}\n"
"Format your response as the puzzle above, with spaces separating each number within a row, and newlines separating rows.\n"
)
return { return {
"question": f"Solve this 4x4 Mini Sudoku puzzle:\n{puzzle_str}", "question": question,
"answer": solution_str, "answer": solution_str,
"metadata": {"puzzle": puzzle, "solution": solved_board, "num_empty": num_empty}, "metadata": {"puzzle": puzzle, "solution": solved_board, "num_empty": num_empty},
} }
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
if not answer:
return 0.0
oracle_answer = entry["answer"]
metadata = entry["metadata"]
solution: list[list[int]] = metadata["solution"]
board_size: int = len(solution[0])
# 1. match answer without trailing whitespaces
answer_stripped = "\n".join(l.rstrip() for l in answer.split("\n"))
oracle_answer_stripped = "\n".join(l.rstrip() for l in oracle_answer.split("\n"))
if answer_stripped == oracle_answer_stripped:
reward = 1.0
else:
# 2. accept answers with correct numeric sequence (ignoring non-numeric characters)
row = 0
num_matching = 0
for ln in answer.split("\n"):
numbers = [int(c) for c in ln if c.isnumeric()]
if len(numbers) != board_size:
continue # ignore lines without numbers
for a, b in zip(solution[row], numbers):
if a == b:
num_matching += 1
row += 1
reward = num_matching / (board_size * board_size)
reward *= 0.9 # penalty for not using standard format
if len(answer) > len(oracle_answer):
reward *= len(oracle_answer) / len(answer) # penalty for additional length
return reward
register_dataset("mini_sudoku", MiniSudokuDataset, MiniSudokuConfig) register_dataset("mini_sudoku", MiniSudokuDataset, MiniSudokuConfig)

View file

@ -14,14 +14,29 @@ from ..factory import ProceduralDataset, register_dataset
MIN_BOARD_SIZE = 4 MIN_BOARD_SIZE = 4
MAX_BOARD_SIZE = 12 MAX_BOARD_SIZE = 12
QUESTION_TEMPLATE = """Solve this N Queens puzzle: QUESTION_TEMPLATE = """Your job is to complete an n x n chess board with n Queens in total, such that no two attack each other.
{puzzle}
The board size is {n}x{n} and your job is to place {num_removed} queen(s) on the board such that no two queens attack each other.
No two queens attack each other if they are not in the same row, column, or diagonal. No two queens attack each other if they are not in the same row, column, or diagonal.
Place a queen by replacing an underscore (_) with a Q. You can place a queen by replacing an underscore (_) with a Q.
Example:
- Input: Given the below board of size 4 x 4 your job is to place 2 queen(s) on the board such that no two queens attack each other.
_ Q _ _
_ _ _ _
_ _ _ _
_ _ Q _
- Output:
_ Q _ _
_ _ _ Q
Q _ _ _
_ _ Q _
- Explanation
- None of the queens attack each other vertically, horizontally, or diagonally.
- The added queens are marked with Q at the positions (1, 3) and (2, 0).
Given the below board of size {n} x {n} your job is to place {num_removed} queen(s) on the board such that no two queens attack each other.
{puzzle}
""" """
@ -137,13 +152,16 @@ class NQueensDataset(ProceduralDataset):
def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float: def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float:
valid_solutions = entry["metadata"]["valid_answers"] valid_solutions = entry["metadata"]["valid_answers"]
reward = 0.0
if answer is not None: if answer is not None:
if answer in valid_solutions: if answer in valid_solutions:
reward = 1.0 return 1.0
else: try:
reward = 0.01 answer = self._board_to_string(eval(answer))
return reward if answer in valid_solutions:
return 0.5
except Exception as e:
return 0.01
return 0.0
register_dataset("n_queens", NQueensDataset, NQueensConfig) register_dataset("n_queens", NQueensDataset, NQueensConfig)

View file

@ -8,6 +8,23 @@ from typing import Any, Dict, List, Optional, Tuple
from ..factory import ProceduralDataset, register_dataset from ..factory import ProceduralDataset, register_dataset
QUESTION_TEMPLATE = """Solve the Tower of Hanoi problem with {num_disks} disks and {num_pegs} pegs.
Move all disks from {start_peg} to {target_peg} following the rules:
- Only one disk can be moved at a time.
- A larger disk cannot be placed on top of a smaller disk.
- All disks must be on a peg at all times.
Example:
Move disk 1 from Peg 1 to Peg 3
Move disk 2 from Peg 1 to Peg 2
Move disk 1 from Peg 3 to Peg 2
Provide the sequence of moves.
Formatting guidelines:
Each instruction should be placed on a single line.
Each line should be formatted as 'Move disk X from Peg Y to Peg Z'
Do not include any other text or formatting.
"""
@dataclass @dataclass
class HanoiConfig: class HanoiConfig:
@ -245,22 +262,13 @@ class HanoiDataset(ProceduralDataset):
# Peg labels # Peg labels
peg_labels = {peg: f"Peg {peg}" for peg in pegs} peg_labels = {peg: f"Peg {peg}" for peg in pegs}
question_str = (
f"Solve the Tower of Hanoi problem with {num_disks} disks and {num_pegs} pegs.\n"
f"Move all disks from {peg_labels[start_peg]} to {peg_labels[target_peg]} following the rules:\n"
"- Only one disk can be moved at a time.\n"
"- A larger disk cannot be placed on top of a smaller disk.\n"
"- All disks must be on a peg at all times.\n"
"Example:\n"
"Move disk 1 from Peg 1 to Peg 3\n"
"Move disk 2 from Peg 1 to Peg 2\n"
"Move disk 1 from Peg 3 to Peg 2\n"
"\n"
"Provide the sequence of moves."
)
result = { result = {
"question": question_str, "question": QUESTION_TEMPLATE.format(
num_disks=num_disks,
num_pegs=num_pegs,
start_peg=peg_labels[start_peg],
target_peg=peg_labels[target_peg],
),
"answer": solution, "answer": solution,
"metadata": { "metadata": {
"num_disks": num_disks, "num_disks": num_disks,
@ -359,7 +367,7 @@ class HanoiDataset(ProceduralDataset):
tuple: (disk, from_peg, to_peg) tuple: (disk, from_peg, to_peg)
""" """
pattern = r"Move disk (\d+) from Peg (\d+) to Peg (\d+)" pattern = r"Move disk (\d+) from Peg (\d+) to Peg (\d+)"
match = re.match(pattern, move) match = re.search(pattern, move)
if not match: if not match:
raise ValueError(f"Unexpected move format: '{move}'") raise ValueError(f"Unexpected move format: '{move}'")

View file

@ -175,9 +175,9 @@ class FamilyRelationshipsDataset(ProceduralDataset):
def __init__(self, config: FamilyRelationshipsConfig): def __init__(self, config: FamilyRelationshipsConfig):
self._templates = [ self._templates = [
"What is {person1} to {person2}?", "What is {person1} to {person2}? Respond only with the word that describes their relationship.",
"How is {person1} related to {person2}?", "How is {person1} related to {person2}? Provide the relationship in one word.",
"What relation is {person1} to {person2}?", "What relation is {person1} to {person2}? Answer with a single word.",
] ]
super().__init__(config=config, seed=config.seed, size=config.size) super().__init__(config=config, seed=config.seed, size=config.size)

View file

@ -3,6 +3,7 @@ Logic tasks for training reasoning capabilities.
""" """
from .aiw import AliceInWonderlandConfig, AliceInWonderlandDataset from .aiw import AliceInWonderlandConfig, AliceInWonderlandDataset
from .circuit_logic import CircuitLogicConfig, CircuitLogicDataset
from .propositional_logic import PropositionalLogicConfig, PropositionalLogicDataset from .propositional_logic import PropositionalLogicConfig, PropositionalLogicDataset
from .self_reference import SelfReferenceConfig, SelfReferenceDataset from .self_reference import SelfReferenceConfig, SelfReferenceDataset
from .syllogisms import SyllogismConfig, SyllogismDataset, Term from .syllogisms import SyllogismConfig, SyllogismDataset, Term
@ -20,5 +21,8 @@ __all__ = [
"ZebraConfig", "ZebraConfig",
"ZebraDataset", "ZebraDataset",
"SelfReference", "SelfReference",
"SelfReferenceConfig",
"SelfReferenceDataset", "SelfReferenceDataset",
"CircuitLogicConfig",
"CircuitLogicDataset",
] ]

View file

@ -187,7 +187,7 @@ class AliceInWonderlandDataset(ProceduralDataset):
num_female_colleagues_bob_circle=num_female_colleagues_bob_circle, num_female_colleagues_bob_circle=num_female_colleagues_bob_circle,
) )
return {"question": question, "answer": answer, "metadata": {"task_type": task_type.value}} return {"question": question, "answer": str(answer), "metadata": {"task_type": task_type.value}}
def __getitem__(self, idx: int) -> dict: def __getitem__(self, idx: int) -> dict:
rng = Random(self.seed + idx) rng = Random(self.seed + idx)

View file

@ -0,0 +1,416 @@
from dataclasses import dataclass
from random import Random
from typing import Any, Dict, List, Optional, Tuple
from ..factory import ProceduralDataset, register_dataset
VERT = ""
HORIZ = ""
RBRANCH = ""
LUP = ""
LDOWN = ""
RUP = ""
RDOWN = ""
def _repeat(s: str, n: int) -> str:
return s * n
def _matrix_put(matrix: List[List[str]], h: int, w: int, x: int, y: int, s: str, direction: str):
"""Place a string `s` into the 2D `matrix` starting at (x,y),
advancing in `direction` ('RIGHT' or 'DOWN')."""
if x >= w or y >= h:
raise IndexError(f"_matrix_put: point ({x}, {y}) out of bounds!")
for c in s:
if x < 0 or x >= w or y < 0 or y >= h:
break
matrix[y][x] = c
if direction == "RIGHT":
x += 1
elif direction == "DOWN":
y += 1
def _get_excel_name(index: int) -> str:
"""
Convert a zero-based integer `index` into an Excel-like column name:
0 -> A, 1 -> B, ..., 25 -> Z, 26 -> AA, etc.
"""
result = ""
index += 1
while index > 0:
rem = (index - 1) % 26
result = chr(ord("A") + rem) + result
index = (index - 1) // 26
return result
@dataclass
class CircuitLogicConfig:
"""
Configuration for circuit logic task generation.
:param num_terms: Number of terms (sub-expressions) to generate
:param min_inputs: Minimum inputs per term
:param max_inputs: Maximum inputs per term
:param neg_prob: Probability (0.0-1.0) that an input is negated
:param allow_reuse: Whether inputs can be reused
:param size: Number of items in the dataset
:param seed: Random seed
"""
num_terms: int = 5
min_inputs: int = 2
max_inputs: int = 4
neg_prob: float = 0.3
allow_reuse: bool = True
size: int = 100
seed: Optional[int] = None
def validate(self):
assert 1 <= self.min_inputs <= self.max_inputs, "Invalid input range"
assert 1 <= self.num_terms, "Invalid number of terms"
assert 0.0 <= self.neg_prob <= 1.0, "neg_prob must be between 0 and 1"
class CircuitLogicDataset(ProceduralDataset):
"""
Generates random digital logic circuits (in ASCII) together with:
- a random Boolean expression,
- random input assignments,
- the final evaluated output.
Each item in the dataset is a dict with:
{
"question": <str>,
"answer": <str>,
"metadata": {
"diagram": <ASCII circuit diagram>,
"expression": <str>,
"term_strings": <list of term_strings>,
"assignments": <dict of input->0/1>,
"final_gate": <str>,
"inputs": <list of input names>,
}
}
"""
def __init__(self, config: CircuitLogicConfig):
super().__init__(config=config, seed=config.seed, size=config.size)
self.config.validate()
self.internal_ops = [
("AND", "&", "&"),
("NAND", "", ""),
("XOR", "", ""),
]
self.final_gate_options = [
("OR", "+"),
("NOR", ""),
("XOR", ""),
("AND", "&"),
]
def __len__(self) -> int:
return self.config.size
def __iter__(self):
self._current_idx = 0
return self
def __next__(self) -> Dict[str, Any]:
if self._current_idx >= self.config.size:
raise StopIteration
item = self[self._current_idx]
self._current_idx += 1
return item
def __getitem__(self, idx: int) -> Dict[str, Any]:
"""
Generate one random circuit logic item using ASCII drawing.
"""
rng = Random(self.seed + idx if self.seed is not None else None)
return self._generate_circuit(
rng=rng,
num_terms=self.config.num_terms,
min_inputs=self.config.min_inputs,
max_inputs=self.config.max_inputs,
neg_prob=self.config.neg_prob,
allow_reuse=self.config.allow_reuse,
)
def _generate_circuit(
self, rng: Random, num_terms: int, min_inputs: int, max_inputs: int, neg_prob: float, allow_reuse: bool
) -> Dict[str, Any]:
"""
Generate circuit logic (ASCII drawing + expression + evaluation)
"""
final_gate_name, final_gate_sym = rng.choice(self.final_gate_options)
final_gate_width = 2 + len(final_gate_sym)
distinct_inputs: List[str] = []
def get_random_input() -> str:
if allow_reuse and distinct_inputs and rng.random() < 0.5:
return rng.choice(distinct_inputs)
else:
name = _get_excel_name(len(distinct_inputs))
distinct_inputs.append(name)
return name
term_ops: List[Tuple[str, str, str]] = []
term_strings: List[str] = []
for _ in range(num_terms):
op = rng.choice(self.internal_ops)
term_ops.append(op)
term_length = rng.randint(min_inputs, max_inputs)
parts = []
for __ in range(term_length):
inp = get_random_input()
neg = rng.random() < neg_prob
parts.append(inp + ("'" if neg else ""))
# Join the parts with the operators join symbol.
term_str = op[1].join(parts)
term_strings.append(term_str)
expression_for_display = final_gate_sym.join(f"({t})" for t in term_strings)
# use || separator internally that doesn't clash with other symbols...
separator = "||"
expression_for_internal = separator.join(term_strings)
expr = [] # will hold a list of tuples (op_used, term_input_list)
inputs_set = set()
term_inputs_map = {}
input_ypos = 0
outer_terms = expression_for_internal.split(separator)
for op_chosen, term in zip(term_ops, outer_terms):
op_used = op_chosen
# If the join symbol appears in the term, split by it; otherwise (single literal) use it as-is.
if op_used[1] in term:
input_strs = term.split(op_used[1])
else:
input_strs = [term]
curr_term = []
for part in input_strs:
if not part:
continue
neg = part.endswith("'")
name = part[:-1] if neg else part
inputs_set.add(name)
curr_term.append({"name": name, "ypos": input_ypos, "neg": neg})
term_inputs_map.setdefault(name, []).append({"ypos": input_ypos, "neg": neg})
input_ypos += 1
expr.append((op_used, curr_term))
# Add a gap after each term.
input_ypos += 1
inputs_list = sorted(list(inputs_set))
total_term_inputs = sum(len(t) for (_, t) in expr)
height = len(inputs_list) + total_term_inputs + len(expr) - 1
max_input_len = max((len(s) for s in inputs_list), default=0)
input_width = max_input_len + 2
width = 0
width += input_width
width += len(inputs_list) * 2
not_and_start = width
width += 7 # space for gates ...
width += len(expr) + 1
gate_start = width
width += final_gate_width
width += 4 # additional wiring space
width += 4
width += len(expression_for_display)
# Create an empty drawing matrix.
matrix = [[" " for _ in range(width)] for _ in range(height)]
base_y = len(inputs_list)
x = width - 8 - len(expression_for_display)
y = base_y + ((height - base_y) // 2)
_matrix_put(matrix, height, width, x, y, _repeat(HORIZ, 4) + " OUT", "RIGHT")
x = gate_start
out_gate_center = base_y + ((height - base_y) // 2) - (len(expr) // 2)
if len(expr) == 1:
_matrix_put(matrix, height, width, x, out_gate_center, _repeat(HORIZ, final_gate_width), "RIGHT")
else:
_matrix_put(matrix, height, width, x, out_gate_center, _repeat(HORIZ, len(expr)), "DOWN")
_matrix_put(matrix, height, width, x + 1, out_gate_center, _repeat(VERT, len(expr)), "DOWN")
for i, ch in enumerate(final_gate_sym):
_matrix_put(matrix, height, width, x + 2 + i, out_gate_center, _repeat(ch, len(expr)), "DOWN")
_matrix_put(matrix, height, width, x + 3 + i, out_gate_center, _repeat(ch, len(expr)), "DOWN")
# Draw internal wiring (for the internal gate section).
x = not_and_start
y = base_y
for op, term_inputs in expr:
layers = [""] * 7
for ti in term_inputs:
layers[0] += HORIZ
layers[1] += ">" if ti["neg"] else HORIZ
layers[2] += "o" if ti["neg"] else HORIZ
layers[3] += HORIZ
# If multiple inputs, we connect them vertically
layers[4] += VERT if len(term_inputs) > 1 else HORIZ
layers[5] += op[2] if (len(term_inputs) > 1) else HORIZ
layers[6] += op[2] if (len(term_inputs) > 1) else HORIZ
for i in range(7):
_matrix_put(matrix, height, width, x + i, y, layers[i], "DOWN")
y += len(term_inputs) + 1
x = 0
y = 0
for inp in inputs_list:
label = f"{inp}: " + _repeat(HORIZ, input_width - (len(inp) + 2))
_matrix_put(matrix, height, width, x, y, label, "RIGHT")
y += 1
x = input_width
for idx, inp in enumerate(inputs_list):
y = idx
length = len(inputs_list) * 2 - 1 - (idx * 2)
_matrix_put(matrix, height, width, x, y, _repeat(HORIZ, length) + LDOWN, "RIGHT")
num = 0
offset = len(inputs_list) * 2 - 1
for inp in inputs_list:
y_breaks = [base_y + ti["ypos"] for ti in term_inputs_map.get(inp, [])]
y_breaks.sort()
for yb in y_breaks:
_matrix_put(
matrix, height, width, x + offset, yb, _repeat(HORIZ, len(inputs_list) * 2 - offset), "RIGHT"
)
y_start = num + 1
max_break = max(y_breaks) if y_breaks else y_start
branch = list(_repeat(VERT, max_break - y_start + 1))
for yb in y_breaks:
pos = yb - y_start
if 0 <= pos < len(branch):
branch[pos] = RBRANCH
branch[-1] = RUP
_matrix_put(matrix, height, width, x + offset, y_start, "".join(branch), "DOWN")
offset -= 2
num += 1
x = not_and_start + 7
out_y = out_gate_center
breakx = len(expr) // 2
for op, term_inputs in expr:
in_y = base_y + (term_inputs[0]["ypos"] + term_inputs[-1]["ypos"]) // 2
# horizontal to branch
_matrix_put(matrix, height, width, x, in_y, _repeat(HORIZ, abs(breakx) + 1), "RIGHT")
# horizontal from branch up/down to final gate column
_matrix_put(
matrix, height, width, x + abs(breakx) + 1, out_y, _repeat(HORIZ, len(expr) - abs(breakx)), "RIGHT"
)
if in_y < out_y:
branch = LDOWN + _repeat(VERT, out_y - in_y - 1) + RUP
_matrix_put(matrix, height, width, x + abs(breakx) + 1, in_y, branch, "DOWN")
elif in_y > out_y:
branch = RDOWN + _repeat(VERT, in_y - out_y - 1) + LUP
_matrix_put(matrix, height, width, x + abs(breakx) + 1, out_y, branch, "DOWN")
out_y += 1
breakx -= 1
ascii_diagram = "\n".join("".join(row).rstrip() for row in matrix)
assignments = {}
for inp in inputs_list:
assignments[inp] = rng.choice([0, 1])
term_values = []
for op_used, term_inputs in expr:
op_name = op_used[0]
values = []
for literal in term_inputs:
val = assignments[literal["name"]]
if literal["neg"]:
val = 1 - val
values.append(val)
if op_name == "AND":
term_val = 1 if all(v == 1 for v in values) else 0
elif op_name == "NAND":
term_val = 0 if all(v == 1 for v in values) else 1
elif op_name == "XOR":
tmp = 0
for v in values:
tmp ^= v
term_val = tmp
else:
term_val = 0
term_values.append(term_val)
# Evaluate final gate based on term values
if final_gate_name == "OR":
final_result = 1 if any(v == 1 for v in term_values) else 0
elif final_gate_name == "NOR":
final_result = 0 if any(v == 1 for v in term_values) else 1
elif final_gate_name == "XOR":
final_result = sum(term_values) % 2
elif final_gate_name == "AND":
final_result = 1 if all(v == 1 for v in term_values) else 0
else:
raise ValueError(f"Unknown gate type: {final_gate_name}")
lines = []
lines.append("Below is a randomly generated logic circuit.\n")
lines.append(ascii_diagram)
lines.append("\n")
legend_lines = []
legend_lines.append("Legend for gates:")
for op_name, _, draw_sym in self.internal_ops:
legend_lines.append(f"{draw_sym*2}: {op_name}")
if neg_prob > 0:
legend_lines.append(f">o: Negate")
if final_gate_sym not in self.internal_ops:
legend_lines.append(f"{final_gate_sym*2}: {final_gate_name}")
legend_str = "\n".join(legend_lines)
lines.append(legend_str)
lines.append("")
lines.append("Given the following input assignments:")
for inp in inputs_list:
lines.append(f" {inp} = {assignments[inp]}")
lines.append("")
lines.append("What is the final output?")
answer_str = str(final_result)
question_str = "\n".join(lines)
return {
"question": question_str,
"answer": answer_str,
"metadata": {
"expression": expression_for_display,
"assignments": assignments,
"term_strings": term_strings,
"final_gate": final_gate_name,
"inputs": inputs_list,
},
}
def score_answer(self, answer: Optional[str], entry: Dict[str, Any]) -> float:
if answer is None or len(answer) == 0:
return 0.0
oracle_answer = entry["answer"]
if oracle_answer == answer:
return 1.0
elif oracle_answer == answer.strip():
return len(oracle_answer) / len(answer)
return 0.01
register_dataset("circuit_logic", CircuitLogicDataset, CircuitLogicConfig)

View file

@ -8,12 +8,12 @@ SYSTEM_PROMPTS = {
"DeepSeekZero": """A conversation between User and Assistant. The user asks a question, and the Assistant solves it. "DeepSeekZero": """A conversation between User and Assistant. The user asks a question, and the Assistant solves it.
The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think> reasoning process here </think> The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think> reasoning process here </think>
<answer>answer here</answer> <answer>answer here</answer>
Do not explain your reasoning inside the answer tags, provide only the final answer. Do not explain your reasoning inside the answer tags, provide only the final answer. When an example is provided, you should strictly follow the format of the output/answer in that example.
""", """,
"default": """Given a problem, your task is to answer the question by thinking step-by-step in a clear and specific manner. "default": """Given a problem, your task is to answer the question by thinking step-by-step in a clear and specific manner.
Once you have thought about the reasoning process, provide the answer in the following format: Once you have thought about the reasoning process, provide the answer in the following format:
<answer>answer here</answer> <answer>answer here</answer>
Do not explain your reasoning inside the answer tags, provide only the final answer. Do not explain your reasoning inside the answer tags, provide only the final answer. When an example is provided, you should strictly follow the format of the output/answer in that example.
""", """,
} }

View file

@ -137,3 +137,32 @@ def test_arc_agi_dataset_modes():
both_ds = ArcAgiDataset(both_config) both_ds = ArcAgiDataset(both_config)
assert len(both_ds._task_ids) > len(train_ds._task_ids) assert len(both_ds._task_ids) > len(train_ds._task_ids)
assert len(both_ds._task_ids) > len(eval_ds._task_ids) assert len(both_ds._task_ids) > len(eval_ds._task_ids)
def test_arc_agi_shuffled_order():
config_unshuffled = ArcAgiConfig(
shuffle_example_order=False,
use_train=True,
use_eval=False,
rotations=[],
mirrors=[],
use_color_permutation=False,
size=3,
seed=42,
)
config_shuffled = ArcAgiConfig(
shuffle_example_order=True,
use_train=True,
use_eval=False,
rotations=[],
mirrors=[],
use_color_permutation=False,
size=3,
seed=42,
)
unshuffled = ArcAgiDataset(config_unshuffled)
shuffled = ArcAgiDataset(config_shuffled)
for a, b in zip(shuffled, unshuffled):
assert a["question"] != b["question"]
assert a["answer"] == b["answer"]

View file

@ -1,5 +1,3 @@
from random import Random
import pytest import pytest
from reasoning_gym.arithmetic.basic_arithmetic import ( from reasoning_gym.arithmetic.basic_arithmetic import (
@ -64,11 +62,19 @@ def test_arithmetic_dataset_format_styles():
max_digits=2, max_digits=2,
) )
dataset = BasicArithmeticDataset(config) dataset = BasicArithmeticDataset(config)
assert all(item["question"].endswith("=") for item in dataset) assert all(item["question"].strip().endswith(".") for item in dataset)
config.format_style = "natural" config = BasicArithmeticDatasetConfig(
size=10,
seed=42,
format_style="natural",
min_terms=2,
max_terms=3, # Keep expressions simple for testing
min_digits=1,
max_digits=2,
)
dataset = BasicArithmeticDataset(config) dataset = BasicArithmeticDataset(config)
assert all("=" not in item["question"] for item in dataset) assert all(item["question"].strip().endswith(".") for item in dataset)
def test_arithmetic_dataset_iteration(): def test_arithmetic_dataset_iteration():

View file

@ -15,6 +15,14 @@ def test_binary_matrix_config_validation():
config = BinaryMatrixConfig(max_n=0) # Zero not allowed config = BinaryMatrixConfig(max_n=0) # Zero not allowed
config.validate() config.validate()
with pytest.raises(AssertionError):
config = BinaryMatrixConfig(min_n=-1) # Negative not allowed
config.validate()
with pytest.raises(AssertionError):
config = BinaryMatrixConfig(min_n=0) # Zero not allowed
config.validate()
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
config = BinaryMatrixConfig(p_zero=0) # <= 0 not allowed config = BinaryMatrixConfig(p_zero=0) # <= 0 not allowed
config.validate() config.validate()
@ -98,3 +106,18 @@ def test_binary_matrix_answer():
# Empty matrix # Empty matrix
matrix = [[0, 0, 0], [0, 0, 0], [0, 0, 0]] matrix = [[0, 0, 0], [0, 0, 0], [0, 0, 0]]
assert dataset._get_distances(matrix) == [[0, 0, 0], [0, 0, 0], [0, 0, 0]] assert dataset._get_distances(matrix) == [[0, 0, 0], [0, 0, 0], [0, 0, 0]]
# String representation of answer
answer = "0 0 0\n0 1 0\n1 2 1"
entry = {"answer": "0 0 0\n0 1 0\n1 2 1"}
assert dataset.score_answer(answer, entry) == 1.0
# Answer is a python list (partially correct answer)
answer = "[[0, 0, 0], [0, 1, 0], [1, 2, 1]]"
entry = {"answer": "0 0 0\n0 1 0\n1 2 1"}
assert dataset.score_answer(answer, entry) == 0.5
# Answer is null
answer = None
entry = {"answer": "0 0 0\n0 1 0\n1 2 1"}
assert dataset.score_answer(answer, entry) == 0.0

224
tests/test_circuit_logic.py Normal file
View file

@ -0,0 +1,224 @@
import pytest
from reasoning_gym.logic import CircuitLogicConfig, CircuitLogicDataset
def test_circuit_logic_config_validation():
"""Test that invalid configs raise appropriate errors"""
with pytest.raises(AssertionError):
config = CircuitLogicConfig(min_inputs=3, max_inputs=2)
config.validate()
with pytest.raises(AssertionError):
config = CircuitLogicConfig(num_terms=0)
config.validate()
with pytest.raises(AssertionError):
config = CircuitLogicConfig(neg_prob=-0.1)
config.validate()
with pytest.raises(AssertionError):
config = CircuitLogicConfig(neg_prob=1.1)
config.validate()
def test_circuit_logic_deterministic():
"""Test that dataset generates same items with same seed"""
config = CircuitLogicConfig(seed=42, size=10)
dataset1 = CircuitLogicDataset(config)
dataset2 = CircuitLogicDataset(config)
for i in range(len(dataset1)):
assert dataset1[i] == dataset2[i]
def test_circuit_logic_items():
"""Test basic properties of generated items"""
config = CircuitLogicConfig(num_terms=3, min_inputs=2, max_inputs=3, neg_prob=0.3, size=50, seed=42)
dataset = CircuitLogicDataset(config)
for i in range(len(dataset)):
item = dataset[i]
assert isinstance(item, dict)
assert "question" in item
assert "answer" in item
assert "metadata" in item
# Verify metadata contents
metadata = item["metadata"]
assert "expression" in metadata
assert "assignments" in metadata
assert "final_gate" in metadata
assert "inputs" in metadata
# Verify answer is binary
assert item["answer"] in ("0", "1")
# Verify assignments are binary
for input_name, value in metadata["assignments"].items():
assert value in (0, 1)
# Verify final gate is valid
assert metadata["final_gate"] in ("OR", "NOR", "XOR", "AND")
# Verify inputs list matches assignments
assert set(metadata["inputs"]) == set(metadata["assignments"].keys())
def test_circuit_logic_expression_validity():
"""Test that generated expressions follow logical circuit rules"""
config = CircuitLogicConfig(
num_terms=2, min_inputs=2, max_inputs=2, neg_prob=0.0, size=20, seed=42 # Disable negation for simpler testing
)
dataset = CircuitLogicDataset(config)
for i in range(len(dataset)):
item = dataset[i]
metadata = item["metadata"]
# Expression should contain valid operators
expr = metadata["expression"]
assert any(op in expr for op in ("&", "", "", "+", ""))
# Input names should be valid Excel-style names
for input_name in metadata["inputs"]:
assert input_name.isalpha()
assert input_name.isupper()
def test_circuit_logic_answer_verification():
"""Test that answers match logical evaluation of circuits"""
config = CircuitLogicConfig(num_terms=2, min_inputs=2, max_inputs=2, size=20, seed=42)
dataset = CircuitLogicDataset(config)
def evaluate_term(term: str, assignments: dict) -> int:
"""Evaluate a single term with given assignments"""
if "" in term: # NAND
parts = term.split("")
values = []
for p in parts:
if p.endswith("'"):
values.append(1 - assignments[p[:-1]])
else:
values.append(assignments[p])
return 0 if all(v == 1 for v in values) else 1
elif "&" in term: # AND
parts = term.split("&")
values = []
for p in parts:
if p.endswith("'"):
values.append(1 - assignments[p[:-1]])
else:
values.append(assignments[p])
return 1 if all(v == 1 for v in values) else 0
elif "" in term: # XOR
parts = term.split("")
values = []
for p in parts:
if p.endswith("'"):
values.append(1 - assignments[p[:-1]])
else:
values.append(assignments[p])
return sum(values) % 2
else:
raise ValueError(f"Unknown operator in term: {term}")
def evaluate_final_gate(gate_type: str, term_values: list) -> int:
"""Evaluate the final gate with given term values"""
if gate_type == "AND":
return 1 if all(v == 1 for v in term_values) else 0
elif gate_type == "OR":
return 1 if any(v == 1 for v in term_values) else 0
elif gate_type == "XOR":
return sum(term_values) % 2
elif gate_type == "NOR":
return 0 if any(v == 1 for v in term_values) else 1
else:
raise ValueError(f"Unknown gate type: {gate_type}")
for i in range(len(dataset)):
item = dataset[i]
metadata = item["metadata"]
assignments = metadata["assignments"]
final_gate = metadata["final_gate"]
term_strings = metadata["term_strings"]
# First evaluate each term
term_values = [evaluate_term(term, assignments) for term in term_strings]
# Then combine terms with final gate
expected = evaluate_final_gate(final_gate, term_values)
# Compare with actual result
result = int(item["answer"])
assert (
result == expected
), f"Item {i}: Expected {expected} but got {result} for terms {term_strings} with assignments {assignments} and final gate {final_gate}"
def test_circuit_logic_ascii_diagram():
"""Test properties of the ASCII circuit diagram"""
config = CircuitLogicConfig(num_terms=2, min_inputs=2, max_inputs=2, size=10, seed=42)
dataset = CircuitLogicDataset(config)
for i in range(len(dataset)):
item = dataset[i]
# Split question to get diagram
parts = item["question"].split("\n")
diagram_start = parts.index("Below is a randomly generated logic circuit.") + 2
diagram_end = parts.index("", diagram_start)
diagram = parts[diagram_start:diagram_end]
# Basic diagram validation
assert len(diagram) > 0
assert all(len(row) > 0 for row in diagram)
# Check for required circuit elements
diagram_str = "\n".join(diagram)
assert "OUT" in diagram_str
assert any(gate in diagram_str for gate in ("&", "", ""))
# Verify input labels
for input_name in item["metadata"]["inputs"]:
assert f"{input_name}:" in diagram_str
def test_circuit_logic_scoring():
"""Test the answer scoring mechanism"""
config = CircuitLogicConfig(size=5, seed=42)
dataset = CircuitLogicDataset(config)
item = dataset[0]
# Correct answer should score 1.0
assert dataset.score_answer(item["answer"], item) == 1.0
# Wrong answer should score lower
wrong_answer = "1" if item["answer"] == "0" else "0"
assert dataset.score_answer(wrong_answer, item) < 1.0
# None or empty answer should score 0.0
assert dataset.score_answer(None, item) == 0.0
assert dataset.score_answer("", item) == 0.0 # Empty string should score 0.0 like None
def test_circuit_logic_iteration():
"""Test that iteration works correctly"""
config = CircuitLogicConfig(size=5, seed=42)
dataset = CircuitLogicDataset(config)
# Test manual iteration
items = []
for item in dataset:
items.append(item)
assert len(items) == config.size
# Test list conversion
items = list(dataset)
assert len(items) == config.size
# Test multiple iterations yield same items
first_items = list(dataset)
second_items = list(dataset)
assert first_items == second_items

105
tests/test_cryptarithm.py Normal file
View file

@ -0,0 +1,105 @@
import pytest
from reasoning_gym import create_dataset
from reasoning_gym.algorithmic.cryptarithm import CryptarithmConfig, CryptarithmDataset
def test_cryptarithm_generation():
dataset = create_dataset("cryptarithm", seed=42, size=10)
assert isinstance(dataset, CryptarithmDataset)
unique_number = set()
for item in dataset:
# Check required keys exist
assert "question" in item
assert "answer" in item
assert "metadata" in item
# Validate question format
question = item["question"]
assert "Solve this cryptarithm:" in question
assert "Each letter stands for a unique digit (0-9)" in question
# Validate metadata structure
metadata = item["metadata"]
assert "letters" in metadata
assert "letter_to_digit" in metadata
assert "words_letters" in metadata
assert "result_letters" in metadata
assert "word_values" in metadata
assert "sum_number" in metadata
# Validate letter to digit mapping
letter_to_digit = metadata["letter_to_digit"]
used_digits = set(letter_to_digit.values())
assert len(used_digits) == len(letter_to_digit), "Each letter should map to a unique digit"
assert all(0 <= digit <= 9 for digit in used_digits), "All digits should be between 0 and 9"
# Validate the arithmetic
word_values = metadata["word_values"]
result_value = metadata["sum_number"]
assert sum(word_values) == result_value, "Sum of word values should equal result value"
unique_number.add(result_value)
assert len(unique_number) == len(dataset)
def test_cryptarithm_config():
# Test invalid configs raise assertions
with pytest.raises(AssertionError):
dataset = create_dataset("cryptarithm", min_words=1) # min_words must be >= 2
with pytest.raises(AssertionError):
dataset = create_dataset("cryptarithm", min_words=4, max_words=3) # min must be <= max
with pytest.raises(AssertionError):
dataset = create_dataset("cryptarithm", size=0) # size must be positive
def test_leading_zero_constraint():
# Test with leading zeros not allowed
dataset = create_dataset("cryptarithm", seed=42, size=5, allow_leading_zero=False, max_words=10, min_words=5)
for item in dataset:
# print(item['question'])
metadata = item["metadata"]
letter_to_digit = metadata["letter_to_digit"]
words_letters = metadata["words_letters"]
result_letters = metadata["result_letters"]
# Check leading letters of all words and result
leading_letters = [word[0] for word in words_letters] + [result_letters[0]]
for letter in leading_letters:
assert letter_to_digit[letter] != 0, "Leading letters cannot be zero when allow_leading_zero=False"
def test_deterministic_generation():
dataset1 = create_dataset("cryptarithm", seed=42, size=5)
dataset2 = create_dataset("cryptarithm", seed=42, size=5)
for i in range(5):
assert dataset1[i]["question"] == dataset2[i]["question"]
assert dataset1[i]["answer"] == dataset2[i]["answer"]
assert dataset1[i]["metadata"] == dataset2[i]["metadata"]
def test_word_length_constraints():
dataset = create_dataset("cryptarithm", seed=42, size=10)
for item in dataset:
metadata = item["metadata"]
words_letters = metadata["words_letters"]
# Check each word is between 3-5 letters as specified in the code
for word in words_letters:
assert 3 <= len(word) <= 5, "Each word should be between 3 and 5 letters long"
def test_max_letters_constraint():
dataset = create_dataset("cryptarithm", seed=42, size=10)
for item in dataset:
metadata = item["metadata"]
letter_to_digit = metadata["letter_to_digit"]
# Check total unique letters doesn't exceed 10 (digits 0-9)
assert len(letter_to_digit) <= 10, "Total unique letters should not exceed 10"

188
tests/test_futoshiki.py Normal file
View file

@ -0,0 +1,188 @@
import pytest
from reasoning_gym.games import FutoshikiConfig, FutoshikiDataset
def test_futoshiki_config_validation():
"""Test that invalid configs raise appropriate errors"""
with pytest.raises(AssertionError):
config = FutoshikiConfig(board_size=3) # Too small
config.validate()
with pytest.raises(AssertionError):
config = FutoshikiConfig(board_size=10) # Too large
config.validate()
with pytest.raises(AssertionError):
config = FutoshikiConfig(difficulty=-1) # Invalid difficulty
config.validate()
with pytest.raises(AssertionError):
config = FutoshikiConfig(difficulty=4) # Invalid difficulty
config.validate()
def test_futoshiki_deterministic():
"""Test that dataset generates same puzzles with same seed"""
config = FutoshikiConfig(seed=42, size=10)
dataset1 = FutoshikiDataset(config)
dataset2 = FutoshikiDataset(config)
for i in range(len(dataset1)):
assert dataset1[i] == dataset2[i]
def test_futoshiki_items():
"""Test basic properties of generated items"""
config = FutoshikiConfig(board_size=4, difficulty=1, size=10, seed=42)
dataset = FutoshikiDataset(config)
for i in range(len(dataset)):
item = dataset[i]
assert isinstance(item, dict)
assert "question" in item
assert "answer" in item
assert "metadata" in item
# Verify metadata contents
metadata = item["metadata"]
assert "puzzle" in metadata
assert "solution" in metadata
assert "constraints" in metadata
assert "board_size" in metadata
assert "difficulty" in metadata
# Verify board dimensions
puzzle = metadata["puzzle"]
solution = metadata["solution"]
assert len(puzzle) == config.board_size
assert len(solution) == config.board_size
for row in puzzle:
assert len(row) == config.board_size
for row in solution:
assert len(row) == config.board_size
# Verify constraints format
constraints = metadata["constraints"]
for ((r1, c1), (r2, c2)), rel in constraints.items():
assert 0 <= r1 < config.board_size
assert 0 <= c1 < config.board_size
assert 0 <= r2 < config.board_size
assert 0 <= c2 < config.board_size
assert rel in ("<", ">")
def test_futoshiki_solution_validity():
"""Test that solutions are valid according to Futoshiki rules"""
config = FutoshikiConfig(board_size=4, difficulty=1, size=10, seed=42)
dataset = FutoshikiDataset(config)
def is_valid_solution(solution, board_size, constraints):
# Check rows
for row in solution:
if sorted(row) != list(range(1, board_size + 1)):
return False
# Check columns
for col in range(board_size):
column = [solution[row][col] for row in range(board_size)]
if sorted(column) != list(range(1, board_size + 1)):
return False
# Check constraints
for ((r1, c1), (r2, c2)), rel in constraints.items():
v1, v2 = solution[r1][c1], solution[r2][c2]
if rel == "<" and not (v1 < v2):
return False
if rel == ">" and not (v1 > v2):
return False
return True
for i in range(len(dataset)):
item = dataset[i]
metadata = item["metadata"]
solution = metadata["solution"]
constraints = metadata["constraints"]
assert is_valid_solution(solution, config.board_size, constraints)
def test_futoshiki_puzzle_solvability():
"""Test that generated puzzles are solvable and have unique solutions"""
config = FutoshikiConfig(board_size=4, difficulty=1, size=5, seed=42)
dataset = FutoshikiDataset(config)
for i in range(len(dataset)):
item = dataset[i]
metadata = item["metadata"]
puzzle = metadata["puzzle"]
constraints = metadata["constraints"]
# Verify puzzle has exactly one solution
assert dataset.count_solutions(puzzle, constraints, limit=2) == 1
def test_futoshiki_difficulty_levels():
"""Test that different difficulty levels affect puzzle complexity"""
size = 5
board_size = 4
seeds = [42, 43, 44] # Test multiple seeds for robustness
def count_clues(puzzle):
return sum(cell != 0 for row in puzzle for cell in row)
def count_constraints(constraints):
return len(constraints)
for seed in seeds:
clues_by_difficulty = []
constraints_by_difficulty = []
for difficulty in range(4): # 0 to 3
config = FutoshikiConfig(board_size=board_size, difficulty=difficulty, size=size, seed=seed)
dataset = FutoshikiDataset(config)
avg_clues = sum(count_clues(item["metadata"]["puzzle"]) for item in dataset) / size
avg_constraints = sum(count_constraints(item["metadata"]["constraints"]) for item in dataset) / size
clues_by_difficulty.append(avg_clues)
constraints_by_difficulty.append(avg_constraints)
# Higher difficulty should generally mean fewer clues and/or more constraints
assert all(clues_by_difficulty[i] >= clues_by_difficulty[i + 1] for i in range(len(clues_by_difficulty) - 1))
assert all(
constraints_by_difficulty[i] <= constraints_by_difficulty[i + 1]
for i in range(len(constraints_by_difficulty) - 1)
)
def test_futoshiki_answer_scoring():
"""Test the answer scoring mechanism"""
config = FutoshikiConfig(board_size=4, difficulty=0, size=5, seed=42)
dataset = FutoshikiDataset(config)
for item in dataset:
# Correct answer should score 1.0
assert dataset.score_answer(item["answer"], item) == 1.0
# Wrong answer should score lower
wrong_answer = item["answer"].replace("1", "2")
assert dataset.score_answer(wrong_answer, item) < 1.0
# None or empty answer should score 0.0
assert dataset.score_answer(None, item) == 0.0
assert dataset.score_answer("", item) == 0.0
answer = item["answer"]
white_space_mismatch = answer.replace(" ", " ")
assert dataset.score_answer(white_space_mismatch, item) == 0.9
anwser_with_additional_text = "This is an anwser " + answer + "\nwith surrounding text."
assert 0 < dataset.score_answer(anwser_with_additional_text, item) < 0.9
partially_correct = anwser_with_additional_text.replace("1", "2")
assert dataset.score_answer(partially_correct, item) > 0.1
bad_answer = "\n".join(anwser_with_additional_text.split("\n")[::-1])
assert dataset.score_answer(bad_answer, item) < 0.1

View file

@ -122,6 +122,11 @@ def test_nqueens_score_answer():
# Test None answer gets score 0.0 # Test None answer gets score 0.0
assert dataset.score_answer(None, item) == 0.0 assert dataset.score_answer(None, item) == 0.0
# Test python list representation of board (partial solution)
answer = "[['_', 'Q', '_', '_'], ['_', '_', '_', 'Q'], ['Q', '_', '_', '_'], ['_', '_', 'Q', '_']]"
entry = {"metadata": {"valid_answers": {"_ Q _ _\n_ _ _ Q\nQ _ _ _\n_ _ Q _"}}}
assert dataset.score_answer(answer, entry) == 0.5
def is_valid_solution(board: list[list[str]]) -> bool: def is_valid_solution(board: list[list[str]]) -> bool:
"""Helper function to verify N Queens solution validity""" """Helper function to verify N Queens solution validity"""

View file

@ -0,0 +1,111 @@
"""Tests for Palindrome Partitioning questions generation"""
import json
from reasoning_gym.algorithmic.palindrome_partitioning import (
PalindromePartitioningConfig,
PalindromePartitioningDataset,
)
def test_palindrome_partitioning_dataset_deterministic():
"""Test that dataset generates same items with same seed"""
config = PalindromePartitioningConfig(seed=42, size=10)
dataset1 = PalindromePartitioningDataset(config)
dataset2 = PalindromePartitioningDataset(config)
for i in range(len(dataset1)):
assert dataset1[i] == dataset2[i]
def test_palindrome_partitioning_dataset_items():
"""Test basic properties of generated items"""
config = PalindromePartitioningConfig(size=10, seed=42)
dataset = PalindromePartitioningDataset(config)
for i in range(len(dataset)):
item = dataset[i]
# Check item structure
assert isinstance(item, dict)
assert "question" in item
assert "answer" in item
assert "metadata" in item
# Check metadata
assert "string" in item["metadata"]
assert "solution" in item["metadata"]
string = item["metadata"]["string"]
solution = item["metadata"]["solution"]
# Verify string is not empty
assert len(string) > 0
# At least one partitioning exists (each letter is a palindrome)
assert len(solution) >= 1
# Verify each partitioning reconstructs the original string
assert all(len(partitioning) > 0 for partitioning in solution)
assert all("".join(partitioning) == string for partitioning in solution)
def test_palindrome_partitioning_dataset_iteration():
"""Test that iteration respects dataset size"""
config = PalindromePartitioningConfig(size=5, seed=42)
dataset = PalindromePartitioningDataset(config)
items = list(dataset)
assert len(items) == config.size
# Test multiple iterations yield same items
assert items == list(dataset)
def test_palindrome_partitioning_answer():
"""Test the _palindrome_partitioning method"""
config = PalindromePartitioningConfig(seed=42)
dataset = PalindromePartitioningDataset(config)
# General use case
word = "afternoon"
correct = [
["a", "f", "t", "e", "r", "n", "o", "o", "n"],
["a", "f", "t", "e", "r", "n", "oo", "n"],
["a", "f", "t", "e", "r", "noon"],
]
assert json.dumps(dataset._palindrome_partitioning(word)) == json.dumps(correct)
# Single letter word
word = "a"
correct = [["a"]]
assert json.dumps(dataset._palindrome_partitioning(word)) == json.dumps(correct)
# Empty string
word = ""
correct = []
assert json.dumps(dataset._palindrome_partitioning(word)) == json.dumps(correct)
def test_palindrome_partitioning_score_answer():
"""Test the score_answer method"""
config = PalindromePartitioningConfig(seed=42)
dataset = PalindromePartitioningDataset(config)
# Verify the scoring function is permutation invariant
answer = json.dumps([["n", "o", "o", "n"], ["no", "on"], ["noon"]])
item = {"metadata": {"solution": [["no", "on"], ["noon"], ["n", "o", "o", "n"]]}}
assert dataset.score_answer(answer, item) == 1
# Verify the score is 0.01 when incorrect
answer = json.dumps([["n", "o", "o", "n"], ["no", "on"]])
item = {"metadata": {"solution": [["no", "on"], ["noon"], ["n", "o", "o", "n"]]}}
assert dataset.score_answer(answer, item) == 0.01
# Verify the score is 0 when answer is None
answer = None
item = {"metadata": {"solution": [["no", "on"], ["noon"], ["n", "o", "o", "n"]]}}
assert dataset.score_answer(answer, item) == 0
# Verify the score is 0 when answer is malformed JSON
answer = '["n", "o", "o", "n"], ["no", "on"], ["noon"]'
item = {"metadata": {"solution": [["no", "on"], ["noon"], ["n", "o", "o", "n"]]}}
assert dataset.score_answer(answer, item) == 0

View file

@ -1,3 +1,5 @@
import string
import pytest import pytest
import sympy as sp import sympy as sp
@ -17,7 +19,7 @@ def test_polynomial_config_validation():
PolynomialMultiplicationConfig(min_value=0).validate() PolynomialMultiplicationConfig(min_value=0).validate()
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
PolynomialMultiplicationConfig(min_degree=0, max_degree=3).validate() PolynomialMultiplicationConfig(min_degree=-1, max_degree=3).validate()
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
PolynomialMultiplicationConfig(min_degree=4, max_degree=3).validate() PolynomialMultiplicationConfig(min_degree=4, max_degree=3).validate()
@ -28,6 +30,17 @@ def test_polynomial_config_validation():
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
PolynomialMultiplicationConfig(min_polynomials=5, max_polynomials=2).validate() PolynomialMultiplicationConfig(min_polynomials=5, max_polynomials=2).validate()
with pytest.raises(AssertionError):
PolynomialMultiplicationConfig(variables="").validate()
with pytest.raises(AssertionError):
PolynomialMultiplicationConfig(
allow_cross_variable_product=False, allow_multivariate_polynomials=True
).validate()
with pytest.raises(AssertionError):
PolynomialMultiplicationConfig(min_polynomials=5, max_polynomials=2).validate()
def test_polynomial_multiplication_dataset_basic(): def test_polynomial_multiplication_dataset_basic():
"""Test dataset creation and length""" """Test dataset creation and length"""
@ -41,7 +54,9 @@ def test_polynomial_multiplication_dataset_basic():
max_degree=2, max_degree=2,
min_polynomials=2, min_polynomials=2,
max_polynomials=3, max_polynomials=3,
single_variable=True, variables=tuple(string.ascii_lowercase),
allow_cross_variable_product=False,
allow_multivariate_polynomials=False,
seed=42, seed=42,
size=dataset_size, size=dataset_size,
) )
@ -63,7 +78,9 @@ def test_polynomial_equations_dataset_items():
max_degree=2, max_degree=2,
min_polynomials=2, min_polynomials=2,
max_polynomials=5, max_polynomials=5,
single_variable=False, variables=tuple("xyz"),
allow_cross_variable_product=False,
allow_multivariate_polynomials=False,
size=3, size=3,
seed=100, seed=100,
) )
@ -75,7 +92,113 @@ def test_polynomial_equations_dataset_items():
# Check metadata # Check metadata
assert isinstance(item["metadata"]["polynomial_expr"], str) assert isinstance(item["metadata"]["polynomial_expr"], str)
assert isinstance(item["metadata"]["single_variable"], bool) assert isinstance(item["metadata"]["result"], str)
assert isinstance(item["metadata"]["variables"], list)
# Check polynomial_expr existence
poly_str = item["metadata"]["polynomial_expr"]
# Ensure it can parse with sympy
sp.sympify(poly_str)
def test_cross_polynomial_equations_dataset_items():
"""Test that generated items have correct structure"""
ds = create_dataset(
"polynomial_multiplication",
min_terms=2,
max_terms=3,
min_value=1,
max_value=5,
min_degree=1,
max_degree=2,
min_polynomials=2,
max_polynomials=5,
variables=tuple("xyz"),
allow_cross_variable_product=True,
allow_multivariate_polynomials=False,
size=3,
seed=100,
)
for item in ds:
assert "question" in item
assert "answer" in item
assert "metadata" in item
# Check metadata
assert isinstance(item["metadata"]["polynomial_expr"], str)
assert isinstance(item["metadata"]["result"], str)
assert isinstance(item["metadata"]["variables"], list)
# Check polynomial_expr existence
poly_str = item["metadata"]["polynomial_expr"]
# Ensure it can parse with sympy
sp.sympify(poly_str)
def test_cross_polynomial_equations_dataset_items():
"""Test that generated items have correct structure"""
ds = create_dataset(
"polynomial_multiplication",
min_terms=2,
max_terms=3,
min_value=1,
max_value=5,
min_degree=1,
max_degree=2,
min_polynomials=2,
max_polynomials=5,
variables=tuple("xyz"),
allow_cross_variable_product=True,
allow_multivariate_polynomials=False,
size=3,
seed=100,
)
for item in ds:
assert "question" in item
assert "answer" in item
assert "metadata" in item
# Check metadata
assert isinstance(item["metadata"]["polynomial_expr"], str)
assert isinstance(item["metadata"]["result"], str)
assert isinstance(item["metadata"]["variables"], list)
# Check polynomial_expr existence
poly_str = item["metadata"]["polynomial_expr"]
# Ensure it can parse with sympy
sp.sympify(poly_str)
def test_multivariate_polynomial_equations_dataset_items():
"""Test that generated items have correct structure"""
ds = create_dataset(
"polynomial_multiplication",
min_terms=2,
max_terms=3,
min_value=1,
max_value=5,
min_degree=1,
max_degree=2,
min_polynomials=2,
max_polynomials=5,
variables=tuple(["x", "y"]),
allow_cross_variable_product=True,
allow_multivariate_polynomials=True,
size=3,
seed=100,
)
for item in ds:
assert "question" in item
assert "answer" in item
assert "metadata" in item
# Check metadata
assert isinstance(item["metadata"]["polynomial_expr"], str)
assert isinstance(item["metadata"]["result"], str)
assert isinstance(item["metadata"]["variables"], list)
# Check polynomial_expr existence # Check polynomial_expr existence
poly_str = item["metadata"]["polynomial_expr"] poly_str = item["metadata"]["polynomial_expr"]
@ -105,7 +228,9 @@ def test_polynomial_solutions_evaluation():
max_degree=3, max_degree=3,
min_polynomials=2, min_polynomials=2,
max_polynomials=5, max_polynomials=5,
single_variable=False, variables=tuple(["x", "y"]),
allow_cross_variable_product=True,
allow_multivariate_polynomials=True,
size=5, size=5,
seed=42, seed=42,
) )
@ -125,42 +250,27 @@ def test_score_function():
ds = create_dataset( ds = create_dataset(
"polynomial_multiplication", "polynomial_multiplication",
min_terms=2, min_terms=2,
max_terms=4, max_terms=3,
min_value=1, min_value=1,
max_value=10, max_value=3,
min_degree=1, min_degree=1,
max_degree=3, max_degree=3,
min_polynomials=2, min_polynomials=3,
max_polynomials=5, max_polynomials=3,
single_variable=True, variables=tuple(["x", "y"]),
size=1, allow_cross_variable_product=True,
allow_multivariate_polynomials=True,
size=3,
seed=42, seed=42,
) )
assert ds.score_answer(None, ds[0]) == 0.00 for item in ds:
assert ds.score_answer("6*x**4 + 9*x**3 - 6*x**2 - 39*x - 45", ds[0]) == 1 poly_str = item["metadata"]["polynomial_expr"]
assert ds.score_answer("Not a polynomial", ds[0]) == 0.01 assert ds.score_answer(poly_str, item) == 0.05
assert ds.score_answer("x**4", ds[0]) == 0.05
poly_expr = str(sp.expand(poly_str))
assert ds.score_answer(poly_expr, item) == 1.0
def test_multivariate_score_function(): assert ds.score_answer(None, item) == 0.00
"""Test that solution satisfy the polynomial multiplication.""" assert ds.score_answer("Not a polynomial", item) == 0.01
ds = create_dataset( assert ds.score_answer("x**4", item) == 0.05
"polynomial_multiplication",
min_terms=2,
max_terms=4,
min_value=1,
max_value=10,
min_degree=1,
max_degree=3,
min_polynomials=2,
max_polynomials=5,
single_variable=False,
size=1,
seed=42,
)
assert ds.score_answer(None, ds[0]) == 0.00
assert ds.score_answer("-27*a**3*c - 27*a**3 + 144*a*c + 144*a", ds[0]) == 1
assert ds.score_answer("Not a polynomial", ds[0]) == 0.01
assert ds.score_answer("x**4", ds[0]) == 0.05

View file

@ -37,6 +37,7 @@ def test_getitem(dataset, config):
assert "metadata" in item assert "metadata" in item
assert item["metadata"]["word_count"] >= config.min_words_in_sentence assert item["metadata"]["word_count"] >= config.min_words_in_sentence
assert item["metadata"]["word_count"] <= config.max_words_in_sentence assert item["metadata"]["word_count"] <= config.max_words_in_sentence
assert len(item["answer"].split()) == item["metadata"]["word_count"]
def test_key_error_in_getitem(dataset): def test_key_error_in_getitem(dataset):

View file

@ -15,6 +15,10 @@ def test_spiral_matrix_config_validation():
config = SpiralMatrixConfig(max_n=0) # Zero not allowed config = SpiralMatrixConfig(max_n=0) # Zero not allowed
config.validate() config.validate()
with pytest.raises(AssertionError):
config = SpiralMatrixConfig(max_n=1) # One not allowed
config.validate()
def test_spiral_matrix_dataset_deterministic(): def test_spiral_matrix_dataset_deterministic():
"""Test that dataset generates same items with same seed""" """Test that dataset generates same items with same seed"""
@ -69,18 +73,26 @@ def test_spiral_matrix_answer():
config = SpiralMatrixConfig(seed=42) config = SpiralMatrixConfig(seed=42)
dataset = SpiralMatrixDataset(config) dataset = SpiralMatrixDataset(config)
# One element
matrix = [[0]]
assert dataset._get_spiral(matrix) == [0]
# One row
matrix = [[0, 1, 2]]
assert dataset._get_spiral(matrix) == [0, 1, 2]
# One column
matrix = [[0], [1], [2]]
assert dataset._get_spiral(matrix) == [0, 1, 2]
# 2D grid # 2D grid
matrix = [[1, 2, 3], [4, 5, 6], [7, 8, 9]] matrix = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
assert dataset._get_spiral(matrix) == [1, 2, 3, 6, 9, 8, 7, 4, 5] assert dataset._get_spiral(matrix) == [1, 2, 3, 6, 9, 8, 7, 4, 5]
# Answer is identical (up to trimming)
entry = {"answer": "1 2 3 6 9 8 7 4 5"}
answer = "\n\n1 2 3 6 9 8 7 4 5\n"
assert dataset.score_answer(answer, entry) == 1.0
# Score answer in list format (partially correct)
entry = {"answer": "1 2 3 6 9 8 7 4 5"}
answer = "[1, 2, 3, 6, 9, 8, 7, 4, 5]"
assert dataset.score_answer(answer, entry) == 0.5
# Answer is incorrect
entry = {"answer": "1 2 3 6 9 8 7 4 5"}
answer = "1 2 3"
assert dataset.score_answer(answer, entry) == 0.01
# Answer is none
entry = {"answer": "1 2 3 6 9 8 7 4 5"}
answer = None
assert dataset.score_answer(answer, entry) == 0.0

View file

@ -92,3 +92,13 @@ def test_string_insertion_answer():
# No reuse of newly inserted characters # No reuse of newly inserted characters
assert dataset._get_answer("ABCDBCD") == "ABCDABCD" assert dataset._get_answer("ABCDBCD") == "ABCDABCD"
# Test score_answer with correct answer
answer = "AABCDAEEEEEEEBCDEBAAAAA"
entry = {"answer": "AABCDAEEEEEEEBCDEBAAAAA"}
assert dataset.score_answer(answer, entry) == 1.0
# Test score_answer with correct answer as python list of characters (partial correct)
answer = "['A', 'A', 'B', 'C', 'D', 'A', 'E', 'E', 'E', 'E', 'E', 'E', 'E', 'B', 'C', 'D', 'E', 'B', 'A', 'A', 'A', 'A', 'A']"
entry = {"answer": "AABCDAEEEEEEEBCDEBAAAAA"}
assert dataset.score_answer(answer, entry) == 0.5

View file

@ -116,3 +116,35 @@ def test_word_sorting_dataset_iteration():
# Test multiple iterations yield same items # Test multiple iterations yield same items
assert items == list(dataset) assert items == list(dataset)
def test_word_sorting_scoring():
"""Test scoring function"""
config = WordSortingConfig(size=1, seed=42)
dataset = WordSortingDataset(config)
item = {
"metadata": {
"sorted_words": ["apple", "banana", "cherry"],
}
}
# Correct answer
answer = "apple, banana, cherry"
assert dataset.score_answer(answer, item) == 1.0
# Correct answer, with incorrect spaces
answer = "apple,banana, cherry"
assert dataset.score_answer(answer, item) == 1.0
# All words present, but not sorted
answer = "banana, cherry, apple"
assert dataset.score_answer(answer, item) == 0.2
# Garbage
answer = "gibberish"
assert dataset.score_answer(answer, item) == 0.01
# Empty answer
answer = None
assert dataset.score_answer(answer, item) == 0.0