mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-30 17:40:45 +00:00
Merge branch 'rich/decimalmath' of github.com:open-thought/reasoning-gym into rich/decimalmath
This commit is contained in:
commit
0cd2eb50d7
62 changed files with 4012 additions and 478 deletions
45
.github/workflows/generate-gallery.yml
vendored
45
.github/workflows/generate-gallery.yml
vendored
|
|
@ -1,45 +0,0 @@
|
||||||
name: Update GALLERY.md
|
|
||||||
|
|
||||||
on:
|
|
||||||
pull_request:
|
|
||||||
branches:
|
|
||||||
- main
|
|
||||||
types: [closed] # Trigger only when the PR is closed
|
|
||||||
|
|
||||||
permissions:
|
|
||||||
contents: write
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
update-gallery:
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
if: github.event.pull_request.merged == true # Ensure it was merged, not just closed
|
|
||||||
steps:
|
|
||||||
- name: Check out repository code
|
|
||||||
uses: actions/checkout@v3
|
|
||||||
with:
|
|
||||||
fetch-depth: 0
|
|
||||||
|
|
||||||
- name: Set up Python
|
|
||||||
uses: actions/setup-python@v4
|
|
||||||
with:
|
|
||||||
python-version: "3.11"
|
|
||||||
|
|
||||||
- name: Install dependencies
|
|
||||||
run: |
|
|
||||||
pip install -e .
|
|
||||||
|
|
||||||
- name: Run gallery script
|
|
||||||
run: |
|
|
||||||
python scripts/generate_gallery.py
|
|
||||||
|
|
||||||
- name: Commit and push changes
|
|
||||||
run: |
|
|
||||||
git config user.name 'github-actions[bot]'
|
|
||||||
git config user.email 'github-actions[bot]@users.noreply.github.com'
|
|
||||||
git add GALLERY.md
|
|
||||||
if [ -n "$(git status --porcelain)" ]; then
|
|
||||||
git commit -m "Update GALLERY.md [skip ci]"
|
|
||||||
git push
|
|
||||||
else
|
|
||||||
echo "No changes to commit."
|
|
||||||
fi
|
|
||||||
1057
GALLERY.md
1057
GALLERY.md
File diff suppressed because it is too large
Load diff
|
|
@ -1,19 +1,32 @@
|
||||||
### env setup
|
### env setup
|
||||||
|
|
||||||
```
|
```
|
||||||
conda create --name verl python=3.12 -y
|
conda create --name verl python=3.11 -y
|
||||||
conda activate verl
|
conda activate verl
|
||||||
|
|
||||||
pip install flash-attn --no-build-isolation
|
pip install flash-attn --no-build-isolation
|
||||||
pip install vllm==0.7.0 ray wandb
|
pip install ray wandb
|
||||||
|
# pip3 install vllm==0.7.0
|
||||||
|
pip3 install vllm --pre --extra-index-url https://wheels.vllm.ai/nightly
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Regarding vllm>0.7 see: [docs](https://verl.readthedocs.io/en/latest/README_vllm0.7.html)
|
||||||
|
|
||||||
|
|
||||||
### clone and install veRL
|
### clone and install veRL
|
||||||
|
|
||||||
tested with verl HEAD a65c9157bc0b85b64cd753de19f94e80a11bd871
|
tested with verl HEAD 0dfcb7f99e299940e1792a386df13c7591df351a
|
||||||
|
|
||||||
```
|
```
|
||||||
git clone https://github.com/volcengine/verl.git
|
git clone https://github.com/volcengine/verl.git
|
||||||
cd verl
|
cd verl
|
||||||
pip install -e .
|
pip install -e .
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
|
Optionally log in to huggingface hub and wandb with your keys:
|
||||||
|
|
||||||
|
```
|
||||||
|
huggingface-cli login
|
||||||
|
wandb login
|
||||||
|
```
|
||||||
|
|
|
||||||
171
examples/veRL/config/grpo_trainer.yaml
Normal file
171
examples/veRL/config/grpo_trainer.yaml
Normal file
|
|
@ -0,0 +1,171 @@
|
||||||
|
data:
|
||||||
|
tokenizer: null
|
||||||
|
train_files: ~/data/rlhf/gsm8k/train.parquet
|
||||||
|
val_files: ~/data/rlhf/gsm8k/test.parquet
|
||||||
|
prompt_key: prompt
|
||||||
|
max_prompt_length: 512
|
||||||
|
max_response_length: 512
|
||||||
|
train_batch_size: 1024
|
||||||
|
val_batch_size: 1312
|
||||||
|
return_raw_input_ids: False # This should be set to true when the tokenizer between policy and rm differs
|
||||||
|
return_raw_chat: False
|
||||||
|
|
||||||
|
actor_rollout_ref:
|
||||||
|
hybrid_engine: True
|
||||||
|
model:
|
||||||
|
path: ~/models/deepseek-llm-7b-chat
|
||||||
|
external_lib: null
|
||||||
|
override_config: { }
|
||||||
|
enable_gradient_checkpointing: True
|
||||||
|
use_remove_padding: False
|
||||||
|
actor:
|
||||||
|
strategy: fsdp # This is for backward-compatibility
|
||||||
|
ppo_mini_batch_size: 256
|
||||||
|
ppo_micro_batch_size: null # will be deprecated, use ppo_micro_batch_size_per_gpu
|
||||||
|
ppo_micro_batch_size_per_gpu: null
|
||||||
|
use_dynamic_bsz: False
|
||||||
|
ppo_max_token_len_per_gpu: 16384 # n * ${data.max_prompt_length} + ${data.max_response_length}
|
||||||
|
grad_clip: 1.0
|
||||||
|
clip_ratio: 0.2
|
||||||
|
entropy_coeff: 0.001
|
||||||
|
use_kl_loss: True # True for GRPO
|
||||||
|
kl_loss_coef: 0.001 # for grpo
|
||||||
|
kl_loss_type: low_var_kl # for grpo
|
||||||
|
ppo_epochs: 1
|
||||||
|
shuffle: False
|
||||||
|
ulysses_sequence_parallel_size: 1 # sp size
|
||||||
|
optim:
|
||||||
|
lr: 1e-6
|
||||||
|
lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime
|
||||||
|
min_lr_ratio: null # only useful for warmup with cosine
|
||||||
|
warmup_style: constant # select from constant/cosine
|
||||||
|
total_training_steps: -1 # must be override by program
|
||||||
|
fsdp_config:
|
||||||
|
wrap_policy:
|
||||||
|
# transformer_layer_cls_to_wrap: None
|
||||||
|
min_num_params: 0
|
||||||
|
param_offload: False
|
||||||
|
optimizer_offload: False
|
||||||
|
fsdp_size: -1
|
||||||
|
ref:
|
||||||
|
fsdp_config:
|
||||||
|
param_offload: False
|
||||||
|
wrap_policy:
|
||||||
|
# transformer_layer_cls_to_wrap: None
|
||||||
|
min_num_params: 0
|
||||||
|
log_prob_micro_batch_size: null # will be deprecated, use log_prob_micro_batch_size_per_gpu
|
||||||
|
log_prob_micro_batch_size_per_gpu: null
|
||||||
|
log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz}
|
||||||
|
log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}
|
||||||
|
ulysses_sequence_parallel_size: ${actor_rollout_ref.actor.ulysses_sequence_parallel_size} # sp size
|
||||||
|
rollout:
|
||||||
|
name: vllm
|
||||||
|
temperature: 1.0
|
||||||
|
top_k: -1 # 0 for hf rollout, -1 for vllm rollout
|
||||||
|
top_p: 1
|
||||||
|
prompt_length: ${data.max_prompt_length} # not use for opensource
|
||||||
|
response_length: ${data.max_response_length}
|
||||||
|
# for vllm rollout
|
||||||
|
dtype: bfloat16 # should align with FSDP
|
||||||
|
gpu_memory_utilization: 0.5
|
||||||
|
ignore_eos: False
|
||||||
|
enforce_eager: True
|
||||||
|
free_cache_engine: True
|
||||||
|
load_format: dummy_dtensor
|
||||||
|
tensor_model_parallel_size: 2
|
||||||
|
max_num_batched_tokens: 8192
|
||||||
|
max_num_seqs: 1024
|
||||||
|
log_prob_micro_batch_size: null # will be deprecated, use log_prob_micro_batch_size_per_gpu
|
||||||
|
log_prob_micro_batch_size_per_gpu: null
|
||||||
|
log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz}
|
||||||
|
log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}
|
||||||
|
disable_log_stats: True
|
||||||
|
enable_chunked_prefill: True # could get higher throughput
|
||||||
|
# for hf rollout
|
||||||
|
do_sample: True
|
||||||
|
# number of responses (i.e. num sample times)
|
||||||
|
n: 16 # > 1 for grpo
|
||||||
|
|
||||||
|
critic:
|
||||||
|
strategy: fsdp
|
||||||
|
optim:
|
||||||
|
lr: 1e-5
|
||||||
|
lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime
|
||||||
|
min_lr_ratio: null # only useful for warmup with cosine
|
||||||
|
warmup_style: constant # select from constant/cosine
|
||||||
|
total_training_steps: -1 # must be override by program
|
||||||
|
model:
|
||||||
|
path: ~/models/deepseek-llm-7b-chat
|
||||||
|
tokenizer_path: ${actor_rollout_ref.model.path}
|
||||||
|
override_config: { }
|
||||||
|
external_lib: ${actor_rollout_ref.model.external_lib}
|
||||||
|
enable_gradient_checkpointing: True
|
||||||
|
use_remove_padding: False
|
||||||
|
fsdp_config:
|
||||||
|
param_offload: False
|
||||||
|
optimizer_offload: False
|
||||||
|
wrap_policy:
|
||||||
|
# transformer_layer_cls_to_wrap: None
|
||||||
|
min_num_params: 0
|
||||||
|
fsdp_size: -1
|
||||||
|
ppo_mini_batch_size: ${actor_rollout_ref.actor.ppo_mini_batch_size}
|
||||||
|
ppo_micro_batch_size: null # will be deprecated, use ppo_micro_batch_size_per_gpu
|
||||||
|
ppo_micro_batch_size_per_gpu: null
|
||||||
|
forward_micro_batch_size: ${critic.ppo_micro_batch_size}
|
||||||
|
forward_micro_batch_size_per_gpu: ${critic.ppo_micro_batch_size_per_gpu}
|
||||||
|
use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz}
|
||||||
|
ppo_max_token_len_per_gpu: 32768 # (${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}) * 2
|
||||||
|
forward_max_token_len_per_gpu: ${critic.ppo_max_token_len_per_gpu}
|
||||||
|
ulysses_sequence_parallel_size: 1 # sp size
|
||||||
|
ppo_epochs: ${actor_rollout_ref.actor.ppo_epochs}
|
||||||
|
shuffle: ${actor_rollout_ref.actor.shuffle}
|
||||||
|
grad_clip: 1.0
|
||||||
|
cliprange_value: 0.5
|
||||||
|
|
||||||
|
reward_model:
|
||||||
|
enable: False
|
||||||
|
strategy: fsdp
|
||||||
|
model:
|
||||||
|
input_tokenizer: ${actor_rollout_ref.model.path} # set this to null if the chat template is identical
|
||||||
|
path: ~/models/FsfairX-LLaMA3-RM-v0.1
|
||||||
|
external_lib: ${actor_rollout_ref.model.external_lib}
|
||||||
|
use_remove_padding: False
|
||||||
|
fsdp_config:
|
||||||
|
min_num_params: 0
|
||||||
|
param_offload: False
|
||||||
|
fsdp_size: -1
|
||||||
|
micro_batch_size: null # will be deprecated, use micro_batch_size_per_gpu
|
||||||
|
micro_batch_size_per_gpu: null # set a number
|
||||||
|
max_length: null
|
||||||
|
ulysses_sequence_parallel_size: 1 # sp size
|
||||||
|
use_dynamic_bsz: ${critic.use_dynamic_bsz}
|
||||||
|
forward_max_token_len_per_gpu: ${critic.forward_max_token_len_per_gpu}
|
||||||
|
|
||||||
|
algorithm:
|
||||||
|
gamma: 1.0
|
||||||
|
lam: 1.0
|
||||||
|
adv_estimator: gae
|
||||||
|
kl_penalty: kl # how to estimate kl divergence
|
||||||
|
kl_ctrl:
|
||||||
|
type: fixed
|
||||||
|
kl_coef: 0.001
|
||||||
|
|
||||||
|
trainer:
|
||||||
|
total_epochs: 30
|
||||||
|
total_training_steps: null
|
||||||
|
project_name: verl_examples
|
||||||
|
experiment_name: gsm8k
|
||||||
|
logger: [ 'console', 'wandb' ]
|
||||||
|
val_generations_to_log_to_wandb: 0
|
||||||
|
nnodes: 1
|
||||||
|
n_gpus_per_node: 8
|
||||||
|
save_freq: -1
|
||||||
|
# auto: find the last ckpt to resume. If can't find, start from scratch
|
||||||
|
resume_mode: auto # or auto or resume_path if
|
||||||
|
resume_from_path: False
|
||||||
|
test_freq: -1
|
||||||
|
critic_warmup: 0
|
||||||
|
default_hdfs_dir: null
|
||||||
|
remove_previous_ckpt_in_save: False
|
||||||
|
del_local_ckpt_after_load: False
|
||||||
|
default_local_dir: checkpoints/${trainer.project_name}/${trainer.experiment_name}
|
||||||
|
|
@ -45,7 +45,6 @@ actor_rollout_ref:
|
||||||
# transformer_layer_cls_to_wrap: None
|
# transformer_layer_cls_to_wrap: None
|
||||||
min_num_params: 0
|
min_num_params: 0
|
||||||
param_offload: False
|
param_offload: False
|
||||||
grad_offload: False
|
|
||||||
optimizer_offload: False
|
optimizer_offload: False
|
||||||
fsdp_size: -1
|
fsdp_size: -1
|
||||||
ref:
|
ref:
|
||||||
|
|
@ -104,7 +103,6 @@ critic:
|
||||||
use_remove_padding: False
|
use_remove_padding: False
|
||||||
fsdp_config:
|
fsdp_config:
|
||||||
param_offload: False
|
param_offload: False
|
||||||
grad_offload: False
|
|
||||||
optimizer_offload: False
|
optimizer_offload: False
|
||||||
wrap_policy:
|
wrap_policy:
|
||||||
# transformer_layer_cls_to_wrap: None
|
# transformer_layer_cls_to_wrap: None
|
||||||
|
|
@ -158,10 +156,16 @@ trainer:
|
||||||
project_name: verl_examples
|
project_name: verl_examples
|
||||||
experiment_name: gsm8k
|
experiment_name: gsm8k
|
||||||
logger: [ 'console', 'wandb' ]
|
logger: [ 'console', 'wandb' ]
|
||||||
|
val_generations_to_log_to_wandb: 0
|
||||||
nnodes: 1
|
nnodes: 1
|
||||||
n_gpus_per_node: 8
|
n_gpus_per_node: 8
|
||||||
save_freq: -1
|
save_freq: -1
|
||||||
|
# auto: find the last ckpt to resume. If can't find, start from scratch
|
||||||
|
resume_mode: auto # or auto or resume_path if
|
||||||
|
resume_from_path: False
|
||||||
test_freq: -1
|
test_freq: -1
|
||||||
critic_warmup: 0
|
critic_warmup: 0
|
||||||
default_hdfs_dir: ~/experiments/gsm8k/ppo/${trainer.experiment_name}
|
default_hdfs_dir: null
|
||||||
|
remove_previous_ckpt_in_save: False
|
||||||
|
del_local_ckpt_after_load: False
|
||||||
default_local_dir: checkpoints/${trainer.project_name}/${trainer.experiment_name}
|
default_local_dir: checkpoints/${trainer.project_name}/${trainer.experiment_name}
|
||||||
|
|
|
||||||
|
|
@ -6,4 +6,4 @@ export ROLLOUT_TP_SIZE=2
|
||||||
export EXPERIMENT_NAME=chain_sum_llama
|
export EXPERIMENT_NAME=chain_sum_llama
|
||||||
export VLLM_ATTENTION_BACKEND=XFORMERS
|
export VLLM_ATTENTION_BACKEND=XFORMERS
|
||||||
|
|
||||||
bash ./train.sh
|
bash ./train_grpo.sh
|
||||||
|
|
|
||||||
39
examples/veRL/train_grpo.sh
Normal file
39
examples/veRL/train_grpo.sh
Normal file
|
|
@ -0,0 +1,39 @@
|
||||||
|
#!/bin/bash
|
||||||
|
set -x
|
||||||
|
|
||||||
|
python3 -u main_ppo_custom_reward.py \
|
||||||
|
algorithm.adv_estimator=grpo \
|
||||||
|
data.train_files=$DATA_DIR/train.parquet \
|
||||||
|
data.val_files=$DATA_DIR/test.parquet \
|
||||||
|
data.train_batch_size=1024 \
|
||||||
|
data.val_batch_size=1312 \
|
||||||
|
data.max_prompt_length=512 \
|
||||||
|
data.max_response_length=1024 \
|
||||||
|
actor_rollout_ref.model.path=$BASE_MODEL \
|
||||||
|
actor_rollout_ref.actor.optim.lr=1e-6 \
|
||||||
|
actor_rollout_ref.model.use_remove_padding=True \
|
||||||
|
actor_rollout_ref.actor.ppo_mini_batch_size=256 \
|
||||||
|
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=80 \
|
||||||
|
actor_rollout_ref.actor.use_kl_loss=True \
|
||||||
|
actor_rollout_ref.actor.kl_loss_coef=0.001 \
|
||||||
|
actor_rollout_ref.actor.kl_loss_type=low_var_kl \
|
||||||
|
actor_rollout_ref.model.enable_gradient_checkpointing=True \
|
||||||
|
actor_rollout_ref.actor.fsdp_config.param_offload=False \
|
||||||
|
actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
|
||||||
|
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=160 \
|
||||||
|
actor_rollout_ref.rollout.tensor_model_parallel_size=$ROLLOUT_TP_SIZE \
|
||||||
|
actor_rollout_ref.rollout.name=vllm \
|
||||||
|
actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \
|
||||||
|
actor_rollout_ref.rollout.n=8 \
|
||||||
|
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=160 \
|
||||||
|
actor_rollout_ref.ref.fsdp_config.param_offload=True \
|
||||||
|
algorithm.kl_ctrl.kl_coef=0.001 \
|
||||||
|
trainer.critic_warmup=0 \
|
||||||
|
trainer.logger=['console'] \
|
||||||
|
trainer.project_name='verl_chain_sum_grpo' \
|
||||||
|
trainer.experiment_name=$EXPERIMENT_NAME \
|
||||||
|
trainer.n_gpus_per_node=$N_GPUS \
|
||||||
|
trainer.nnodes=1 \
|
||||||
|
trainer.save_freq=100 \
|
||||||
|
trainer.test_freq=100 \
|
||||||
|
trainer.total_epochs=15 $@ 2>&1 | tee verl_output.log
|
||||||
|
|
@ -25,6 +25,6 @@ trainer.n_gpus_per_node=$N_GPUS \
|
||||||
trainer.nnodes=1 \
|
trainer.nnodes=1 \
|
||||||
trainer.save_freq=100 \
|
trainer.save_freq=100 \
|
||||||
trainer.test_freq=100 \
|
trainer.test_freq=100 \
|
||||||
trainer.project_name=verl_chain_sum \
|
trainer.project_name='verl_chain_sum_ppo' \
|
||||||
trainer.experiment_name=$EXPERIMENT_NAME \
|
trainer.experiment_name=$EXPERIMENT_NAME \
|
||||||
trainer.total_epochs=15 2>&1 | tee verl_output.log
|
trainer.total_epochs=15 2>&1 | tee verl_output.log
|
||||||
|
|
@ -1,9 +1,10 @@
|
||||||
import random
|
import random
|
||||||
import string
|
import warnings
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Dict, Optional, Tuple
|
from typing import Any, Dict, Optional, Tuple
|
||||||
|
|
||||||
import sympy as sp
|
import sympy as sp
|
||||||
|
from sympy.polys.monomials import itermonomials
|
||||||
|
|
||||||
from ..factory import ProceduralDataset, register_dataset
|
from ..factory import ProceduralDataset, register_dataset
|
||||||
|
|
||||||
|
|
@ -18,11 +19,13 @@ class PolynomialMultiplicationConfig:
|
||||||
max_terms: int = 4 # Maximum number of polynomial terms
|
max_terms: int = 4 # Maximum number of polynomial terms
|
||||||
min_value: int = 1 # Minimum value for coefficients
|
min_value: int = 1 # Minimum value for coefficients
|
||||||
max_value: int = 100 # Maximum value for coefficients
|
max_value: int = 100 # Maximum value for coefficients
|
||||||
min_degree: int = 1 # Minimum polynomial degree
|
min_degree: int = 0 # Minimum polynomial degree
|
||||||
max_degree: int = 3 # Maximum polynomial degree
|
max_degree: int = 3 # Maximum polynomial degree
|
||||||
min_polynomials: int = 2 # Minimum number of polynomials being multiplied
|
min_polynomials: int = 2 # Minimum number of polynomials being multiplied
|
||||||
max_polynomials: int = 3 # Maximum number of polynomials being multiplied
|
max_polynomials: int = 3 # Maximum number of polynomials being multiplied
|
||||||
single_variable: bool = True
|
variables: Tuple[str] = ("x", "y", "z") # Tuple of variable names, that will be chosen randomly
|
||||||
|
allow_cross_variable_product: bool = False # Generate tasks like "Multiply (x^2+3x-1)*(y^2-5)"
|
||||||
|
allow_multivariate_polynomials: bool = False # Generate multivariate tasks like "Multiply (2x^2 + 3y)*(5x^2+3x-1)"
|
||||||
operators: Tuple[str, ...] = (
|
operators: Tuple[str, ...] = (
|
||||||
"+",
|
"+",
|
||||||
"-",
|
"-",
|
||||||
|
|
@ -38,12 +41,17 @@ class PolynomialMultiplicationConfig:
|
||||||
assert self.min_value > 0, "min_value must be positive."
|
assert self.min_value > 0, "min_value must be positive."
|
||||||
assert self.max_value >= self.min_value, "max_value must be >= min_value."
|
assert self.max_value >= self.min_value, "max_value must be >= min_value."
|
||||||
|
|
||||||
assert self.min_degree >= 1, "min_degree must be >= 1."
|
assert self.min_degree >= 0, "min_degree must be >= 0."
|
||||||
assert self.max_degree >= self.min_degree, "max_degree must be >= min_degree."
|
assert self.max_degree >= self.min_degree, "max_degree must be >= min_degree."
|
||||||
|
|
||||||
assert self.min_polynomials >= 2, "min_polynomials must be >= 2."
|
assert self.min_polynomials >= 2, "min_polynomials must be >= 2."
|
||||||
assert self.max_polynomials >= self.min_polynomials, "max_polynomials must be >= min_polynomials."
|
assert self.max_polynomials >= self.min_polynomials, "max_polynomials must be >= min_polynomials."
|
||||||
|
|
||||||
|
assert len(self.variables) > 0, "The variable tuple is empty."
|
||||||
|
assert not (
|
||||||
|
self.allow_multivariate_polynomials and not self.allow_cross_variable_product
|
||||||
|
), "Multivariate polynomials require cross product."
|
||||||
|
|
||||||
allowed_ops = {"+", "-"}
|
allowed_ops = {"+", "-"}
|
||||||
assert len(self.operators) > 0, "operators tuple cannot be empty."
|
assert len(self.operators) > 0, "operators tuple cannot be empty."
|
||||||
assert all(op in allowed_ops for op in self.operators), "Invalid operator found. Must be a subset of {+, -}."
|
assert all(op in allowed_ops for op in self.operators), "Invalid operator found. Must be a subset of {+, -}."
|
||||||
|
|
@ -76,13 +84,24 @@ In addition, When doing calculation, Use the following instructions together wit
|
||||||
A dict with:
|
A dict with:
|
||||||
- question: str (e.g. "Multiply polynomials: (8x^3 + x + 2)*(x - 3)")
|
- question: str (e.g. "Multiply polynomials: (8x^3 + x + 2)*(x - 3)")
|
||||||
- answer: str (Product, e.g. "8x^4 - 24x^3 + x^2 - x - 6")
|
- answer: str (Product, e.g. "8x^4 - 24x^3 + x^2 - x - 6")
|
||||||
- metadata: dict with details (polynomial_expr, single_variable)
|
- metadata: dict with details (polynomial_expr, result, variables)
|
||||||
"""
|
"""
|
||||||
rng = random.Random(self.seed + idx)
|
|
||||||
number_polynomials = rng.randint(self.config.min_polynomials, self.config.max_polynomials)
|
|
||||||
polynomials = [self._generate_polynomial_expr(rng) for i in range(number_polynomials)]
|
|
||||||
|
|
||||||
polynomial_expr = sp.prod(polynomials)
|
rng = random.Random(self.seed + idx)
|
||||||
|
|
||||||
|
"""
|
||||||
|
Three Monomial States:
|
||||||
|
- allow_multivariate_polynomials == 1: list of multivariate monomials (e.g "xy" --> [x, y, xy, x**2, y**2])
|
||||||
|
- allow_cross_variable_product == 1: None. Will generate a unique list of single variable monomials for each term
|
||||||
|
- allow_cross_variable_product == 0: A shared list of monomials for each term (e.g "x" --> [x, x**2, 1])
|
||||||
|
"""
|
||||||
|
monomials = self._get_monomials(rng) if self.config.allow_cross_variable_product else None
|
||||||
|
monomials = None if self.config.allow_cross_variable_product else self._get_monomials(rng)
|
||||||
|
|
||||||
|
number_polynomials = rng.randint(self.config.min_polynomials, self.config.max_polynomials)
|
||||||
|
|
||||||
|
polynomial_terms = [self._generate_polynomial(rng, monomials) for _ in range(number_polynomials)]
|
||||||
|
polynomial_expr = sp.prod(polynomial_terms)
|
||||||
product = sp.expand(polynomial_expr)
|
product = sp.expand(polynomial_expr)
|
||||||
question = rng.choice(self._prompt_templates).format(polynomial_expr=polynomial_expr) + self.added_instruction
|
question = rng.choice(self._prompt_templates).format(polynomial_expr=polynomial_expr) + self.added_instruction
|
||||||
|
|
||||||
|
|
@ -91,54 +110,39 @@ In addition, When doing calculation, Use the following instructions together wit
|
||||||
"answer": product,
|
"answer": product,
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"polynomial_expr": str(polynomial_expr),
|
"polynomial_expr": str(polynomial_expr),
|
||||||
"single_variable": self.config.single_variable,
|
|
||||||
"result": str(product),
|
"result": str(product),
|
||||||
|
"variables": list(product.free_symbols),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
def _get_variable(self, rng: random.Random) -> str:
|
def _get_monomials(self, rng: random.Random) -> str:
|
||||||
"""Get a random lowercase variable name"""
|
"""Get a list of monomials"""
|
||||||
if self.config.single_variable:
|
if self.config.allow_multivariate_polynomials:
|
||||||
return "x"
|
sym = sp.symbols(self.config.variables)
|
||||||
return rng.choice(string.ascii_lowercase)
|
else:
|
||||||
|
sym = [sp.symbols(rng.choice(self.config.variables))]
|
||||||
def _generate_polynomial_expr(self, rng: random.Random):
|
monomials = list(itermonomials(sym, self.config.max_degree, self.config.min_degree))
|
||||||
"""
|
return monomials
|
||||||
Randomly generate a polynomial expression of 'degree'.
|
|
||||||
We'll use the config parameters:
|
|
||||||
- min_terms, max_terms: how many total terms to combine
|
|
||||||
- min_value, max_value: range for coefficients
|
|
||||||
- operators: to decide sign flips or direct addition
|
|
||||||
|
|
||||||
Args:
|
|
||||||
rng: Random number generator
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Polynomial string
|
|
||||||
"""
|
|
||||||
variable = self._get_variable(rng)
|
|
||||||
degree = rng.randint(self.config.min_degree, self.config.max_degree)
|
|
||||||
|
|
||||||
x = sp.Symbol(variable)
|
|
||||||
|
|
||||||
|
def _generate_polynomial(self, rng: random.Random, monomials: Optional[list]):
|
||||||
|
"""Generates a random polynomial, returns expression."""
|
||||||
# Choose the number of terms and their respective degrees
|
# Choose the number of terms and their respective degrees
|
||||||
|
monomials = monomials if monomials else self._get_monomials(rng)
|
||||||
num_terms = rng.randint(self.config.min_terms, self.config.max_terms)
|
num_terms = rng.randint(self.config.min_terms, self.config.max_terms)
|
||||||
# Keep track of exponents, exponents can repeat or skip but we force the highest exponent
|
|
||||||
chosen_exponents = [degree]
|
|
||||||
# Fill the rest randomly in [0, degree]
|
|
||||||
for _ in range(num_terms - 1):
|
|
||||||
exp = rng.randint(0, degree)
|
|
||||||
chosen_exponents.append(exp)
|
|
||||||
|
|
||||||
# Now build the polynomial expression: sum_{term}( coeff * x^exponent ), with optional sign
|
|
||||||
polynomial_expr = 0
|
polynomial_expr = 0
|
||||||
for exp in chosen_exponents:
|
for _ in range(num_terms):
|
||||||
|
# Pick a nonzero random coefficient between min_value and max_value.
|
||||||
coeff = rng.randint(self.config.min_value, self.config.max_value)
|
coeff = rng.randint(self.config.min_value, self.config.max_value)
|
||||||
|
|
||||||
|
# Pick a random monomial
|
||||||
|
var = rng.choice(monomials)
|
||||||
|
|
||||||
# If '-' in operators, we can randomly flip the sign
|
# If '-' in operators, we can randomly flip the sign
|
||||||
if "-" in self.config.operators and rng.random() < 0.5:
|
if "-" in self.config.operators and rng.random() < 0.5:
|
||||||
coeff = -coeff
|
coeff = -coeff
|
||||||
term_expr = coeff * (x**exp)
|
|
||||||
polynomial_expr += term_expr
|
polynomial_expr += coeff * var
|
||||||
|
|
||||||
return polynomial_expr
|
return polynomial_expr
|
||||||
|
|
||||||
|
|
@ -151,7 +155,7 @@ In addition, When doing calculation, Use the following instructions together wit
|
||||||
target_poly = sp.parse_expr(metadata["result"])
|
target_poly = sp.parse_expr(metadata["result"])
|
||||||
|
|
||||||
# Check if the difference simplifies to zero (i.e. they are equivalent).
|
# Check if the difference simplifies to zero (i.e. they are equivalent).
|
||||||
if sp.simplify(predicted_poly - target_poly) == 0:
|
if predicted_poly == target_poly:
|
||||||
reward = 1.0
|
reward = 1.0
|
||||||
elif answer.strip():
|
elif answer.strip():
|
||||||
reward = 0.05
|
reward = 0.05
|
||||||
|
|
|
||||||
|
|
@ -11,6 +11,7 @@ from .base_conversion import BaseConversionConfig, BaseConversionDataset
|
||||||
from .binary_matrix import BinaryMatrixConfig, BinaryMatrixDataset
|
from .binary_matrix import BinaryMatrixConfig, BinaryMatrixDataset
|
||||||
from .caesar_cipher import CaesarCipherConfig, CaesarCipherDataset
|
from .caesar_cipher import CaesarCipherConfig, CaesarCipherDataset
|
||||||
from .count_primes import CountPrimesConfig, CountPrimesDataset
|
from .count_primes import CountPrimesConfig, CountPrimesDataset
|
||||||
|
from .cryptarithm import CryptarithmConfig, CryptarithmDataset
|
||||||
from .game_of_life import GameOfLifeConfig, GameOfLifeDataset
|
from .game_of_life import GameOfLifeConfig, GameOfLifeDataset
|
||||||
from .graph_color import GraphColorConfig, GraphColorDataset
|
from .graph_color import GraphColorConfig, GraphColorDataset
|
||||||
from .group_anagrams import GroupAnagramsConfig, GroupAnagramsDataset
|
from .group_anagrams import GroupAnagramsConfig, GroupAnagramsDataset
|
||||||
|
|
@ -21,6 +22,7 @@ from .manipulate_matrix import ManipulateMatrixConfig, ManipulateMatrixDataset
|
||||||
from .number_filtering import NumberFilteringConfig, NumberFilteringDataset
|
from .number_filtering import NumberFilteringConfig, NumberFilteringDataset
|
||||||
from .number_sorting import NumberSortingConfig, NumberSortingDataset
|
from .number_sorting import NumberSortingConfig, NumberSortingDataset
|
||||||
from .palindrome_generation import PalindromeConfig, PalindromeDataset
|
from .palindrome_generation import PalindromeConfig, PalindromeDataset
|
||||||
|
from .palindrome_partitioning import PalindromePartitioningConfig, PalindromePartitioningDataset
|
||||||
from .pool_matrix import PoolMatrixConfig, PoolMatrixDataset
|
from .pool_matrix import PoolMatrixConfig, PoolMatrixDataset
|
||||||
from .ransom_note import RansomNoteConfig, RansomNoteDataset
|
from .ransom_note import RansomNoteConfig, RansomNoteDataset
|
||||||
from .rotate_matrix import RotateMatrixConfig, RotateMatrixDataset
|
from .rotate_matrix import RotateMatrixConfig, RotateMatrixDataset
|
||||||
|
|
@ -42,6 +44,8 @@ __all__ = [
|
||||||
"BaseConversionDataset",
|
"BaseConversionDataset",
|
||||||
"CaesarCipherConfig",
|
"CaesarCipherConfig",
|
||||||
"CaesarCipherDataset",
|
"CaesarCipherDataset",
|
||||||
|
"CryptarithmConfig",
|
||||||
|
"CryptarithmDataset",
|
||||||
"GameOfLifeConfig",
|
"GameOfLifeConfig",
|
||||||
"GameOfLifeDataset",
|
"GameOfLifeDataset",
|
||||||
"LetterCountingConfig",
|
"LetterCountingConfig",
|
||||||
|
|
@ -65,6 +69,8 @@ __all__ = [
|
||||||
"PalindromeDataset",
|
"PalindromeDataset",
|
||||||
"GroupAnagramsConfig",
|
"GroupAnagramsConfig",
|
||||||
"GroupAnagramsDataset",
|
"GroupAnagramsDataset",
|
||||||
|
"PalindromePartitioningConfig",
|
||||||
|
"PalindromePartitioningDataset",
|
||||||
"SpiralMatrixConfig",
|
"SpiralMatrixConfig",
|
||||||
"SpiralMatrixDataset",
|
"SpiralMatrixDataset",
|
||||||
"RansomNoteConfig",
|
"RansomNoteConfig",
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,26 @@ from typing import Optional, Tuple
|
||||||
|
|
||||||
from ..factory import ProceduralDataset, register_dataset
|
from ..factory import ProceduralDataset, register_dataset
|
||||||
|
|
||||||
|
QUESTION_TEMPLATE = """Your task is to convert a number between two different bases.
|
||||||
|
|
||||||
|
If the target base is > 10, use lowercase letters a-z for digits above 9.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
- Input: Convert the base-9 number 440 to base-5
|
||||||
|
- Output: 2420
|
||||||
|
- Explanation
|
||||||
|
- First, we convert the base-9 number 440 to base-10: 4 * 9**2 + 4 * 9**1 + 0 * 9**0 = 324 + 36 + 0 = 360
|
||||||
|
- Next, we convert the base-10 number 360 to base-5:
|
||||||
|
- 360 // 5 = 72 remainder 0
|
||||||
|
- 72 // 5 = 14 remainder 2
|
||||||
|
- 14 // 5 = 2 remainder 4
|
||||||
|
- 2 // 5 = 0 remainder 2
|
||||||
|
- Reading the remainders in reverse order gives us the base-5 number 2 4 2 0
|
||||||
|
- Hence, the final answer is 2420
|
||||||
|
|
||||||
|
Now, convert the {source_name} number {source_repr} to {target_name}
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BaseConversionConfig:
|
class BaseConversionConfig:
|
||||||
|
|
@ -90,11 +110,10 @@ class BaseConversionDataset(ProceduralDataset):
|
||||||
source_name = self._format_base_name(source_base)
|
source_name = self._format_base_name(source_base)
|
||||||
target_name = self._format_base_name(target_base)
|
target_name = self._format_base_name(target_base)
|
||||||
|
|
||||||
# Add hint for bases > 10 about using lowercase letters
|
|
||||||
hint = " (use lowercase letters a-z for digits above 9)" if target_base > 10 else ""
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"question": f"Convert the {source_name} number {source_repr} to {target_name}{hint}",
|
"question": QUESTION_TEMPLATE.format(
|
||||||
|
source_name=source_name, source_repr=source_repr, target_name=target_name
|
||||||
|
),
|
||||||
"answer": target_repr,
|
"answer": target_repr,
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"decimal_value": value,
|
"decimal_value": value,
|
||||||
|
|
|
||||||
|
|
@ -7,23 +7,28 @@ https://leetcode.com/problems/01-matrix/description/
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from random import Random
|
from random import Random
|
||||||
from typing import Optional
|
from typing import Dict, Optional
|
||||||
|
|
||||||
from ..factory import ProceduralDataset, register_dataset
|
from ..factory import ProceduralDataset, register_dataset
|
||||||
|
|
||||||
QUESTION_TEMPLATE = """Given a square matrix, your job is to find the taxicab distance of the nearest 0 for each cell.
|
QUESTION_TEMPLATE = """Given a square matrix, your job is to find the taxicab (Manhattan) distance of the nearest 0 for each cell.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
- Input: Find the distance to the nearest 0 for each cell in the matrix below:
|
||||||
Input: Find the distance to the nearest 0 for each cell in the matrix below:
|
|
||||||
0 0 0
|
0 0 0
|
||||||
0 1 0
|
0 1 0
|
||||||
1 1 1
|
1 1 1
|
||||||
|
- Output:
|
||||||
Output:
|
|
||||||
0 0 0
|
0 0 0
|
||||||
0 1 0
|
0 1 0
|
||||||
1 2 1
|
1 2 1
|
||||||
|
- Explanation
|
||||||
|
- Each cell with a 0 has a distance of 0 to itself.
|
||||||
|
- The cell at (1, 1) has a distance of 1 to the nearest 0 (any of the three 0's at (1, 0), (0, 1), (1, 2)).
|
||||||
|
- The cell at (2, 0) has a distance of 1 to the nearest 0 (the 0 at (1, 0)).
|
||||||
|
- The cell at (2, 1) has a distance of 2 to the nearest 0 (any of the two 0's at (1, 0), (1, 2))
|
||||||
|
- The cell at (2, 2) has a distance of 1 to the nearest 0 (the 0 at (1, 2)).
|
||||||
|
- Hence, the final answer is the matrix is the output shown above, where each cell contains the distance to the nearest 0, in the same format as the input matrix.
|
||||||
|
|
||||||
Find the distance to the nearest 0 for each cell in the matrix below:
|
Find the distance to the nearest 0 for each cell in the matrix below:
|
||||||
{matrix}
|
{matrix}
|
||||||
|
|
@ -34,6 +39,7 @@ Find the distance to the nearest 0 for each cell in the matrix below:
|
||||||
class BinaryMatrixConfig:
|
class BinaryMatrixConfig:
|
||||||
"""Configuration for Binary Matrix dataset generation"""
|
"""Configuration for Binary Matrix dataset generation"""
|
||||||
|
|
||||||
|
min_n: int = 3 # Minimum size of the matrix
|
||||||
max_n: int = 10 # Maximum size of the matrix
|
max_n: int = 10 # Maximum size of the matrix
|
||||||
p_zero: float = 0.25 # Probability of a cell being 0
|
p_zero: float = 0.25 # Probability of a cell being 0
|
||||||
|
|
||||||
|
|
@ -42,7 +48,8 @@ class BinaryMatrixConfig:
|
||||||
|
|
||||||
def validate(self):
|
def validate(self):
|
||||||
"""Validate configuration parameters"""
|
"""Validate configuration parameters"""
|
||||||
assert 1 <= self.max_n, "max_n must be at least 1"
|
assert 1 <= self.min_n, "min_n must be at least 1"
|
||||||
|
assert self.min_n <= self.max_n, "min_n must be less than or equal to max_n"
|
||||||
assert 0 < self.p_zero <= 1, "p_zero must be between 0 and 1"
|
assert 0 < self.p_zero <= 1, "p_zero must be between 0 and 1"
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -54,7 +61,7 @@ class BinaryMatrixDataset(ProceduralDataset):
|
||||||
|
|
||||||
def _get_binary_matrix(self, rng: Random) -> list[list[int]]:
|
def _get_binary_matrix(self, rng: Random) -> list[list[int]]:
|
||||||
"""Generate a random binary matrix"""
|
"""Generate a random binary matrix"""
|
||||||
n = rng.randint(1, self.config.max_n)
|
n = rng.randint(self.config.min_n, self.config.max_n)
|
||||||
# Ensure at least one 0 in the matrix, so that a solution exists
|
# Ensure at least one 0 in the matrix, so that a solution exists
|
||||||
numbers = [0] + [0 if rng.random() < self.config.p_zero else 1 for _ in range(n**2 - 1)]
|
numbers = [0] + [0 if rng.random() < self.config.p_zero else 1 for _ in range(n**2 - 1)]
|
||||||
rng.shuffle(numbers)
|
rng.shuffle(numbers)
|
||||||
|
|
@ -105,6 +112,22 @@ class BinaryMatrixDataset(ProceduralDataset):
|
||||||
"""Get a string representation of the matrix"""
|
"""Get a string representation of the matrix"""
|
||||||
return "\n".join(" ".join(str(x) for x in row) for row in matrix)
|
return "\n".join(" ".join(str(x) for x in row) for row in matrix)
|
||||||
|
|
||||||
|
def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float:
|
||||||
|
"""Overwrite this method in derived classes if a single oracle answer is not available."""
|
||||||
|
oracle_answer = entry["answer"]
|
||||||
|
if answer is not None:
|
||||||
|
if answer == oracle_answer:
|
||||||
|
return 1.0
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
# check if answer is python list of lists
|
||||||
|
answer = self._matrix_to_str(eval(answer))
|
||||||
|
if answer == oracle_answer:
|
||||||
|
return 0.5
|
||||||
|
except Exception as e:
|
||||||
|
return 0.01
|
||||||
|
return 0.0
|
||||||
|
|
||||||
def __getitem__(self, idx: int) -> dict:
|
def __getitem__(self, idx: int) -> dict:
|
||||||
"""Generate a single Binary Matrix question"""
|
"""Generate a single Binary Matrix question"""
|
||||||
rng = Random(self.seed + idx)
|
rng = Random(self.seed + idx)
|
||||||
|
|
|
||||||
206
reasoning_gym/algorithmic/cryptarithm.py
Normal file
206
reasoning_gym/algorithmic/cryptarithm.py
Normal file
|
|
@ -0,0 +1,206 @@
|
||||||
|
"""
|
||||||
|
Cryptarithm puzzle generator (numbers -> letters approach).
|
||||||
|
|
||||||
|
Generates puzzles such that:
|
||||||
|
WORD1
|
||||||
|
+ WORD2
|
||||||
|
[+ WORD3]
|
||||||
|
---------
|
||||||
|
RESULT
|
||||||
|
where each letter corresponds to exactly one digit (0..9).
|
||||||
|
No leading letter can be zero (unless allow_leading_zero=True).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from random import Random
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from ..factory import ProceduralDataset, register_dataset
|
||||||
|
|
||||||
|
EXAMPLE_CASE = """
|
||||||
|
BASE
|
||||||
|
+ BALL
|
||||||
|
------
|
||||||
|
GAMES
|
||||||
|
|
||||||
|
Answer (one possible solution):
|
||||||
|
|
||||||
|
B=7, A=8, S=2, E=9, L=1, G=1, M=0
|
||||||
|
Summation: 7829 + 7811 = 15640 (the puzzle might produce a different arrangement, but the principle is the same)."""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CryptarithmConfig:
|
||||||
|
"""Configuration for Cryptarithm dataset generation."""
|
||||||
|
|
||||||
|
min_words: int = 2 # Minimum number of addends
|
||||||
|
max_words: int = 3 # Maximum number of addends
|
||||||
|
allow_leading_zero: bool = False
|
||||||
|
include_example: bool = True
|
||||||
|
seed: Optional[int] = None
|
||||||
|
size: int = 500 # Number of puzzle instances to generate
|
||||||
|
|
||||||
|
def validate(self):
|
||||||
|
"""Validate configuration parameters."""
|
||||||
|
assert 2 <= self.min_words <= self.max_words, "min_words must be <= max_words, both >= 2."
|
||||||
|
assert self.size > 0, "Dataset size must be positive."
|
||||||
|
|
||||||
|
|
||||||
|
class CryptarithmDataset(ProceduralDataset):
|
||||||
|
"""
|
||||||
|
Generates cryptarithm puzzles by:
|
||||||
|
1) Randomly choosing integers for each "addend" (with no leading zero if not allowed),
|
||||||
|
2) Summing them,
|
||||||
|
3) Mapping distinct digits (0..9) to letters (A..Z),
|
||||||
|
4) Formatting the puzzle text.
|
||||||
|
|
||||||
|
This approach guarantees sum correctness and avoids repeated failures.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config: CryptarithmConfig):
|
||||||
|
super().__init__(config=config, seed=config.seed, size=config.size)
|
||||||
|
|
||||||
|
def __getitem__(self, idx: int) -> dict:
|
||||||
|
rng = Random(self.seed + idx)
|
||||||
|
return self._create_single_puzzle(rng)
|
||||||
|
|
||||||
|
def _create_single_puzzle(self, rng: Random) -> dict:
|
||||||
|
"""
|
||||||
|
Creates one puzzle with N addends (2..3) plus a result.
|
||||||
|
Ensures total distinct digits <= 10.
|
||||||
|
"""
|
||||||
|
# 1) Pick how many addends
|
||||||
|
n_words = rng.randint(self.config.min_words, self.config.max_words)
|
||||||
|
|
||||||
|
# 2) For each addend, pick a random length (3..5) and then pick a random integer with that many digits.
|
||||||
|
# If leading zero is disallowed, the first digit is from 1..9.
|
||||||
|
word_lengths = [rng.randint(3, 5) for _ in range(n_words)]
|
||||||
|
words_numbers = []
|
||||||
|
for length in word_lengths:
|
||||||
|
if self.config.allow_leading_zero:
|
||||||
|
# e.g. random integer in [0, 10^length - 1], then zero-pad to length
|
||||||
|
num = rng.randint(0, 10**length - 1)
|
||||||
|
else:
|
||||||
|
# leading digit is from 1..9, rest are from 0..9
|
||||||
|
# e.g. random integer in [10^(length-1), 10^length - 1]
|
||||||
|
num = rng.randint(10 ** (length - 1), 10**length - 1)
|
||||||
|
words_numbers.append(num)
|
||||||
|
|
||||||
|
# 3) Compute the sum
|
||||||
|
total_sum = sum(words_numbers)
|
||||||
|
# The sum can have up to (max_length+1) digits, which is normal in cryptarithms.
|
||||||
|
|
||||||
|
# 4) Gather all digits from the addends and the sum
|
||||||
|
digits_in_use = set()
|
||||||
|
|
||||||
|
def collect_digits(num: int):
|
||||||
|
return set(str(num))
|
||||||
|
|
||||||
|
for wn in words_numbers:
|
||||||
|
digits_in_use.update(collect_digits(wn))
|
||||||
|
digits_in_use.update(collect_digits(total_sum))
|
||||||
|
|
||||||
|
# If we exceed 10 distinct digits, try again (pick new random numbers).
|
||||||
|
# In practice, we can loop until success. But for demonstration, let's do a simple re-pick approach.
|
||||||
|
# We'll do a while loop up to some attempts:
|
||||||
|
if len(digits_in_use) > 10:
|
||||||
|
# Just do a recursion call to pick new numbers, ignoring current picks
|
||||||
|
return self._create_single_puzzle(rng)
|
||||||
|
|
||||||
|
# 5) Map each digit to a letter
|
||||||
|
# If no leading zero is allowed, the leading digit of each addend + result must not map to '0'.
|
||||||
|
# Actually, we are generating real numeric values, so there's no scenario of leading "0" for
|
||||||
|
# the addends we enforced (except if allow_leading_zero is True).
|
||||||
|
# For the puzzle's perspective, we simply create a random assignment from {digits_in_use} -> letters.
|
||||||
|
# Then the solver has to figure it out. They don't see the digits, only letters.
|
||||||
|
|
||||||
|
digits_in_use_list = sorted(list(digits_in_use)) # e.g. ['0', '1', '3', '9']
|
||||||
|
rng.shuffle(digits_in_use_list) # shuffle so mapping is random
|
||||||
|
letters_pool = [chr(i) for i in range(ord("A"), ord("Z") + 1)]
|
||||||
|
rng.shuffle(letters_pool)
|
||||||
|
chosen_letters = letters_pool[: len(digits_in_use_list)]
|
||||||
|
|
||||||
|
# digit -> letter mapping
|
||||||
|
digit_to_letter = {}
|
||||||
|
for d, letter in zip(digits_in_use_list, chosen_letters):
|
||||||
|
digit_to_letter[d] = letter
|
||||||
|
|
||||||
|
# If leading-zero is not allowed, we must ensure that the first digit of each addend and the sum
|
||||||
|
# does not map to the letter that is assigned to digit '0'. If we see a conflict, we can just re-pick
|
||||||
|
# or we can try to swap letters. The simplest is to re-pick for demonstration.
|
||||||
|
if not self.config.allow_leading_zero and "0" in digit_to_letter:
|
||||||
|
zero_letter = digit_to_letter["0"]
|
||||||
|
# Check the first digit of each addend and of the sum
|
||||||
|
for wn in words_numbers:
|
||||||
|
first_digit = str(wn)[0]
|
||||||
|
if digit_to_letter.get(first_digit) == zero_letter:
|
||||||
|
# Conflict => re-generate puzzle
|
||||||
|
return self._create_single_puzzle(rng)
|
||||||
|
sum_first_digit = str(total_sum)[0]
|
||||||
|
if digit_to_letter.get(sum_first_digit) == zero_letter:
|
||||||
|
return self._create_single_puzzle(rng)
|
||||||
|
|
||||||
|
# Now we have a stable digit->letter mapping. Let's create the letter->digit mapping for the answer.
|
||||||
|
letter_to_digit = {v: int(k) for k, v in digit_to_letter.items()}
|
||||||
|
|
||||||
|
# 6) Convert each integer to its letter representation
|
||||||
|
def int_to_letter_str(num: int) -> str:
|
||||||
|
return "".join(digit_to_letter[d] for d in str(num))
|
||||||
|
|
||||||
|
words_letters = [int_to_letter_str(num) for num in words_numbers]
|
||||||
|
result_letters = int_to_letter_str(total_sum)
|
||||||
|
|
||||||
|
# 7) Create the puzzle text
|
||||||
|
# We'll do the typical vertical format, with a plus sign before the last addend, dashes, then result
|
||||||
|
puzzle_lines = []
|
||||||
|
max_width = max(len(w) for w in words_letters + [result_letters])
|
||||||
|
for i, wl in enumerate(words_letters):
|
||||||
|
if i < len(words_letters) - 1:
|
||||||
|
# Right align with spaces, +2 for the " " prefix
|
||||||
|
puzzle_lines.append(f"{wl:>{max_width+2}}")
|
||||||
|
else:
|
||||||
|
# Right align with spaces, +2 for the "+ " prefix
|
||||||
|
puzzle_lines.append(f"+ {wl:>{max_width}}")
|
||||||
|
|
||||||
|
# The line of dashes should match the longest line
|
||||||
|
puzzle_lines.append("-" * (max_width + 2))
|
||||||
|
# Right align the result
|
||||||
|
puzzle_lines.append(f"{result_letters:>{max_width+2}}")
|
||||||
|
|
||||||
|
puzzle_text = "\n".join(puzzle_lines)
|
||||||
|
|
||||||
|
question_str = (
|
||||||
|
"Solve this cryptarithm:\n\n"
|
||||||
|
f"{puzzle_text}\n\n"
|
||||||
|
"Each letter stands for a unique digit (0-9). "
|
||||||
|
+ (
|
||||||
|
"Leading letters may be zero.\n"
|
||||||
|
if self.config.allow_leading_zero
|
||||||
|
else "No leading letter can be zero.\n"
|
||||||
|
)
|
||||||
|
+ "Provide a mapping from letters to digits that satisfies the equation.\n"
|
||||||
|
)
|
||||||
|
if self.config.include_example:
|
||||||
|
question_str += "Here's an example:\n" + EXAMPLE_CASE
|
||||||
|
|
||||||
|
# 8) Create a human-readable answer, e.g. "A=1,B=0,C=9,..."
|
||||||
|
sorted_letter_keys = sorted(letter_to_digit.keys())
|
||||||
|
answer_str = ",".join(f"{letter}={letter_to_digit[letter]}" for letter in sorted_letter_keys)
|
||||||
|
|
||||||
|
# 9) Return the final puzzle item
|
||||||
|
return {
|
||||||
|
"question": question_str,
|
||||||
|
"answer": answer_str,
|
||||||
|
"metadata": {
|
||||||
|
"letters": list(letter_to_digit.keys()),
|
||||||
|
"word_values": words_numbers,
|
||||||
|
"sum_number": total_sum,
|
||||||
|
"words_letters": words_letters,
|
||||||
|
"result_letters": result_letters,
|
||||||
|
"digit_to_letter": digit_to_letter,
|
||||||
|
"letter_to_digit": letter_to_digit,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
register_dataset("cryptarithm", CryptarithmDataset, CryptarithmConfig)
|
||||||
|
|
@ -32,7 +32,7 @@ class GameOfLifeDataset(ProceduralDataset):
|
||||||
|
|
||||||
def __init__(self, config: GameOfLifeConfig):
|
def __init__(self, config: GameOfLifeConfig):
|
||||||
self._prompt_templates = [
|
self._prompt_templates = [
|
||||||
"What will this Game of Life board look like after {simulation_steps} steps of simulation? Reply as array of array representing rows in the grid from top to bottom in JSON format. (An empty 3x3 grid would look like this: [[0,0,0],[0,0,0],[0,0,0]])\n\n{board}."
|
"What will this Game of Life board look like after {simulation_steps} steps of simulation? Reply as array of array representing rows in the grid from top to bottom in JSON format. Let your answer(array of array be on a single line). (An empty 3x3 grid would look like this: [[0,0,0],[0,0,0],[0,0,0]])\n\n{board}."
|
||||||
]
|
]
|
||||||
|
|
||||||
super().__init__(config=config, seed=config.seed, size=config.size)
|
super().__init__(config=config, seed=config.seed, size=config.size)
|
||||||
|
|
|
||||||
|
|
@ -200,7 +200,7 @@ Vertices: {puzzle["vertices"]}
|
||||||
Edges: {edges}
|
Edges: {edges}
|
||||||
Possible colors: {puzzle["color_options"]}
|
Possible colors: {puzzle["color_options"]}
|
||||||
|
|
||||||
Return your solution as a JSON map of verteces to colors. (For example: {{0: 1, 1: 2, 2: 3}})
|
Return your solution as a JSON map of vertices to colors. (For example: {{0: 1, 1: 2, 2: 3}})
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
|
|
||||||
|
|
@ -90,7 +90,7 @@ class GroupAnagramsDataset(ProceduralDataset):
|
||||||
|
|
||||||
def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float:
|
def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float:
|
||||||
"""Score a single Group Anagrams question"""
|
"""Score a single Group Anagrams question"""
|
||||||
reward = 0
|
reward = 0.0
|
||||||
if answer is not None:
|
if answer is not None:
|
||||||
try:
|
try:
|
||||||
answer = json.loads(answer)
|
answer = json.loads(answer)
|
||||||
|
|
@ -98,11 +98,11 @@ class GroupAnagramsDataset(ProceduralDataset):
|
||||||
answer_str = json.dumps(self._sort_nested_list(answer))
|
answer_str = json.dumps(self._sort_nested_list(answer))
|
||||||
oracle_str = json.dumps(self._sort_nested_list(oracle))
|
oracle_str = json.dumps(self._sort_nested_list(oracle))
|
||||||
if answer_str == oracle_str:
|
if answer_str == oracle_str:
|
||||||
reward = 1
|
reward = 1.0
|
||||||
else:
|
else:
|
||||||
reward = 0.01
|
reward = 0.01
|
||||||
except Exception:
|
except Exception:
|
||||||
reward = 0
|
reward = 0.0
|
||||||
return reward
|
return reward
|
||||||
|
|
||||||
def __getitem__(self, idx: int) -> dict:
|
def __getitem__(self, idx: int) -> dict:
|
||||||
|
|
|
||||||
|
|
@ -9,6 +9,30 @@ from reasoning_gym.data import read_data_file
|
||||||
|
|
||||||
from ..factory import ProceduralDataset, register_dataset
|
from ..factory import ProceduralDataset, register_dataset
|
||||||
|
|
||||||
|
QUESTION_TEMPLATE = """Your task is to unsramble words in a sentence.
|
||||||
|
|
||||||
|
For each word in a sentence, the letter may have been randomly shuffled. Your task is to unscramble the words.
|
||||||
|
|
||||||
|
The order of the words in the sentence is preserved. Moreover, the style of the sentence is preserved (i.e. punctuation, capitalization, new lines, etc.).
|
||||||
|
|
||||||
|
Example:
|
||||||
|
- Input: Unscramble these words: raendgmeins yWh nya hilcd anc od hatt
|
||||||
|
- Output: meanderings Why any child can do that
|
||||||
|
- Explanation
|
||||||
|
- We unscramble each of the words independently.
|
||||||
|
- raendgmeins -> meanderings
|
||||||
|
- yWh -> Why
|
||||||
|
- nya -> any
|
||||||
|
- hilcd -> child
|
||||||
|
- anc -> can
|
||||||
|
- od -> do
|
||||||
|
- hatt -> that
|
||||||
|
- The final answer is: meanderings Why any child can do that
|
||||||
|
- Notice that the order of the words is preserved, no new words / symbols (e.g. new lines) are added.
|
||||||
|
|
||||||
|
Now, unscramble these words: {words}
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class LetterJumbleConfig:
|
class LetterJumbleConfig:
|
||||||
|
|
@ -89,7 +113,7 @@ class LetterJumbleDataset(ProceduralDataset):
|
||||||
scrambled_words = [self._scramble_word(word, corruption_level, rng) for word in selected_words]
|
scrambled_words = [self._scramble_word(word, corruption_level, rng) for word in selected_words]
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"question": f"Unscramble these words: {' '.join(scrambled_words)}",
|
"question": QUESTION_TEMPLATE.format(words=" ".join(scrambled_words)),
|
||||||
"answer": " ".join(selected_words),
|
"answer": " ".join(selected_words),
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"num_words": num_words,
|
"num_words": num_words,
|
||||||
|
|
@ -112,14 +136,16 @@ class LetterJumbleDataset(ProceduralDataset):
|
||||||
float: The computed score between 0.0 and 1.0.
|
float: The computed score between 0.0 and 1.0.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if answer == None:
|
oracle_answer = entry["answer"].strip()
|
||||||
return 0.0
|
if answer:
|
||||||
|
answer = answer.strip()
|
||||||
s_answer = answer.strip().lower()
|
if answer == oracle_answer:
|
||||||
if not s_answer == entry["answer"].strip().lower():
|
return 1.0
|
||||||
return 0.01
|
elif answer.lower() == oracle_answer.lower():
|
||||||
else:
|
return 0.5
|
||||||
return 1.0
|
else:
|
||||||
|
return 0.01
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
|
||||||
register_dataset("letter_jumble", LetterJumbleDataset, LetterJumbleConfig)
|
register_dataset("letter_jumble", LetterJumbleDataset, LetterJumbleConfig)
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,23 @@ from typing import Any, Dict, Optional
|
||||||
|
|
||||||
from ..factory import ProceduralDataset, register_dataset
|
from ..factory import ProceduralDataset, register_dataset
|
||||||
|
|
||||||
|
QUESTION_TEMPALTE = """Your task is, given a list of letters, to form a valid palindrome.
|
||||||
|
|
||||||
|
A palindrome is a phrase that reads the same forwards and backwards.
|
||||||
|
|
||||||
|
If there are multiple possible answers, only respond with one of them. You must use all the letters provided.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
- Input: Form a valid palindrome using the following letters: a, a, b
|
||||||
|
- Output: aba
|
||||||
|
- Explanation:
|
||||||
|
- The phrase aba reads the same forwards and backwards.
|
||||||
|
- The output answer is a valid palindrome using all the letters provided.
|
||||||
|
- The answer is a string, rather than a list of characters.
|
||||||
|
|
||||||
|
Now, form a valid palindrome using the following letters: {letters}
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class PalindromeConfig:
|
class PalindromeConfig:
|
||||||
|
|
@ -51,16 +68,8 @@ class PalindromeDataset(ProceduralDataset):
|
||||||
letters = self._generate_palindrome_letters(rng, length)
|
letters = self._generate_palindrome_letters(rng, length)
|
||||||
scrambled_letters = rng.sample(letters, len(letters)) # Scramble the order
|
scrambled_letters = rng.sample(letters, len(letters)) # Scramble the order
|
||||||
palindrome = self._assemble_palindrome(letters)
|
palindrome = self._assemble_palindrome(letters)
|
||||||
|
|
||||||
question_str = (
|
|
||||||
"Rearrange these letters to form a palindrome. A palindrome is a word, phrase, or sequence that reads the same forward and backward. If there are multiple answers, only respond with one of them.\n\n"
|
|
||||||
"For example, if the letters are: a, a, b — a valid palindrome is: aba.\n\n"
|
|
||||||
f"Your letters: {', '.join(scrambled_letters)}\n\n"
|
|
||||||
"What palindrome can you form from these letters?"
|
|
||||||
)
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"question": question_str,
|
"question": QUESTION_TEMPALTE.format(letters=", ".join(scrambled_letters)),
|
||||||
"answer": palindrome,
|
"answer": palindrome,
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"letters": scrambled_letters,
|
"letters": scrambled_letters,
|
||||||
|
|
|
||||||
144
reasoning_gym/algorithmic/palindrome_partitioning.py
Normal file
144
reasoning_gym/algorithmic/palindrome_partitioning.py
Normal file
|
|
@ -0,0 +1,144 @@
|
||||||
|
"""Given a string, return all possible partitions of the string such that each substring is a palindrome.
|
||||||
|
|
||||||
|
A popular Leetcode problem:
|
||||||
|
https://leetcode.com/problems/palindrome-partitioning/description/
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import string
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from random import Random
|
||||||
|
from typing import Dict, Optional
|
||||||
|
|
||||||
|
from ..factory import ProceduralDataset, register_dataset
|
||||||
|
|
||||||
|
QUESTION_TEMPLATE = """Given a string, partition it such that every substring is a palindrome.
|
||||||
|
|
||||||
|
A palindrome is a word that reads the same backward as forward.
|
||||||
|
|
||||||
|
You may return all possible palindrome partitioning in any order.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
- Input: Partition the following string into palindromes: aab
|
||||||
|
- Output: [["a","a","b"],["aa","b"]]
|
||||||
|
- Explanation:
|
||||||
|
- One way to partition the string is "a" | "a" | "b", where each substring is a palindrome.
|
||||||
|
- Another way to partition the string is "aa" | "b", where again each substring is a palindrome.
|
||||||
|
- Therefore, the final result is a list of the two palindrome partitions.
|
||||||
|
|
||||||
|
Partition the following string into palindromes: {string}
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PalindromePartitioningConfig:
|
||||||
|
"""Configuration for Palindrome Partitioning dataset generation"""
|
||||||
|
|
||||||
|
min_string_len: int = 5
|
||||||
|
max_string_len: int = 15
|
||||||
|
max_substring_palindome_len: int = 5
|
||||||
|
|
||||||
|
size: int = 500 # Virtual dataset size
|
||||||
|
seed: Optional[int] = None
|
||||||
|
|
||||||
|
def validate(self):
|
||||||
|
"""Validate configuration parameters"""
|
||||||
|
assert 1 <= self.min_string_len, "Minimum string length must be at least 1"
|
||||||
|
assert self.min_string_len <= self.max_string_len, "Minimum string length must be less than or equal to maximum"
|
||||||
|
assert 1 <= self.max_substring_palindome_len, "Maximum substring palindrome length must be at least 1"
|
||||||
|
assert (
|
||||||
|
self.max_substring_palindome_len <= self.max_string_len
|
||||||
|
), "Maximum substring palindrome length must be less than or equal to maximum string length"
|
||||||
|
|
||||||
|
|
||||||
|
class PalindromePartitioningDataset(ProceduralDataset):
|
||||||
|
"""Generates Palindrome Partitioning exercises with configurable difficulty"""
|
||||||
|
|
||||||
|
def __init__(self, config: PalindromePartitioningConfig):
|
||||||
|
super().__init__(config=config, seed=config.seed, size=config.size)
|
||||||
|
|
||||||
|
def _sort_list(self, lst: list[list[str]]) -> list[list[str]]:
|
||||||
|
"""Sort the list of palindrome partitions"""
|
||||||
|
return sorted(lst, key=lambda x: x[0] if x else "")
|
||||||
|
|
||||||
|
def to_set_of_tuples(self, list_of_lists: list[list[str]]) -> set[tuple[str]]:
|
||||||
|
"""Convert a list of lists to a set of tuples"""
|
||||||
|
return {tuple(lst) for lst in list_of_lists}
|
||||||
|
|
||||||
|
def _palindrome_partitioning(self, string: str) -> list[list[str]]:
|
||||||
|
"""Return all possible palindrome partitions of a string"""
|
||||||
|
if not string:
|
||||||
|
return []
|
||||||
|
dp = {}
|
||||||
|
|
||||||
|
def is_palindrome(i, j) -> bool:
|
||||||
|
if i >= j:
|
||||||
|
return True
|
||||||
|
if (i, j) in dp:
|
||||||
|
return dp[(i, j)]
|
||||||
|
dp[(i, j)] = string[i] == string[j] and is_palindrome(i + 1, j - 1)
|
||||||
|
return dp[(i, j)]
|
||||||
|
|
||||||
|
res, temp = [], []
|
||||||
|
|
||||||
|
def _partition(idx) -> None:
|
||||||
|
if idx >= len(string):
|
||||||
|
res.append(temp[:])
|
||||||
|
for i in range(idx, len(string)):
|
||||||
|
if is_palindrome(idx, i):
|
||||||
|
temp.append(string[idx : i + 1])
|
||||||
|
_partition(i + 1)
|
||||||
|
temp.pop()
|
||||||
|
|
||||||
|
_partition(0)
|
||||||
|
return self._sort_list(res)
|
||||||
|
|
||||||
|
def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float:
|
||||||
|
"""Score a single Palindrome Partitioning question"""
|
||||||
|
if answer is not None:
|
||||||
|
try:
|
||||||
|
answer = self.to_set_of_tuples(json.loads(answer))
|
||||||
|
oracle = self.to_set_of_tuples(entry["metadata"]["solution"])
|
||||||
|
if answer == oracle:
|
||||||
|
return 1.0
|
||||||
|
return 0.01
|
||||||
|
except Exception:
|
||||||
|
return 0.0
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
def _generate_palindrome_letters(self, rng: Random, length: int) -> list[str]:
|
||||||
|
"""Generate a set of letters that can form a palindrome."""
|
||||||
|
half_length = length // 2
|
||||||
|
letters = rng.choices(string.ascii_lowercase, k=half_length)
|
||||||
|
if length % 2 == 1:
|
||||||
|
middle_letter = rng.choice(string.ascii_lowercase)
|
||||||
|
return letters + [middle_letter] + letters[::-1]
|
||||||
|
return letters + letters[::-1]
|
||||||
|
|
||||||
|
def _get_string(self, rng: Random) -> str:
|
||||||
|
"""Generate a random string"""
|
||||||
|
size = rng.randint(self.config.min_string_len, self.config.max_string_len)
|
||||||
|
output = ""
|
||||||
|
|
||||||
|
while len(output) < size:
|
||||||
|
palindrome_len = rng.randint(1, min(self.config.max_substring_palindome_len, size - len(output)))
|
||||||
|
substring = "".join(self._generate_palindrome_letters(rng, palindrome_len))
|
||||||
|
output += substring
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
def __getitem__(self, idx: int) -> dict:
|
||||||
|
"""Generate a single Palindrome Partitioning question"""
|
||||||
|
rng = Random(self.seed + idx)
|
||||||
|
string = self._get_string(rng)
|
||||||
|
answer = self._palindrome_partitioning(string)
|
||||||
|
answer_str = json.dumps(answer)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"question": QUESTION_TEMPLATE.format(string=string),
|
||||||
|
"answer": answer_str,
|
||||||
|
"metadata": {"string": string, "solution": answer},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
register_dataset("palindrome_partitioning", PalindromePartitioningDataset, PalindromePartitioningConfig)
|
||||||
|
|
@ -3,7 +3,7 @@
|
||||||
import re
|
import re
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from random import Random
|
from random import Random
|
||||||
from typing import Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
from ..data import read_data_file
|
from ..data import read_data_file
|
||||||
from ..factory import ProceduralDataset, register_dataset
|
from ..factory import ProceduralDataset, register_dataset
|
||||||
|
|
@ -92,5 +92,26 @@ class SentenceReorderingDataset(ProceduralDataset):
|
||||||
"metadata": {"word_count": word_count},
|
"metadata": {"word_count": word_count},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def score_answer(self, answer: Optional[str], entry: Dict[str, Any]) -> float:
|
||||||
|
reward = 0.0
|
||||||
|
expected_answer = entry["answer"]
|
||||||
|
if answer is not None:
|
||||||
|
try:
|
||||||
|
if expected_answer == answer:
|
||||||
|
return 1.0
|
||||||
|
goal_words = expected_answer.split()
|
||||||
|
answer_words = answer.split()
|
||||||
|
if len(goal_words) == len(answer_words):
|
||||||
|
credit = [
|
||||||
|
1 if goal_word.lower() == answer_word.lower() else 0
|
||||||
|
for goal_word, answer_word in zip(goal_words, answer_words)
|
||||||
|
]
|
||||||
|
reward = sum(credit) / len(credit)
|
||||||
|
else:
|
||||||
|
reward = 0.05
|
||||||
|
except:
|
||||||
|
reward = 0.01
|
||||||
|
return reward
|
||||||
|
|
||||||
|
|
||||||
register_dataset("sentence_reordering", SentenceReorderingDataset, SentenceReorderingConfig)
|
register_dataset("sentence_reordering", SentenceReorderingDataset, SentenceReorderingConfig)
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,7 @@
|
||||||
import re
|
import re
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from random import Random
|
from random import Random
|
||||||
from typing import Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
from ..data import read_data_file
|
from ..data import read_data_file
|
||||||
from ..factory import ProceduralDataset, register_dataset
|
from ..factory import ProceduralDataset, register_dataset
|
||||||
|
|
@ -49,5 +49,18 @@ class SpellBackwardDataset(ProceduralDataset):
|
||||||
"metadata": {"word": word, "word_len": len(word)},
|
"metadata": {"word": word, "word_len": len(word)},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def score_answer(self, answer: Optional[str], entry: Dict[str, Any]) -> float:
|
||||||
|
reward = 0.0
|
||||||
|
expected_answer = entry["answer"]
|
||||||
|
if answer is not None:
|
||||||
|
try:
|
||||||
|
if expected_answer.lower() == answer.lower():
|
||||||
|
reward = 1.0
|
||||||
|
else:
|
||||||
|
reward = 0.05
|
||||||
|
except:
|
||||||
|
reward = 0.01
|
||||||
|
return reward
|
||||||
|
|
||||||
|
|
||||||
register_dataset("spell_backward", SpellBackwardDataset, SpellBackwardConfig)
|
register_dataset("spell_backward", SpellBackwardDataset, SpellBackwardConfig)
|
||||||
|
|
|
||||||
|
|
@ -6,20 +6,25 @@ https://leetcode.com/problems/spiral-matrix/description/
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from random import Random
|
from random import Random
|
||||||
from typing import Optional
|
from typing import Dict, Optional
|
||||||
|
|
||||||
from ..factory import ProceduralDataset, register_dataset
|
from ..factory import ProceduralDataset, register_dataset
|
||||||
|
|
||||||
QUESTION_TEMPLATE = """Given a matrix, your job is to generate a list of elements in spiral order, starting from the top-left element.
|
QUESTION_TEMPLATE = """Given a matrix, your job is to generate a list of elements in spiral order, starting from the top-left element.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
- Input: For the matrix below, what is the list of elements in spiral order?
|
||||||
Input:
|
|
||||||
1 2 3
|
1 2 3
|
||||||
4 5 6
|
4 5 6
|
||||||
7 8 9
|
7 8 9
|
||||||
|
- Output: 1 2 3 6 9 8 7 4 5
|
||||||
Output: 1 2 3 6 9 8 7 4 5
|
- Explanation:
|
||||||
|
- We start from the top-left element (1) and move right until we reach the end of the row: 1 2 3
|
||||||
|
- Then, we move down until we reach the last column: 1 2 3 6 9
|
||||||
|
- Next, we move left until we reach the first column: 1 2 3 6 9 8 7
|
||||||
|
- Then, we move up until we reach the second row (i.e. one below the previously traversed row): 1 2 3 6 9 8 7 4
|
||||||
|
- Finally, we move right until we reach the second to last column: 1 2 3 6 9 8 7 4 5
|
||||||
|
- The output format is a space-separated list of elements in spiral order (as opposed to a python list)
|
||||||
|
|
||||||
For the matrix below, what is the list of elements in spiral order?
|
For the matrix below, what is the list of elements in spiral order?
|
||||||
{matrix}
|
{matrix}
|
||||||
|
|
@ -37,7 +42,7 @@ class SpiralMatrixConfig:
|
||||||
|
|
||||||
def validate(self):
|
def validate(self):
|
||||||
"""Validate configuration parameters"""
|
"""Validate configuration parameters"""
|
||||||
assert 1 <= self.max_n, "max_n must be at least 1"
|
assert 2 <= self.max_n, "max_n must be at least 2"
|
||||||
|
|
||||||
|
|
||||||
class SpiralMatrixDataset(ProceduralDataset):
|
class SpiralMatrixDataset(ProceduralDataset):
|
||||||
|
|
@ -48,7 +53,7 @@ class SpiralMatrixDataset(ProceduralDataset):
|
||||||
|
|
||||||
def _get_matrix(self, rng: Random) -> list[list[int]]:
|
def _get_matrix(self, rng: Random) -> list[list[int]]:
|
||||||
"""Generate a random matrix"""
|
"""Generate a random matrix"""
|
||||||
n = rng.randint(1, self.config.max_n)
|
n = rng.randint(2, self.config.max_n)
|
||||||
numbers = [rng.randint(0, 9) for _ in range(n**2)]
|
numbers = [rng.randint(0, 9) for _ in range(n**2)]
|
||||||
rng.shuffle(numbers)
|
rng.shuffle(numbers)
|
||||||
matrix = [numbers[i * n : (i + 1) * n] for i in range(n)]
|
matrix = [numbers[i * n : (i + 1) * n] for i in range(n)]
|
||||||
|
|
@ -111,5 +116,28 @@ class SpiralMatrixDataset(ProceduralDataset):
|
||||||
"metadata": {"matrix": matrix, "solution": answer},
|
"metadata": {"matrix": matrix, "solution": answer},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float:
|
||||||
|
"""Overwrite this method in derived classes if a single oracle answer is not available."""
|
||||||
|
oracle_answer = entry["answer"].strip()
|
||||||
|
|
||||||
|
if answer is not None and len(answer) > 0:
|
||||||
|
answer = answer.strip()
|
||||||
|
|
||||||
|
# Exact match
|
||||||
|
if answer == oracle_answer:
|
||||||
|
return 1.0
|
||||||
|
|
||||||
|
# Try to see if the model's answer is a python list
|
||||||
|
try:
|
||||||
|
answer = " ".join(str(item) for item in eval(answer))
|
||||||
|
if answer == oracle_answer:
|
||||||
|
return 0.5
|
||||||
|
else:
|
||||||
|
return 0.01
|
||||||
|
except Exception as e:
|
||||||
|
return 0.01
|
||||||
|
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
|
||||||
register_dataset("spiral_matrix", SpiralMatrixDataset, SpiralMatrixConfig)
|
register_dataset("spiral_matrix", SpiralMatrixDataset, SpiralMatrixConfig)
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,7 @@ https://github.com/yongchao98/CodeSteer-v1.0/blob/main/create_dataset/create_dat
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from random import Random
|
from random import Random
|
||||||
from typing import Optional
|
from typing import Dict, Optional
|
||||||
|
|
||||||
from ..factory import ProceduralDataset, register_dataset
|
from ..factory import ProceduralDataset, register_dataset
|
||||||
|
|
||||||
|
|
@ -26,6 +26,7 @@ Example
|
||||||
- First, we insert A after ABCD.
|
- First, we insert A after ABCD.
|
||||||
- Even though with the newly inserted 'A' we can obtain the substring BCD[A], we can't use it to insert another character.
|
- Even though with the newly inserted 'A' we can obtain the substring BCD[A], we can't use it to insert another character.
|
||||||
- Lastly, we insert D after DEAB.
|
- Lastly, we insert D after DEAB.
|
||||||
|
- Therefore, the final answer is DDABCDAEEDEABD (represented as a string, instead of a list of characters).
|
||||||
|
|
||||||
Given the following string, provide the answer after inserting the characters according to the pattern: {string}
|
Given the following string, provide the answer after inserting the characters according to the pattern: {string}
|
||||||
"""
|
"""
|
||||||
|
|
@ -79,12 +80,28 @@ class StringInsertionDataset(ProceduralDataset):
|
||||||
i += 1
|
i += 1
|
||||||
return "".join(output)
|
return "".join(output)
|
||||||
|
|
||||||
|
def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float:
|
||||||
|
"""Overwrite this method in derived classes if a single oracle answer is not available."""
|
||||||
|
oracle_answer = entry["answer"]
|
||||||
|
if answer is not None:
|
||||||
|
if answer == oracle_answer:
|
||||||
|
return 1.0
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
# check if answer is python list of characters
|
||||||
|
answer = "".join(eval(answer))
|
||||||
|
if answer == oracle_answer:
|
||||||
|
return 0.5
|
||||||
|
except Exception as e:
|
||||||
|
return 0.01
|
||||||
|
return 0.0
|
||||||
|
|
||||||
def __getitem__(self, idx: int) -> dict:
|
def __getitem__(self, idx: int) -> dict:
|
||||||
"""Generate a single String Insertion question"""
|
"""Generate a single String Insertion question"""
|
||||||
rng = Random(self.seed + idx)
|
rng = Random(self.seed + idx)
|
||||||
|
|
||||||
string_length = rng.randint(self.config.min_string_length, self.config.max_string_length)
|
string_length = rng.randint(self.config.min_string_length, self.config.max_string_length)
|
||||||
string = [rng.choice(self.vocabulary) for _ in range(string_length)]
|
string = "".join(rng.choice(self.vocabulary) for _ in range(string_length))
|
||||||
|
|
||||||
answer = self._get_answer(string)
|
answer = self._get_answer(string)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -8,6 +8,10 @@ from typing import Dict, List, Optional, Set, Tuple
|
||||||
from ..data import get_data_file_path
|
from ..data import get_data_file_path
|
||||||
from ..factory import ProceduralDataset, register_dataset
|
from ..factory import ProceduralDataset, register_dataset
|
||||||
|
|
||||||
|
QUESTION_TEMPLATE = """Transform the word ladder '{start}' to '{end}' by changing one letter at a time.
|
||||||
|
Provide your answer as a comma-separated sequence of uppercase letters without spaces.
|
||||||
|
Each step must be a valid English word."""
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class WordLadderConfig:
|
class WordLadderConfig:
|
||||||
|
|
@ -211,7 +215,7 @@ class WordLadderDataset(ProceduralDataset):
|
||||||
raise IndexError(f"Dataset exhausted at index {idx}. {str(e)}")
|
raise IndexError(f"Dataset exhausted at index {idx}. {str(e)}")
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"question": f"Transform the word ladder '{start}' to '{end}' by changing one letter at a time.",
|
"question": QUESTION_TEMPLATE.format(start=start, end=end),
|
||||||
"answer": ",".join(path),
|
"answer": ",".join(path),
|
||||||
"metadata": {"start_word": start, "end_word": end, "word_length": length, "chain_length": len(path)},
|
"metadata": {"start_word": start, "end_word": end, "word_length": length, "chain_length": len(path)},
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -19,6 +19,23 @@ class TextTransformation(StrEnum):
|
||||||
RANDOMCASE = "randomcase"
|
RANDOMCASE = "randomcase"
|
||||||
|
|
||||||
|
|
||||||
|
QUESTION_TEMPLATE = """Your task is to sort words in ascending or descending order using ASCII/Unicode ordering.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
- Input: Sort these words in ascending order (using ASCII/Unicode ordering) and return them as a comma-separated list: freely, idea, indemnify, last, END, solving
|
||||||
|
- Output: END, freely, idea, indemnify, last, solving
|
||||||
|
- Explanation:
|
||||||
|
- Uppercase letters come before lowercase letters, hence why "END" comes first.
|
||||||
|
- "freely" comes before "idea" because "f" comes before "i".
|
||||||
|
- "idea" comes before "indemnify" because even though they both start with "i", "d" comes before "n".
|
||||||
|
- "indemnify" comes before "last" because "i" comes before "l".
|
||||||
|
- "last" comes before "solving" because "l" comes before "s".
|
||||||
|
- Finally, the output is provided as a comma separated list of the sorted words.
|
||||||
|
|
||||||
|
Now, sort these words in {direction} order (using ASCII/Unicode ordering) and return them as a comma-separated list: {words}
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class WordSortingConfig:
|
class WordSortingConfig:
|
||||||
"""Configuration for word sorting task generation"""
|
"""Configuration for word sorting task generation"""
|
||||||
|
|
@ -94,7 +111,7 @@ class WordSortingDataset(ProceduralDataset):
|
||||||
answer = asc_words if is_ascending else desc_words
|
answer = asc_words if is_ascending else desc_words
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"question": f"Sort these words in {direction} order (using ASCII/Unicode ordering) and return them as a comma-separated list:\n{', '.join(transformed_words)}",
|
"question": QUESTION_TEMPLATE.format(direction=direction, words=", ".join(transformed_words)),
|
||||||
"answer": ", ".join(answer),
|
"answer": ", ".join(answer),
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"original_words": original_words,
|
"original_words": original_words,
|
||||||
|
|
@ -106,26 +123,17 @@ class WordSortingDataset(ProceduralDataset):
|
||||||
}
|
}
|
||||||
|
|
||||||
def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float:
|
def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float:
|
||||||
"""Determine if the solution provided solves this task.
|
oracle_answer = entry["metadata"]["sorted_words"]
|
||||||
|
if answer is not None and len(answer) > 0:
|
||||||
|
parsed_answer = [word.strip() for word in re.split(r",\s*", answer)]
|
||||||
|
if parsed_answer == oracle_answer:
|
||||||
|
return 1.0
|
||||||
|
elif sorted(parsed_answer) == oracle_answer:
|
||||||
|
return 0.2
|
||||||
|
else:
|
||||||
|
return 0.01
|
||||||
|
|
||||||
The function awards 1.0 for a correct answer.
|
return 0.0
|
||||||
|
|
||||||
Args:
|
|
||||||
answer (Optional[str]): The user's answer.
|
|
||||||
entry (Dict[str, any]): The original dataset entry containing the correct answer.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
float: The computed score between 0.0 and 1.0.
|
|
||||||
"""
|
|
||||||
|
|
||||||
if answer == None:
|
|
||||||
return 0.0
|
|
||||||
|
|
||||||
s_answer = answer.strip().replace(" ", "")
|
|
||||||
if not s_answer == entry["answer"].strip().replace(" ", ""):
|
|
||||||
return 0.01
|
|
||||||
else:
|
|
||||||
return 1.0
|
|
||||||
|
|
||||||
|
|
||||||
register_dataset("word_sorting", WordSortingDataset, WordSortingConfig)
|
register_dataset("word_sorting", WordSortingDataset, WordSortingConfig)
|
||||||
|
|
|
||||||
|
|
@ -27,6 +27,7 @@ class ArcAgiConfig:
|
||||||
default_factory=lambda: ["horizontal", "vertical", "diagonal", "counterdiagonal"]
|
default_factory=lambda: ["horizontal", "vertical", "diagonal", "counterdiagonal"]
|
||||||
) # empty list for no mirrors
|
) # empty list for no mirrors
|
||||||
use_color_permutation: bool = True
|
use_color_permutation: bool = True
|
||||||
|
shuffle_example_order: bool = True # whether to shuffle the order of example board pairs for each riddle
|
||||||
|
|
||||||
seed: Optional[int] = None
|
seed: Optional[int] = None
|
||||||
size: int = 500
|
size: int = 500
|
||||||
|
|
@ -87,8 +88,8 @@ def cmap(board: Board, colors: list[int]) -> Board:
|
||||||
return [[colors[c] for c in row] for row in board]
|
return [[colors[c] for c in row] for row in board]
|
||||||
|
|
||||||
|
|
||||||
ROTATION_AUGMENTATIONS = [identity, rot90, rot180, rot270]
|
# ROTATION_AUGMENTATIONS = [identity, rot90, rot180, rot270]
|
||||||
MIRROR_AUGMENTATIONS = [identity, hmirror, vmirror, dmirror, cmirror]
|
# MIRROR_AUGMENTATIONS = [identity, hmirror, vmirror, dmirror, cmirror]
|
||||||
|
|
||||||
|
|
||||||
class ArcAgiDataset(ProceduralDataset):
|
class ArcAgiDataset(ProceduralDataset):
|
||||||
|
|
@ -156,6 +157,9 @@ class ArcAgiDataset(ProceduralDataset):
|
||||||
for p in train:
|
for p in train:
|
||||||
augmented_train.append({"input": augment(p["input"]), "output": augment(p["output"])})
|
augmented_train.append({"input": augment(p["input"]), "output": augment(p["output"])})
|
||||||
|
|
||||||
|
if self.config.shuffle_example_order:
|
||||||
|
rng.shuffle(augmented_train)
|
||||||
|
|
||||||
examples = [
|
examples = [
|
||||||
format_board_pair(i + 1, p, formatting_options=self.config.board_format_opts)
|
format_board_pair(i + 1, p, formatting_options=self.config.board_format_opts)
|
||||||
for i, p in enumerate(augmented_train)
|
for i, p in enumerate(augmented_train)
|
||||||
|
|
|
||||||
|
|
@ -64,6 +64,9 @@ class BasicArithmeticDataset(ProceduralDataset):
|
||||||
|
|
||||||
def __init__(self, config: BasicArithmeticDatasetConfig):
|
def __init__(self, config: BasicArithmeticDatasetConfig):
|
||||||
super().__init__(config=config, seed=config.seed, size=config.size)
|
super().__init__(config=config, seed=config.seed, size=config.size)
|
||||||
|
self.added_instruction = (
|
||||||
|
" Ensure to report the answer as an integer. Do not add commas to the integer answers reported."
|
||||||
|
)
|
||||||
|
|
||||||
def __getitem__(self, idx: int) -> dict[str, Any]:
|
def __getitem__(self, idx: int) -> dict[str, Any]:
|
||||||
"""Generate a single arithmetic task
|
"""Generate a single arithmetic task
|
||||||
|
|
@ -88,7 +91,7 @@ class BasicArithmeticDataset(ProceduralDataset):
|
||||||
else:
|
else:
|
||||||
expression, result = self._generate_simple_task(rng, num_terms, num_digits)
|
expression, result = self._generate_simple_task(rng, num_terms, num_digits)
|
||||||
|
|
||||||
question = self._format_question(rng, expression)
|
question = self._format_question(rng, expression) + self.added_instruction
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"question": question,
|
"question": question,
|
||||||
|
|
@ -223,12 +226,14 @@ class BasicArithmeticDataset(ProceduralDataset):
|
||||||
return expression, result
|
return expression, result
|
||||||
|
|
||||||
def _format_question(self, rng: Random, expression: str) -> str:
|
def _format_question(self, rng: Random, expression: str) -> str:
|
||||||
"""Format the expression according to config style"""
|
"""Format the the question with the arithmetic expression"""
|
||||||
|
|
||||||
if self.config.format_style == "simple":
|
if self.config.format_style == "simple":
|
||||||
return f"{expression} ="
|
return f"Calculate {expression}."
|
||||||
else:
|
else:
|
||||||
templates = ["What is {0}?", "Calculate {0}", "Solve {0}", "Evaluate the expression: {0}"]
|
templates = ["What is {0}?", "Solve {0}.", "Compute {0}.", "Evaluate: {0}."]
|
||||||
return rng.choice(templates).format(expression)
|
template = rng.choice(templates)
|
||||||
|
return template.format(expression)
|
||||||
|
|
||||||
|
|
||||||
# Register the dataset
|
# Register the dataset
|
||||||
|
|
|
||||||
|
|
@ -63,7 +63,7 @@ class ChainSumDataset(ProceduralDataset):
|
||||||
expression, result = self._generate_task(rng, num_terms, min_value, max_value)
|
expression, result = self._generate_task(rng, num_terms, min_value, max_value)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"question": f"{expression} =",
|
"question": f"State the final answer to the following arithmetic problem: {expression} =",
|
||||||
"answer": str(result),
|
"answer": str(result),
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"difficulty": {
|
"difficulty": {
|
||||||
|
|
|
||||||
|
|
@ -1,12 +1,15 @@
|
||||||
"""Fraction simplification task generator"""
|
"""Fraction simplification task generator"""
|
||||||
|
|
||||||
|
import re
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from math import gcd
|
from math import gcd
|
||||||
from random import Random
|
from random import Random
|
||||||
from typing import Optional, Sequence, Tuple
|
from typing import Any, Dict, Optional, Sequence, Tuple
|
||||||
|
|
||||||
from ..factory import ProceduralDataset, register_dataset
|
from ..factory import ProceduralDataset, register_dataset
|
||||||
|
|
||||||
|
QUESTION_TEMPLATE = "Simplify the fraction {question_fraction} to its lowest terms. Give only the simplified fraction as your final answer."
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class FractionSimplificationConfig:
|
class FractionSimplificationConfig:
|
||||||
|
|
@ -107,7 +110,7 @@ class FractionSimplificationDataset(ProceduralDataset):
|
||||||
answer_fraction = self._format_fraction(simple_num, simple_den, style)
|
answer_fraction = self._format_fraction(simple_num, simple_den, style)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"question": f"Simplify the fraction {question_fraction} to its lowest terms",
|
"question": QUESTION_TEMPLATE.format(question_fraction=question_fraction),
|
||||||
"answer": answer_fraction,
|
"answer": answer_fraction,
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"numerator": num,
|
"numerator": num,
|
||||||
|
|
@ -119,5 +122,34 @@ class FractionSimplificationDataset(ProceduralDataset):
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def _extract_fraction(self, answer: Optional[str]):
|
||||||
|
try:
|
||||||
|
cleaned = answer.strip().strip("$").strip()
|
||||||
|
latex_match = re.match(r"\\(?:frac|dfrac)\s*{\s*(\d+)\s*}\s*{\s*(\d+)\s*}", cleaned, re.IGNORECASE)
|
||||||
|
if latex_match:
|
||||||
|
return int(latex_match.group(1)), int(latex_match.group(2))
|
||||||
|
if "/" in cleaned:
|
||||||
|
numerator, denominator = map(str.strip, cleaned.split("/", 1))
|
||||||
|
return int(numerator), int(denominator)
|
||||||
|
except:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def score_answer(self, answer: Optional[str], entry: Dict[str, Any]):
|
||||||
|
reward = 0.0
|
||||||
|
metadata = entry["metadata"]
|
||||||
|
try:
|
||||||
|
numerator, denominator = self._extract_fraction(answer)
|
||||||
|
if numerator == metadata["simplified_numerator"] and denominator == metadata["simplified_denominator"]:
|
||||||
|
reward = 1.0
|
||||||
|
elif numerator == metadata["numerator"] or denominator == metadata["denominator"]:
|
||||||
|
reward = 0.1
|
||||||
|
elif len(answer.strip()) > 0:
|
||||||
|
reward = 0.05
|
||||||
|
else:
|
||||||
|
reward = 0.01
|
||||||
|
except:
|
||||||
|
reward = 0.01
|
||||||
|
return reward
|
||||||
|
|
||||||
|
|
||||||
register_dataset("fraction_simplification", FractionSimplificationDataset, FractionSimplificationConfig)
|
register_dataset("fraction_simplification", FractionSimplificationDataset, FractionSimplificationConfig)
|
||||||
|
|
|
||||||
|
|
@ -57,7 +57,7 @@ class GCDDataset(ProceduralDataset):
|
||||||
numbers_str = ", ".join(str(n) for n in numbers)
|
numbers_str = ", ".join(str(n) for n in numbers)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"question": f"Find the Greatest Common Divisor (GCD) of these numbers: {numbers_str}",
|
"question": f"Find the Greatest Common Divisor (GCD) of these numbers: {numbers_str}. Give only the GCD as your final answer.",
|
||||||
"answer": str(result),
|
"answer": str(result),
|
||||||
"metadata": {"numbers": numbers, "result": result},
|
"metadata": {"numbers": numbers, "result": result},
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -148,7 +148,9 @@ class GSMSymbolicDataset(ProceduralDataset):
|
||||||
rng = Random(self.seed + idx)
|
rng = Random(self.seed + idx)
|
||||||
generator_idx = self.task_indices[idx]
|
generator_idx = self.task_indices[idx]
|
||||||
generator = self.generators[generator_idx]
|
generator = self.generators[generator_idx]
|
||||||
return generator(rng, self.config.difficulty)
|
example = generator(rng, self.config.difficulty)
|
||||||
|
example["question"] += " Give only the result as your final answer."
|
||||||
|
return example
|
||||||
|
|
||||||
|
|
||||||
register_dataset("gsm_symbolic", GSMSymbolicDataset, GSMSymbolicDatasetConfig)
|
register_dataset("gsm_symbolic", GSMSymbolicDataset, GSMSymbolicDatasetConfig)
|
||||||
|
|
|
||||||
|
|
@ -54,14 +54,29 @@ ANIMALS = {
|
||||||
"woodlouse": 14,
|
"woodlouse": 14,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
QUESTION_TEMPLATE = """Your task is to count how many legs there are in total when given a list of animals.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
- Input: How many legs are there in total if you have 1 duck, 2 deers, 1 spider, 3 cows?
|
||||||
|
- Output: 30
|
||||||
|
- Explanation:
|
||||||
|
- Ducks have 2 legs each, so 1 duck has 2 legs.
|
||||||
|
- Deers have 4 legs each, so 2 deers have 8 legs.
|
||||||
|
- Spiders have 8 legs each, so 1 spider has 8 legs.
|
||||||
|
- Cows have 4 legs each, so 3 cows have 12 legs.
|
||||||
|
- Therefore, the total number of legs is 2 + 8 + 8 + 12 = 30
|
||||||
|
|
||||||
|
Now, how many legs are there in total if you have {animals}?
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class LegCountingConfig:
|
class LegCountingConfig:
|
||||||
"""Configuration for leg counting task generation"""
|
"""Configuration for leg counting task generation"""
|
||||||
|
|
||||||
min_animals: int = 2 # Minimum number of animals in problem
|
min_animals: int = 3 # Minimum number of animals in problem
|
||||||
max_animals: int = 5 # Maximum number of animals
|
max_animals: int = 10 # Maximum number of animals
|
||||||
max_instances: int = 3 # Maximum instances of each animal
|
max_instances: int = 15 # Maximum instances of each animal
|
||||||
seed: Optional[int] = None
|
seed: Optional[int] = None
|
||||||
size: int = 500 # Virtual dataset size
|
size: int = 500 # Virtual dataset size
|
||||||
|
|
||||||
|
|
@ -106,10 +121,8 @@ class LegCountingDataset(ProceduralDataset):
|
||||||
for animal, count in animals.items():
|
for animal, count in animals.items():
|
||||||
animal_list.append(f"{count} {animal}{'s' if count > 1 else ''}")
|
animal_list.append(f"{count} {animal}{'s' if count > 1 else ''}")
|
||||||
|
|
||||||
question = "How many legs are there in total if you have " + ", ".join(animal_list) + "?"
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"question": question,
|
"question": QUESTION_TEMPLATE.format(animals=", ".join(animal_list)),
|
||||||
"answer": str(total_legs),
|
"answer": str(total_legs),
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"difficulty": {
|
"difficulty": {
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,24 @@ from typing import Dict, Optional
|
||||||
|
|
||||||
from ..factory import ProceduralDataset, register_dataset
|
from ..factory import ProceduralDataset, register_dataset
|
||||||
|
|
||||||
QUESTION_TEMPLATE = """Compute {base}^{exponent}"""
|
QUESTION_TEMPLATE = """Your task is to compute an exponentiation of a number.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
- Input: Compute 2^3
|
||||||
|
- Output: 8
|
||||||
|
- Explanation:
|
||||||
|
- 2^3 = 2 * 2 * 2 = 8
|
||||||
|
- Therefore, the final answer is 8
|
||||||
|
|
||||||
|
Example:
|
||||||
|
- Input: Compute 412.5^3
|
||||||
|
- Output: 70189453.125
|
||||||
|
- Explanation:
|
||||||
|
- 412.5^3 = 412.5 * 412.5 * 412.5 = 70189453.125
|
||||||
|
- Therefore, the final answer is 70189453.125
|
||||||
|
|
||||||
|
Compute {base}^{exponent}
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
@ -32,28 +49,31 @@ class PowerFunctionDataset(ProceduralDataset):
|
||||||
def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float:
|
def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float:
|
||||||
"""Overwrite this method in derived classes if a single oracle answer is not available."""
|
"""Overwrite this method in derived classes if a single oracle answer is not available."""
|
||||||
oracle_answer = entry["answer"]
|
oracle_answer = entry["answer"]
|
||||||
reward = 0.0
|
|
||||||
if answer is not None:
|
if answer is not None:
|
||||||
difference = abs(float(answer) - float(oracle_answer))
|
try:
|
||||||
if difference < 1e-6:
|
answer = round(float(answer), 4)
|
||||||
reward = 1.0
|
oracle_answer = round(float(oracle_answer), 4)
|
||||||
elif difference < 1e-1:
|
difference = abs(float(answer) - float(oracle_answer))
|
||||||
reward = 0.5
|
if difference < 1e-4:
|
||||||
else:
|
return 1.0
|
||||||
reward = 0.01
|
elif difference < 1e-1:
|
||||||
|
return 0.5
|
||||||
return reward
|
else:
|
||||||
|
return 0.01
|
||||||
|
except Exception as e:
|
||||||
|
return 0.01
|
||||||
|
return 0.0
|
||||||
|
|
||||||
def __getitem__(self, idx: int) -> dict:
|
def __getitem__(self, idx: int) -> dict:
|
||||||
"""Generate a single Power Function question"""
|
"""Generate a single Power Function question"""
|
||||||
rng = Random(self.seed + idx)
|
rng = Random(self.seed + idx)
|
||||||
|
|
||||||
base = rng.uniform(self.config.min_base, self.config.max_base)
|
base = round(rng.uniform(self.config.min_base, self.config.max_base), 4)
|
||||||
exponent = rng.randint(self.config.min_exponent, self.config.max_exponent)
|
exponent = rng.randint(self.config.min_exponent, self.config.max_exponent)
|
||||||
answer = pow(base, exponent)
|
answer = pow(base, exponent)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"question": f"Compute {base}^{exponent}",
|
"question": QUESTION_TEMPLATE.format(base=base, exponent=exponent),
|
||||||
"answer": str(answer),
|
"answer": str(answer),
|
||||||
"metadata": {"base": base, "exponent": exponent, "solution": answer},
|
"metadata": {"base": base, "exponent": exponent, "solution": answer},
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -57,7 +57,7 @@ class ProductsDataset(ProceduralDataset):
|
||||||
expression, result = self._generate_task(rng, num_terms, min_value, max_value)
|
expression, result = self._generate_task(rng, num_terms, min_value, max_value)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"question": f"{expression} =",
|
"question": f"Solve the following multiplication: {expression}. Give only the result as your final answer.",
|
||||||
"answer": str(result),
|
"answer": str(result),
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"difficulty": {
|
"difficulty": {
|
||||||
|
|
|
||||||
|
|
@ -28,7 +28,8 @@ class BFDataset(ProceduralDataset):
|
||||||
|
|
||||||
def __init__(self, config: BFConfig):
|
def __init__(self, config: BFConfig):
|
||||||
self._prompt_templates = [
|
self._prompt_templates = [
|
||||||
"This is a BF (Brainf*ck) computer program. What is the output? Reply only with the program output, ex: 42. \n\n{bf_program}",
|
"This is a BF (Brainf*ck) computer program. What is the output?\n\n{bf_program}\n\nRespond only with the exact output of the program.",
|
||||||
|
"Consider the following BF (Brainf*ck) code. What would it output?\n\n{bf_program}\n\nProvide only the exact output of the code.",
|
||||||
]
|
]
|
||||||
super().__init__(config=config, seed=config.seed, size=config.size)
|
super().__init__(config=config, seed=config.seed, size=config.size)
|
||||||
|
|
||||||
|
|
@ -123,6 +124,13 @@ int main() {{
|
||||||
if answer == None:
|
if answer == None:
|
||||||
return 0.0
|
return 0.0
|
||||||
if answer != entry["answer"]:
|
if answer != entry["answer"]:
|
||||||
|
if entry["answer"] in answer.splitlines():
|
||||||
|
# We can be quite confident that the correct answer was given
|
||||||
|
# It was likely just given alongside an explanation
|
||||||
|
return max(0.9 * len(answer) / len(entry["answer"]), 0.1)
|
||||||
|
if entry["answer"] in answer:
|
||||||
|
# Since answers are English words, some risk of the response coincidentally containing the answer
|
||||||
|
return max(0.5 * len(answer) / len(entry["answer"]), 0.1)
|
||||||
return 0.01
|
return 0.01
|
||||||
else:
|
else:
|
||||||
return 1.0 # Yay
|
return 1.0 # Yay
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,42 @@ from typing import Dict, Optional
|
||||||
|
|
||||||
from ..factory import ProceduralDataset, register_dataset
|
from ..factory import ProceduralDataset, register_dataset
|
||||||
|
|
||||||
|
QUESTION_TEMPLATE = """Your task is to count how many rectangles are present in an ASCII grid.
|
||||||
|
|
||||||
|
Single rectangles are outlined with a '#', overlapping rectangles (max 2) are shown with '█'.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
- Input: How many rectangles are in the grid below?
|
||||||
|
|
||||||
|
####
|
||||||
|
# #
|
||||||
|
####
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
#########
|
||||||
|
# █##
|
||||||
|
# █ #
|
||||||
|
########█ #
|
||||||
|
# #
|
||||||
|
###
|
||||||
|
- Output: 3
|
||||||
|
- Explanation:
|
||||||
|
- The first rectangle is the 3x4 rectangle in the top right.
|
||||||
|
- The other two rectangles are overlapping in the bottom left corner.
|
||||||
|
- Therefore, the final answer is 3.
|
||||||
|
|
||||||
|
Now, it's your turn. How many rectangles do you see in the grid below?
|
||||||
|
{puzzle}
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
def draw_rectangles_with_overlap(n, width, height, rng):
|
def draw_rectangles_with_overlap(n, width, height, rng):
|
||||||
# Create a grid that holds a count of how many times a cell is drawn.
|
# Create a grid that holds a count of how many times a cell is drawn.
|
||||||
|
|
@ -103,12 +139,10 @@ class RectangleCountDataset(ProceduralDataset):
|
||||||
target = rng.randint(1, self.config.max_rectangles)
|
target = rng.randint(1, self.config.max_rectangles)
|
||||||
puzzle, answer = draw_rectangles_with_overlap(target, self.config.width, self.config.height, rng)
|
puzzle, answer = draw_rectangles_with_overlap(target, self.config.width, self.config.height, rng)
|
||||||
|
|
||||||
puzz = f"How many rectangles do you see? Single rectangles are outlined with a '#', overlapping rectangles (max 2) are shown with '█'. \n\n {puzzle}"
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"question": puzz,
|
"question": QUESTION_TEMPLATE.format(puzzle=puzzle),
|
||||||
"answer": str(answer),
|
"answer": str(answer),
|
||||||
"metadata": {},
|
"metadata": {"puzzle": puzzle, "solution": answer},
|
||||||
}
|
}
|
||||||
|
|
||||||
def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float:
|
def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float:
|
||||||
|
|
|
||||||
|
|
@ -53,9 +53,10 @@ class ProceduralDataset(ABC, Sized, Iterable[Dict[str, Any]]):
|
||||||
|
|
||||||
def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float:
|
def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float:
|
||||||
"""Overwrite this method in derived classes if a single oracle answer is not available."""
|
"""Overwrite this method in derived classes if a single oracle answer is not available."""
|
||||||
oracle_answer = entry["answer"]
|
oracle_answer = entry["answer"].strip()
|
||||||
reward = 0.0
|
reward = 0.0
|
||||||
if answer is not None:
|
if answer is not None and len(answer) > 0:
|
||||||
|
answer = answer.strip()
|
||||||
if answer == oracle_answer:
|
if answer == oracle_answer:
|
||||||
reward = 1.0
|
reward = 1.0
|
||||||
elif oracle_answer in answer:
|
elif oracle_answer in answer:
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,7 @@ Game tasks for training reasoning capabilities:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from .countdown import CountdownConfig, CountdownDataset
|
from .countdown import CountdownConfig, CountdownDataset
|
||||||
|
from .futoshiki import FutoshikiConfig, FutoshikiDataset
|
||||||
from .knight_swap import KnightSwapConfig, KnightSwapDataset
|
from .knight_swap import KnightSwapConfig, KnightSwapDataset
|
||||||
from .maze import MazeConfig, MazeDataset
|
from .maze import MazeConfig, MazeDataset
|
||||||
from .mini_sudoku import MiniSudokuConfig, MiniSudokuDataset
|
from .mini_sudoku import MiniSudokuConfig, MiniSudokuDataset
|
||||||
|
|
@ -20,6 +21,8 @@ from .tsumego import TsumegoConfig, TsumegoDataset
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"CountdownConfig",
|
"CountdownConfig",
|
||||||
"CountdownDataset",
|
"CountdownDataset",
|
||||||
|
"FutoshikiConfig",
|
||||||
|
"FutoshikiDataset",
|
||||||
"MiniSudokuConfig",
|
"MiniSudokuConfig",
|
||||||
"MiniSudokuDataset",
|
"MiniSudokuDataset",
|
||||||
"SudokuConfig",
|
"SudokuConfig",
|
||||||
|
|
|
||||||
|
|
@ -8,6 +8,15 @@ from sympy.parsing.sympy_parser import parse_expr
|
||||||
|
|
||||||
from ..factory import ProceduralDataset, register_dataset
|
from ..factory import ProceduralDataset, register_dataset
|
||||||
|
|
||||||
|
QUESTION_FORMAT_TEMPLATE = """{question}
|
||||||
|
Final answer format instructions:
|
||||||
|
1. Provide your solution as a arithmetic expression (no '=' sign).
|
||||||
|
2. Do not include the target number in the expression.
|
||||||
|
3. Use '*' for multiplication.
|
||||||
|
4. Use '/' for division.
|
||||||
|
5. Do not include any other text or formatting.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class CountdownConfig:
|
class CountdownConfig:
|
||||||
|
|
@ -67,8 +76,11 @@ class CountdownDataset(ProceduralDataset):
|
||||||
|
|
||||||
numbers_str = ", ".join(map(str, numbers))
|
numbers_str = ", ".join(map(str, numbers))
|
||||||
|
|
||||||
|
question = rng.choice(self._prompt_templates)
|
||||||
|
question = question.format(numbers=numbers_str, target=target)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"question": rng.choice(self._prompt_templates).format(numbers=numbers_str, target=target),
|
"question": QUESTION_FORMAT_TEMPLATE.format(question=question),
|
||||||
"answer": expression,
|
"answer": expression,
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"numbers": numbers,
|
"numbers": numbers,
|
||||||
|
|
|
||||||
656
reasoning_gym/games/futoshiki.py
Normal file
656
reasoning_gym/games/futoshiki.py
Normal file
|
|
@ -0,0 +1,656 @@
|
||||||
|
"""Futoshiki puzzle generator"""
|
||||||
|
|
||||||
|
import copy
|
||||||
|
import itertools
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from random import Random
|
||||||
|
from typing import Any, Dict, List, Optional, Set, Tuple
|
||||||
|
|
||||||
|
from ..factory import ProceduralDataset, register_dataset
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FutoshikiConfig:
|
||||||
|
"""Configuration for Futoshiki puzzle generation"""
|
||||||
|
|
||||||
|
board_size: int = 4 # Board will be NxN where N is this value
|
||||||
|
difficulty: int = 1 # Possible values: 0, 1, 2, 3
|
||||||
|
seed: Optional[int] = None
|
||||||
|
size: int = 500 # Virtual dataset size
|
||||||
|
|
||||||
|
def validate(self):
|
||||||
|
"""Validate configuration parameters"""
|
||||||
|
assert 4 <= self.board_size <= 9, "board_size must be between 4 and 9"
|
||||||
|
assert 0 <= self.difficulty <= 3, "difficulty must be between 0 and 3"
|
||||||
|
|
||||||
|
|
||||||
|
class FutoshikiDataset(ProceduralDataset):
|
||||||
|
"""Generates Futoshiki puzzles with configurable board size and difficulty"""
|
||||||
|
|
||||||
|
def __init__(self, config: FutoshikiConfig):
|
||||||
|
super().__init__(config=config, seed=config.seed, size=config.size)
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
return self.config.size
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
self._current_idx = 0
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __next__(self):
|
||||||
|
if self._current_idx >= self.config.size:
|
||||||
|
raise StopIteration
|
||||||
|
item = self[self._current_idx]
|
||||||
|
self._current_idx += 1
|
||||||
|
return item
|
||||||
|
|
||||||
|
def __getitem__(self, idx: int) -> dict:
|
||||||
|
"""
|
||||||
|
Generate a single Futoshiki puzzle with blanks, represented by 0s, and constraints.
|
||||||
|
Clues are pre-filled numbers in the grid.
|
||||||
|
Constraints are adjacent cell pairs which may have '<' or '>' relations.
|
||||||
|
Difficulty in [0..3] affects number of clues and constraints.
|
||||||
|
"""
|
||||||
|
rng = Random(self.seed + idx)
|
||||||
|
|
||||||
|
# Generate random "solved" Futoshiki grid
|
||||||
|
solution = self._generate_random_solution(self.config.board_size, rng)
|
||||||
|
# Add random adjacency constraints consistent with generated solved grid
|
||||||
|
constraints = self._generate_random_constraints(solution, self.config.difficulty, rng)
|
||||||
|
# Starting with full solution, remove clues to desired difficulty
|
||||||
|
puzzle = self._remove_clues(copy.deepcopy(solution), constraints, rng)
|
||||||
|
|
||||||
|
# Format as strings
|
||||||
|
puzzle_str = self._puzzle_to_string(puzzle, constraints)
|
||||||
|
solution_str = self._puzzle_to_string(solution, constraints)
|
||||||
|
|
||||||
|
question = (
|
||||||
|
f"Solve the following {self.config.board_size}x{self.config.board_size} Futoshiki puzzle:\n\n"
|
||||||
|
f"{puzzle_str}\n\n"
|
||||||
|
"Ensure your answer follows the same format as the puzzle above, just replace blanks (_) with the correct value for the cell.\n"
|
||||||
|
"Use < and > for horizontal constraints. Use \u2227 and \u2228 for vertical constraints.\n"
|
||||||
|
f"Remember, in Futoshiki each row and column must contain each number from 1 to {self.config.board_size} exactly once."
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"question": question,
|
||||||
|
"answer": solution_str,
|
||||||
|
"metadata": {
|
||||||
|
"puzzle": puzzle,
|
||||||
|
"constraints": constraints,
|
||||||
|
"solution": solution,
|
||||||
|
"board_size": self.config.board_size,
|
||||||
|
"difficulty": self.config.difficulty,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
def _puzzle_to_string(
|
||||||
|
self, puzzle_grid: List[List[int]], constraints: Dict[Tuple[Tuple[int, int], Tuple[int, int]], str]
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Formats a Futoshiki puzzle grid as a string with constraints.
|
||||||
|
Constraints are represented as '<', '>', '\u2227', or '\u2228' between adjacent cells.
|
||||||
|
"""
|
||||||
|
n = len(puzzle_grid)
|
||||||
|
|
||||||
|
def cell_str(val: int) -> str:
|
||||||
|
return str(val) if val != 0 else "_"
|
||||||
|
|
||||||
|
# Helper to look up constraints between two adjacent cells
|
||||||
|
# Ensures the first tuple is always the “lesser” in row-major order
|
||||||
|
# If order is reversed in the dict, invert the constraint
|
||||||
|
def get_constraint(r1, c1, r2, c2) -> Optional[str]:
|
||||||
|
if (r1, c1) == (r2, c2):
|
||||||
|
return None
|
||||||
|
if (r1, c1) < (r2, c2):
|
||||||
|
key = ((r1, c1), (r2, c2))
|
||||||
|
sign = constraints.get(key)
|
||||||
|
if sign == ">": # first is bigger
|
||||||
|
if r1 == r2: # horizontal
|
||||||
|
return ">"
|
||||||
|
else: # vertical
|
||||||
|
return "\u2228"
|
||||||
|
elif sign == "<": # first is smaller
|
||||||
|
if r1 == r2: # horizontal
|
||||||
|
return "<"
|
||||||
|
else:
|
||||||
|
return "\u2227"
|
||||||
|
else:
|
||||||
|
# reversed order in the dictionary -> invert the sign
|
||||||
|
key = ((r2, c2), (r1, c1))
|
||||||
|
sign = constraints.get(key)
|
||||||
|
if sign == ">":
|
||||||
|
if r1 == r2:
|
||||||
|
return "<"
|
||||||
|
else:
|
||||||
|
return "\u2227"
|
||||||
|
elif sign == "<":
|
||||||
|
if r1 == r2:
|
||||||
|
return ">"
|
||||||
|
else:
|
||||||
|
return "\u2228"
|
||||||
|
return None
|
||||||
|
|
||||||
|
lines = []
|
||||||
|
|
||||||
|
for r in range(n):
|
||||||
|
# Build the row string with horizontal constraints
|
||||||
|
row_cells = []
|
||||||
|
for c in range(n):
|
||||||
|
row_cells.append(cell_str(puzzle_grid[r][c]))
|
||||||
|
if c < n - 1:
|
||||||
|
hc = get_constraint(r, c, r, c + 1)
|
||||||
|
row_cells.append(hc if hc else " ")
|
||||||
|
lines.append(" ".join(row_cells))
|
||||||
|
|
||||||
|
# If not the last row, build the line of vertical constraints
|
||||||
|
if r < n - 1:
|
||||||
|
vert_cells = []
|
||||||
|
for c in range(n):
|
||||||
|
vc = get_constraint(r, c, r + 1, c)
|
||||||
|
if vc:
|
||||||
|
vert_cells.append(vc)
|
||||||
|
else:
|
||||||
|
vert_cells.append(" ")
|
||||||
|
# Space out columns so vertical symbols line up under the correct spot
|
||||||
|
if c < n - 1:
|
||||||
|
vert_cells.append(" ")
|
||||||
|
lines.append(" ".join(vert_cells))
|
||||||
|
|
||||||
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
def _solve_logical(
|
||||||
|
self,
|
||||||
|
grid: List[List[int]],
|
||||||
|
constraints: Dict[Tuple[Tuple[int, int], Tuple[int, int]], str],
|
||||||
|
) -> Tuple[List[List[int]], List[List[Set[int]]]]:
|
||||||
|
"""
|
||||||
|
Apply logical rules to progress solution.
|
||||||
|
Returns current state if logical rules can't progress further.
|
||||||
|
Logical rules are implemented based on the descriptions here: https://futoshiki.uk/
|
||||||
|
"""
|
||||||
|
size, working_grid = len(grid), copy.deepcopy(grid)
|
||||||
|
|
||||||
|
# Starting point all numbers are candidates for all unfilled squares
|
||||||
|
candidates: List[List[Set[int]]] = [
|
||||||
|
[set(range(1, len(grid) + 1)) if grid[r][c] == 0 else {grid[r][c]} for c in range(len(grid))]
|
||||||
|
for r in range(len(grid))
|
||||||
|
]
|
||||||
|
|
||||||
|
# Any cells > another cannot be 1, and any cells < another cannot be `size`
|
||||||
|
# This is separated from the repeated function below to avoid redundant checks
|
||||||
|
for ((r1, c1), (_, _)), rel in constraints.items():
|
||||||
|
if rel == ">":
|
||||||
|
candidates[r1][c1].discard(1)
|
||||||
|
elif rel == "<":
|
||||||
|
candidates[r1][c1].discard(size)
|
||||||
|
|
||||||
|
def _update_grid():
|
||||||
|
"""Update solution wherever a cell's candidates set is reduced to a single element."""
|
||||||
|
for r in range(len(working_grid)):
|
||||||
|
for c in range(len(working_grid)):
|
||||||
|
if working_grid[r][c] == 0 and len(candidates[r][c]) == 1:
|
||||||
|
working_grid[r][c] = next(iter(candidates[r][c]))
|
||||||
|
|
||||||
|
def _try_solve_logical() -> bool:
|
||||||
|
progress = False
|
||||||
|
|
||||||
|
# Eliminate candidates based on numbers already placed
|
||||||
|
# If a number is placed in a cell, it cannot be a candidate for any other cell in the same row or column
|
||||||
|
for r in range(len(working_grid)):
|
||||||
|
for c in range(len(working_grid)):
|
||||||
|
if working_grid[r][c] == 0:
|
||||||
|
continue
|
||||||
|
for cc in range(len(working_grid)):
|
||||||
|
if cc != c and working_grid[r][c] in candidates[r][cc]:
|
||||||
|
candidates[r][cc].discard(working_grid[r][c])
|
||||||
|
progress = True
|
||||||
|
for rr in range(len(working_grid)):
|
||||||
|
if rr != r and working_grid[r][c] in candidates[rr][c]:
|
||||||
|
candidates[rr][c].discard(working_grid[r][c])
|
||||||
|
progress = True
|
||||||
|
|
||||||
|
_update_grid()
|
||||||
|
|
||||||
|
# Eliminate candidates based on constraints
|
||||||
|
# Based on currently filled values, eliminate candidates that violate constraints
|
||||||
|
def _eliminate_by_constraint(rc_less: Tuple[int, int], rc_greater: Tuple[int, int]) -> bool:
|
||||||
|
r_less, c_less = rc_less
|
||||||
|
r_greater, c_greater = rc_greater
|
||||||
|
progress = False
|
||||||
|
|
||||||
|
if working_grid[r_less][c_less] != 0:
|
||||||
|
# greater must only have candidates > less
|
||||||
|
for v in candidates[r_greater][c_greater].copy():
|
||||||
|
if v <= working_grid[r_less][c_less] and v in candidates[r_greater][c_greater]:
|
||||||
|
candidates[r_greater][c_greater].discard(v)
|
||||||
|
progress = True
|
||||||
|
|
||||||
|
if working_grid[r_greater][c_greater] != 0:
|
||||||
|
# less must only have candidates < greater
|
||||||
|
for v in candidates[r_less][c_less].copy():
|
||||||
|
if v >= working_grid[r_greater][c_greater] and v in candidates[r_less][c_less]:
|
||||||
|
candidates[r_less][c_less].discard(v)
|
||||||
|
progress = True
|
||||||
|
|
||||||
|
return progress
|
||||||
|
|
||||||
|
for ((r1, c1), (r2, c2)), rel in constraints.items():
|
||||||
|
v1, v2 = working_grid[r1][c1], working_grid[r2][c2]
|
||||||
|
if v1 != 0 and v2 != 0: # both already filled, skip
|
||||||
|
continue
|
||||||
|
if rel == "<":
|
||||||
|
progress |= _eliminate_by_constraint((r1, c1), (r2, c2))
|
||||||
|
elif rel == ">":
|
||||||
|
progress |= _eliminate_by_constraint((r2, c2), (r1, c1))
|
||||||
|
|
||||||
|
_update_grid()
|
||||||
|
|
||||||
|
# Seek "hidden singles" - cells where a candidate is unique in the row or column
|
||||||
|
for r in range(len(working_grid)):
|
||||||
|
for c in range(len(working_grid)):
|
||||||
|
if working_grid[r][c] != 0:
|
||||||
|
continue
|
||||||
|
for v in candidates[r][c]:
|
||||||
|
if sum(v in candidates[r][cc] for cc in range(len(working_grid))) == 1:
|
||||||
|
candidates[r][c] = {v} # candidate unique in row
|
||||||
|
break
|
||||||
|
if sum(v in candidates[rr][c] for rr in range(len(working_grid))) == 1:
|
||||||
|
candidates[r][c] = {v} # candidate unique in column
|
||||||
|
break
|
||||||
|
|
||||||
|
_update_grid()
|
||||||
|
|
||||||
|
# Seek "naked pairs" if same pair of candidates twice in a row or col, with nothing else in those two cells
|
||||||
|
# Remove them from other cells in row/col
|
||||||
|
for r in range(len(working_grid)):
|
||||||
|
for c in range(len(working_grid)):
|
||||||
|
if working_grid[r][c] != 0 or len(candidates[r][c]) != 2:
|
||||||
|
continue
|
||||||
|
for cc in range(len(working_grid)):
|
||||||
|
if cc != c and candidates[r][cc] == candidates[r][c]:
|
||||||
|
for ccc in range(len(working_grid)):
|
||||||
|
if ccc != c and ccc != cc and candidates[r][c].intersection(candidates[r][ccc]):
|
||||||
|
candidates[r][ccc] -= candidates[r][c]
|
||||||
|
progress = True
|
||||||
|
for rr in range(len(working_grid)):
|
||||||
|
if rr != r and candidates[rr][c] == candidates[r][c]:
|
||||||
|
for rrr in range(len(working_grid)):
|
||||||
|
if rrr != r and rrr != rr and candidates[r][c].intersection(candidates[rrr][c]):
|
||||||
|
candidates[rrr][c] -= candidates[r][c]
|
||||||
|
progress = True
|
||||||
|
|
||||||
|
_update_grid()
|
||||||
|
|
||||||
|
# Seek "hidden pairs" - same pair of candidates occurs in two cells in a line, but nowhere else in the line
|
||||||
|
# alongside other candidates (hence hidden). All other candidates can be removed from those two cells
|
||||||
|
for r in range(len(working_grid)):
|
||||||
|
for c in range(len(working_grid)):
|
||||||
|
if working_grid[r][c] != 0:
|
||||||
|
continue
|
||||||
|
for cc in range(c + 1, len(working_grid)):
|
||||||
|
if working_grid[r][cc] != 0:
|
||||||
|
continue
|
||||||
|
# Check if r, c shares a candidate pair with r, cc (maybe subset, not exact candidate set match)
|
||||||
|
r_c_pairs = itertools.permutations(candidates[r][c], 2)
|
||||||
|
r_cc_pairs = itertools.permutations(candidates[r][cc], 2)
|
||||||
|
for pair in r_c_pairs:
|
||||||
|
if pair not in r_cc_pairs:
|
||||||
|
continue
|
||||||
|
otherwise_unique = True
|
||||||
|
# If this pair occurs elsewhere in the row, skip
|
||||||
|
for ccc in range(len(working_grid)):
|
||||||
|
if ccc in (c, cc):
|
||||||
|
continue
|
||||||
|
if pair in itertools.permutations(candidates[r][ccc], 2):
|
||||||
|
otherwise_unique = False
|
||||||
|
break
|
||||||
|
if not otherwise_unique:
|
||||||
|
continue
|
||||||
|
# Found a hidden pair, remove all other candidates from these two cells
|
||||||
|
candidates[r][c] = set(pair)
|
||||||
|
candidates[r][cc] = set(pair)
|
||||||
|
|
||||||
|
for rr in range(r + 1, len(working_grid)):
|
||||||
|
if working_grid[rr][c] != 0:
|
||||||
|
continue
|
||||||
|
# Check if r, c shares a candidate pair with rr, c (maybe subset, not exact candidate set match)
|
||||||
|
r_c_pairs = itertools.permutations(candidates[r][c], 2)
|
||||||
|
rr_c_pairs = itertools.permutations(candidates[rr][c], 2)
|
||||||
|
for pair in r_c_pairs:
|
||||||
|
if pair not in rr_c_pairs:
|
||||||
|
continue
|
||||||
|
otherwise_unique = True
|
||||||
|
# If this pair occurs elsewhere in the col, skip
|
||||||
|
for rrr in range(len(working_grid)):
|
||||||
|
if rrr in (r, rr):
|
||||||
|
continue
|
||||||
|
if pair in itertools.permutations(candidates[rrr][c], 2):
|
||||||
|
otherwise_unique = False
|
||||||
|
break
|
||||||
|
if not otherwise_unique:
|
||||||
|
continue
|
||||||
|
# Found a hidden pair, remove all other candidates from these two cells
|
||||||
|
candidates[r][c] = set(pair)
|
||||||
|
candidates[rr][c] = set(pair)
|
||||||
|
|
||||||
|
_update_grid()
|
||||||
|
|
||||||
|
# Seek X-wings by rows
|
||||||
|
for v in range(1, size + 1):
|
||||||
|
# If candidate is in the same 2 positions in 2 rows, and nowhere else in those rows
|
||||||
|
# Delete from the 2 intersecting cols
|
||||||
|
|
||||||
|
# Find rows which have exactly 2 instances of the value in their candidate sets
|
||||||
|
rows_with_v = [r for r in range(size) if sum(v in candidates[r][c] for c in range(size)) == 2]
|
||||||
|
if len(rows_with_v) < 2:
|
||||||
|
continue
|
||||||
|
# Check whether the 2 columns with the value are the same in the 2 rows
|
||||||
|
cols_with_v_per_row = [set() for _ in range(len(rows_with_v))]
|
||||||
|
for i, r in enumerate(rows_with_v):
|
||||||
|
for c in range(size):
|
||||||
|
if v in candidates[r][c]:
|
||||||
|
cols_with_v_per_row[i].add(c)
|
||||||
|
# Check if there are a pair of tows with the same cols (there may be more than 2 rows)
|
||||||
|
for i in range(len(rows_with_v)):
|
||||||
|
for j in range(i + 1, len(rows_with_v)):
|
||||||
|
if cols_with_v_per_row[i] == cols_with_v_per_row[j]:
|
||||||
|
# Found an X-wing, remove candidate from the 2 cols
|
||||||
|
for c in cols_with_v_per_row[i]:
|
||||||
|
for rr in range(size):
|
||||||
|
if rr not in (rows_with_v[i], rows_with_v[j]) and v in candidates[rr][c]:
|
||||||
|
candidates[rr][c].discard(v)
|
||||||
|
progress = True
|
||||||
|
|
||||||
|
# Seek X-wings by cols
|
||||||
|
for v in range(1, size + 1):
|
||||||
|
# If candidate is in the same 2 positions in 2 cols, and nowhere else in those cols
|
||||||
|
# Delete from the 2 intersecting rows
|
||||||
|
|
||||||
|
# Find cols which have exactly 2 instances of the value in their candidate sets
|
||||||
|
cols_with_v = [c for c in range(size) if sum(v in candidates[r][c] for r in range(size)) == 2]
|
||||||
|
if len(cols_with_v) < 2:
|
||||||
|
continue
|
||||||
|
# Check whether the 2 rows with the value are the same in the 2 cols
|
||||||
|
rows_with_v_per_col = [set() for _ in range(len(cols_with_v))]
|
||||||
|
for i, c in enumerate(cols_with_v):
|
||||||
|
for r in range(size):
|
||||||
|
if v in candidates[r][c]:
|
||||||
|
rows_with_v_per_col[i].add(r)
|
||||||
|
# Check if there are a pair of cols with the same rows (there may be more than 2 cols)
|
||||||
|
for i in range(len(cols_with_v)):
|
||||||
|
for j in range(i + 1, len(cols_with_v)):
|
||||||
|
if rows_with_v_per_col[i] == rows_with_v_per_col[j]:
|
||||||
|
# Found an X-wing, remove candidate from the 2 rows
|
||||||
|
for r in rows_with_v_per_col[i]:
|
||||||
|
for cc in range(size):
|
||||||
|
if cc not in (cols_with_v[i], cols_with_v[j]) and v in candidates[r][cc]:
|
||||||
|
candidates[r][cc].discard(v)
|
||||||
|
progress = True
|
||||||
|
|
||||||
|
_update_grid()
|
||||||
|
|
||||||
|
return progress
|
||||||
|
|
||||||
|
while _try_solve_logical():
|
||||||
|
continue
|
||||||
|
|
||||||
|
return working_grid, candidates
|
||||||
|
|
||||||
|
def _solve(
|
||||||
|
self,
|
||||||
|
grid: List[List[int]],
|
||||||
|
constraints: Dict[Tuple[Tuple[int, int], Tuple[int, int]], str],
|
||||||
|
) -> List[List[int]] | None:
|
||||||
|
"""
|
||||||
|
Backtracking Futoshiki solver. Used to verify generated puzzles.
|
||||||
|
Applies logical rules first then backtracks to fill gaps.
|
||||||
|
Return solved grid, or None if unsolvable.
|
||||||
|
"""
|
||||||
|
|
||||||
|
grid, candidates = self._solve_logical(grid, constraints)
|
||||||
|
|
||||||
|
size = len(grid)
|
||||||
|
empty_cell = None
|
||||||
|
|
||||||
|
# Find first empty cell
|
||||||
|
for r in range(size):
|
||||||
|
for c in range(size):
|
||||||
|
if grid[r][c] == 0:
|
||||||
|
empty_cell = (r, c)
|
||||||
|
break
|
||||||
|
if empty_cell:
|
||||||
|
break
|
||||||
|
|
||||||
|
# If no empty cell, solution is complete
|
||||||
|
if not empty_cell:
|
||||||
|
return copy.deepcopy(grid)
|
||||||
|
|
||||||
|
r, c = empty_cell
|
||||||
|
for val in range(1, size + 1):
|
||||||
|
if val not in candidates[r][c]:
|
||||||
|
continue
|
||||||
|
if self._is_valid(grid, r, c, val, constraints):
|
||||||
|
grid[r][c] = val
|
||||||
|
solution = self._solve(grid, constraints)
|
||||||
|
if solution is not None:
|
||||||
|
grid[r][c] = 0
|
||||||
|
return solution
|
||||||
|
grid[r][c] = 0
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _is_valid(
|
||||||
|
self,
|
||||||
|
grid: List[List[int]],
|
||||||
|
row: int,
|
||||||
|
col: int,
|
||||||
|
val: int,
|
||||||
|
constraints: Dict[Tuple[Tuple[int, int], Tuple[int, int]], str],
|
||||||
|
) -> bool:
|
||||||
|
"""Check row, col, and inequality constraints for placing val in grid[row][col]."""
|
||||||
|
size = len(grid)
|
||||||
|
|
||||||
|
# Row or column conflict?
|
||||||
|
if val in grid[row]:
|
||||||
|
return False
|
||||||
|
for r in range(size):
|
||||||
|
if grid[r][col] == val:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Temporarily place the val and check constraints
|
||||||
|
original_val = grid[row][col]
|
||||||
|
grid[row][col] = val
|
||||||
|
|
||||||
|
# Check all constraints involving this cell
|
||||||
|
for ((r1, c1), (r2, c2)), rel in constraints.items():
|
||||||
|
v1 = grid[r1][c1]
|
||||||
|
v2 = grid[r2][c2]
|
||||||
|
# If either is 0, skip
|
||||||
|
if v1 == 0 or v2 == 0:
|
||||||
|
continue
|
||||||
|
# If relation is '<', v1 < v2 must hold
|
||||||
|
if rel == "<":
|
||||||
|
if not (v1 < v2):
|
||||||
|
grid[row][col] = original_val
|
||||||
|
return False
|
||||||
|
elif rel == ">":
|
||||||
|
if not (v1 > v2):
|
||||||
|
grid[row][col] = original_val
|
||||||
|
return False
|
||||||
|
|
||||||
|
grid[row][col] = original_val
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _generate_random_solution(self, size: int, rng: Random) -> List[List[int]]:
|
||||||
|
"""
|
||||||
|
Generates a random valid completed Futoshiki solution with numbers 1..size.
|
||||||
|
Ensures each row and column has unique numbers.
|
||||||
|
"""
|
||||||
|
# Fill row by row with a random permutation, ensuring no column conflicts. Use backtracking
|
||||||
|
grid = [[0] * size for _ in range(size)]
|
||||||
|
|
||||||
|
def backtrack(r):
|
||||||
|
if r == size:
|
||||||
|
return True
|
||||||
|
nums = list(range(1, size + 1))
|
||||||
|
rng.shuffle(nums)
|
||||||
|
for permutation in itertools.permutations(nums):
|
||||||
|
# Place row if columns are valid
|
||||||
|
valid = True
|
||||||
|
for c in range(size):
|
||||||
|
if any(grid[rr][c] == permutation[c] for rr in range(r)):
|
||||||
|
valid = False
|
||||||
|
break
|
||||||
|
if valid:
|
||||||
|
grid[r] = list(permutation)
|
||||||
|
if backtrack(r + 1):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
if backtrack(0):
|
||||||
|
return grid
|
||||||
|
|
||||||
|
raise ValueError("Could not generate a random solution.")
|
||||||
|
|
||||||
|
def _generate_random_constraints(
|
||||||
|
self, solution: List[List[int]], difficulty: int, rng: Random
|
||||||
|
) -> Dict[Tuple[Tuple[int, int], Tuple[int, int]], str]:
|
||||||
|
"""
|
||||||
|
Randomly add inequality constraints that match the solution.
|
||||||
|
We only add constraints for adjacent cells (horizontal or vertical).
|
||||||
|
Probability of adding a constraint can scale with difficulty.
|
||||||
|
"""
|
||||||
|
size = len(solution)
|
||||||
|
constraints = {}
|
||||||
|
# For each pair of adjacent cells, we might add a constraint
|
||||||
|
# P(adding a constraint) increases with difficulty on an arbitrary scale
|
||||||
|
base_prob = 0.03 + 0.07 * difficulty
|
||||||
|
for r in range(size):
|
||||||
|
for c in range(size):
|
||||||
|
# Horizontal neighbor
|
||||||
|
if c < size - 1:
|
||||||
|
if rng.random() < base_prob:
|
||||||
|
if solution[r][c] < solution[r][c + 1]:
|
||||||
|
constraints[((r, c), (r, c + 1))] = "<"
|
||||||
|
else:
|
||||||
|
constraints[((r, c), (r, c + 1))] = ">"
|
||||||
|
# Vertical neighbor
|
||||||
|
if r < size - 1:
|
||||||
|
if rng.random() < base_prob:
|
||||||
|
if solution[r][c] < solution[r + 1][c]:
|
||||||
|
constraints[((r, c), (r + 1, c))] = "<"
|
||||||
|
else:
|
||||||
|
constraints[((r, c), (r + 1, c))] = ">"
|
||||||
|
return constraints
|
||||||
|
|
||||||
|
def count_solutions(self, grid, constraints, limit=2) -> int:
|
||||||
|
size = len(grid)
|
||||||
|
count = 0
|
||||||
|
|
||||||
|
def backtrack():
|
||||||
|
nonlocal count
|
||||||
|
# Early exit if limit reached
|
||||||
|
if count >= limit:
|
||||||
|
return
|
||||||
|
# Find the next empty cell
|
||||||
|
for r in range(size):
|
||||||
|
for c in range(size):
|
||||||
|
if grid[r][c] == 0:
|
||||||
|
for val in range(1, size + 1):
|
||||||
|
if self._is_valid(grid, r, c, val, constraints):
|
||||||
|
grid[r][c] = val
|
||||||
|
backtrack()
|
||||||
|
grid[r][c] = 0
|
||||||
|
return
|
||||||
|
count += 1
|
||||||
|
|
||||||
|
backtrack()
|
||||||
|
return count
|
||||||
|
|
||||||
|
def _remove_clues(
|
||||||
|
self,
|
||||||
|
grid: List[List[int]],
|
||||||
|
constraints: Dict[Tuple[Tuple[int, int], Tuple[int, int]], str],
|
||||||
|
rng: Random,
|
||||||
|
) -> List[List[int]]:
|
||||||
|
"""
|
||||||
|
Remove clues from a full solution to try to maintain a unique-solution puzzle.
|
||||||
|
We remove in random order until we reach our target, or can't without losing uniqueness.
|
||||||
|
"""
|
||||||
|
size = len(grid)
|
||||||
|
fill_fraction = 0.1
|
||||||
|
target_filled = int(fill_fraction * (size * size))
|
||||||
|
|
||||||
|
coords = [(r, c) for r in range(size) for c in range(size)]
|
||||||
|
rng.shuffle(coords)
|
||||||
|
|
||||||
|
def _count_filled_cells(g):
|
||||||
|
return sum(g[r][c] != 0 for r in range(size) for c in range(size))
|
||||||
|
|
||||||
|
def _try_remove():
|
||||||
|
for r, c in coords:
|
||||||
|
if _count_filled_cells(grid) <= target_filled:
|
||||||
|
break # Removal target hit
|
||||||
|
|
||||||
|
saved = grid[r][c]
|
||||||
|
if saved == 0:
|
||||||
|
continue
|
||||||
|
# Try remove
|
||||||
|
grid[r][c] = 0
|
||||||
|
|
||||||
|
# Check if unsolvable
|
||||||
|
sol = self._solve(copy.deepcopy(grid), constraints)
|
||||||
|
if sol is None:
|
||||||
|
grid[r][c] = saved
|
||||||
|
continue
|
||||||
|
# Check if not unique
|
||||||
|
if self.count_solutions(copy.deepcopy(grid), constraints, limit=2) > 1:
|
||||||
|
grid[r][c] = saved
|
||||||
|
|
||||||
|
_try_remove()
|
||||||
|
|
||||||
|
# Second pass if we aren't close to target
|
||||||
|
if _count_filled_cells(grid) > 2 * target_filled:
|
||||||
|
rng.shuffle(coords)
|
||||||
|
_try_remove()
|
||||||
|
|
||||||
|
return grid
|
||||||
|
|
||||||
|
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
|
||||||
|
if not answer:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
oracle_answer = entry["answer"]
|
||||||
|
metadata = entry["metadata"]
|
||||||
|
solution: list[list[int]] = metadata["solution"]
|
||||||
|
board_size: int = len(solution[0])
|
||||||
|
|
||||||
|
# 1. match answer without trailing whitespaces
|
||||||
|
answer_stripped = "\n".join(l.rstrip() for l in answer.split("\n"))
|
||||||
|
oracle_answer_stripped = "\n".join(l.rstrip() for l in oracle_answer.split("\n"))
|
||||||
|
|
||||||
|
if answer_stripped == oracle_answer_stripped:
|
||||||
|
reward = 1.0
|
||||||
|
else:
|
||||||
|
# 2. accept answers with correct numeric sequence (ignoring non-numeric characters)
|
||||||
|
row = 0
|
||||||
|
num_matching = 0
|
||||||
|
for ln in answer.split("\n"):
|
||||||
|
numbers = [int(c) for c in ln if c.isnumeric()]
|
||||||
|
if len(numbers) != len(solution[0]):
|
||||||
|
continue # ignore lines without numbers
|
||||||
|
for a, b in zip(solution[row], numbers):
|
||||||
|
if a == b:
|
||||||
|
num_matching += 1
|
||||||
|
row += 1
|
||||||
|
|
||||||
|
reward = num_matching / (board_size * board_size)
|
||||||
|
reward *= 0.9 # penalty for not using standard format
|
||||||
|
|
||||||
|
if len(answer) > len(oracle_answer):
|
||||||
|
reward *= len(oracle_answer) / len(answer) # penalty for additional length
|
||||||
|
return reward
|
||||||
|
|
||||||
|
|
||||||
|
register_dataset("futoshiki", FutoshikiDataset, FutoshikiConfig)
|
||||||
|
|
@ -95,7 +95,8 @@ class MazeDataset(ProceduralDataset):
|
||||||
+ "\n```"
|
+ "\n```"
|
||||||
+ "\nLegend: "
|
+ "\nLegend: "
|
||||||
+ f"'{self.wall_char}' = Wall, '{self.path_char}' = Passage\n\n"
|
+ f"'{self.wall_char}' = Wall, '{self.path_char}' = Passage\n\n"
|
||||||
+ "What is the minimum number of steps to reach the goal?"
|
+ "What is the minimum number of steps to reach the goal?\n"
|
||||||
|
+ "Give only the number of steps as your final answer, no other text or formatting."
|
||||||
)
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,7 @@
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from random import Random
|
from random import Random
|
||||||
from typing import List, Optional, Tuple
|
from typing import Any, List, Optional, Tuple
|
||||||
|
|
||||||
from ..factory import ProceduralDataset, register_dataset
|
from ..factory import ProceduralDataset, register_dataset
|
||||||
|
|
||||||
|
|
@ -141,11 +141,55 @@ class MiniSudokuDataset(ProceduralDataset):
|
||||||
puzzle_str = self._board_to_string(puzzle)
|
puzzle_str = self._board_to_string(puzzle)
|
||||||
solution_str = self._board_to_string(solved_board)
|
solution_str = self._board_to_string(solved_board)
|
||||||
|
|
||||||
|
question = (
|
||||||
|
"In 4x4 Mini Sudoku:\n"
|
||||||
|
"- Each row must contain each number from 1-4 exactly once\n"
|
||||||
|
"- Each column must contain each number 1-4 exactly once\n"
|
||||||
|
"- Each 2x2 subgrid must contain each number 1-4 exactly once\n"
|
||||||
|
f"Solve this 4x4 Mini Sudoku puzzle:\n{puzzle_str}\n"
|
||||||
|
"Format your response as the puzzle above, with spaces separating each number within a row, and newlines separating rows.\n"
|
||||||
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"question": f"Solve this 4x4 Mini Sudoku puzzle:\n{puzzle_str}",
|
"question": question,
|
||||||
"answer": solution_str,
|
"answer": solution_str,
|
||||||
"metadata": {"puzzle": puzzle, "solution": solved_board, "num_empty": num_empty},
|
"metadata": {"puzzle": puzzle, "solution": solved_board, "num_empty": num_empty},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
|
||||||
|
if not answer:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
oracle_answer = entry["answer"]
|
||||||
|
metadata = entry["metadata"]
|
||||||
|
solution: list[list[int]] = metadata["solution"]
|
||||||
|
board_size: int = len(solution[0])
|
||||||
|
|
||||||
|
# 1. match answer without trailing whitespaces
|
||||||
|
answer_stripped = "\n".join(l.rstrip() for l in answer.split("\n"))
|
||||||
|
oracle_answer_stripped = "\n".join(l.rstrip() for l in oracle_answer.split("\n"))
|
||||||
|
|
||||||
|
if answer_stripped == oracle_answer_stripped:
|
||||||
|
reward = 1.0
|
||||||
|
else:
|
||||||
|
# 2. accept answers with correct numeric sequence (ignoring non-numeric characters)
|
||||||
|
row = 0
|
||||||
|
num_matching = 0
|
||||||
|
for ln in answer.split("\n"):
|
||||||
|
numbers = [int(c) for c in ln if c.isnumeric()]
|
||||||
|
if len(numbers) != board_size:
|
||||||
|
continue # ignore lines without numbers
|
||||||
|
for a, b in zip(solution[row], numbers):
|
||||||
|
if a == b:
|
||||||
|
num_matching += 1
|
||||||
|
row += 1
|
||||||
|
|
||||||
|
reward = num_matching / (board_size * board_size)
|
||||||
|
reward *= 0.9 # penalty for not using standard format
|
||||||
|
|
||||||
|
if len(answer) > len(oracle_answer):
|
||||||
|
reward *= len(oracle_answer) / len(answer) # penalty for additional length
|
||||||
|
return reward
|
||||||
|
|
||||||
|
|
||||||
register_dataset("mini_sudoku", MiniSudokuDataset, MiniSudokuConfig)
|
register_dataset("mini_sudoku", MiniSudokuDataset, MiniSudokuConfig)
|
||||||
|
|
|
||||||
|
|
@ -14,14 +14,29 @@ from ..factory import ProceduralDataset, register_dataset
|
||||||
MIN_BOARD_SIZE = 4
|
MIN_BOARD_SIZE = 4
|
||||||
MAX_BOARD_SIZE = 12
|
MAX_BOARD_SIZE = 12
|
||||||
|
|
||||||
QUESTION_TEMPLATE = """Solve this N Queens puzzle:
|
QUESTION_TEMPLATE = """Your job is to complete an n x n chess board with n Queens in total, such that no two attack each other.
|
||||||
{puzzle}
|
|
||||||
|
|
||||||
The board size is {n}x{n} and your job is to place {num_removed} queen(s) on the board such that no two queens attack each other.
|
|
||||||
|
|
||||||
No two queens attack each other if they are not in the same row, column, or diagonal.
|
No two queens attack each other if they are not in the same row, column, or diagonal.
|
||||||
|
|
||||||
Place a queen by replacing an underscore (_) with a Q.
|
You can place a queen by replacing an underscore (_) with a Q.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
- Input: Given the below board of size 4 x 4 your job is to place 2 queen(s) on the board such that no two queens attack each other.
|
||||||
|
_ Q _ _
|
||||||
|
_ _ _ _
|
||||||
|
_ _ _ _
|
||||||
|
_ _ Q _
|
||||||
|
- Output:
|
||||||
|
_ Q _ _
|
||||||
|
_ _ _ Q
|
||||||
|
Q _ _ _
|
||||||
|
_ _ Q _
|
||||||
|
- Explanation
|
||||||
|
- None of the queens attack each other vertically, horizontally, or diagonally.
|
||||||
|
- The added queens are marked with Q at the positions (1, 3) and (2, 0).
|
||||||
|
|
||||||
|
Given the below board of size {n} x {n} your job is to place {num_removed} queen(s) on the board such that no two queens attack each other.
|
||||||
|
{puzzle}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -137,13 +152,16 @@ class NQueensDataset(ProceduralDataset):
|
||||||
|
|
||||||
def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float:
|
def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float:
|
||||||
valid_solutions = entry["metadata"]["valid_answers"]
|
valid_solutions = entry["metadata"]["valid_answers"]
|
||||||
reward = 0.0
|
|
||||||
if answer is not None:
|
if answer is not None:
|
||||||
if answer in valid_solutions:
|
if answer in valid_solutions:
|
||||||
reward = 1.0
|
return 1.0
|
||||||
else:
|
try:
|
||||||
reward = 0.01
|
answer = self._board_to_string(eval(answer))
|
||||||
return reward
|
if answer in valid_solutions:
|
||||||
|
return 0.5
|
||||||
|
except Exception as e:
|
||||||
|
return 0.01
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
|
||||||
register_dataset("n_queens", NQueensDataset, NQueensConfig)
|
register_dataset("n_queens", NQueensDataset, NQueensConfig)
|
||||||
|
|
|
||||||
|
|
@ -8,6 +8,23 @@ from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
from ..factory import ProceduralDataset, register_dataset
|
from ..factory import ProceduralDataset, register_dataset
|
||||||
|
|
||||||
|
QUESTION_TEMPLATE = """Solve the Tower of Hanoi problem with {num_disks} disks and {num_pegs} pegs.
|
||||||
|
Move all disks from {start_peg} to {target_peg} following the rules:
|
||||||
|
- Only one disk can be moved at a time.
|
||||||
|
- A larger disk cannot be placed on top of a smaller disk.
|
||||||
|
- All disks must be on a peg at all times.
|
||||||
|
Example:
|
||||||
|
Move disk 1 from Peg 1 to Peg 3
|
||||||
|
Move disk 2 from Peg 1 to Peg 2
|
||||||
|
Move disk 1 from Peg 3 to Peg 2
|
||||||
|
|
||||||
|
Provide the sequence of moves.
|
||||||
|
Formatting guidelines:
|
||||||
|
Each instruction should be placed on a single line.
|
||||||
|
Each line should be formatted as 'Move disk X from Peg Y to Peg Z'
|
||||||
|
Do not include any other text or formatting.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class HanoiConfig:
|
class HanoiConfig:
|
||||||
|
|
@ -245,22 +262,13 @@ class HanoiDataset(ProceduralDataset):
|
||||||
# Peg labels
|
# Peg labels
|
||||||
peg_labels = {peg: f"Peg {peg}" for peg in pegs}
|
peg_labels = {peg: f"Peg {peg}" for peg in pegs}
|
||||||
|
|
||||||
question_str = (
|
|
||||||
f"Solve the Tower of Hanoi problem with {num_disks} disks and {num_pegs} pegs.\n"
|
|
||||||
f"Move all disks from {peg_labels[start_peg]} to {peg_labels[target_peg]} following the rules:\n"
|
|
||||||
"- Only one disk can be moved at a time.\n"
|
|
||||||
"- A larger disk cannot be placed on top of a smaller disk.\n"
|
|
||||||
"- All disks must be on a peg at all times.\n"
|
|
||||||
"Example:\n"
|
|
||||||
"Move disk 1 from Peg 1 to Peg 3\n"
|
|
||||||
"Move disk 2 from Peg 1 to Peg 2\n"
|
|
||||||
"Move disk 1 from Peg 3 to Peg 2\n"
|
|
||||||
"\n"
|
|
||||||
"Provide the sequence of moves."
|
|
||||||
)
|
|
||||||
|
|
||||||
result = {
|
result = {
|
||||||
"question": question_str,
|
"question": QUESTION_TEMPLATE.format(
|
||||||
|
num_disks=num_disks,
|
||||||
|
num_pegs=num_pegs,
|
||||||
|
start_peg=peg_labels[start_peg],
|
||||||
|
target_peg=peg_labels[target_peg],
|
||||||
|
),
|
||||||
"answer": solution,
|
"answer": solution,
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"num_disks": num_disks,
|
"num_disks": num_disks,
|
||||||
|
|
@ -359,7 +367,7 @@ class HanoiDataset(ProceduralDataset):
|
||||||
tuple: (disk, from_peg, to_peg)
|
tuple: (disk, from_peg, to_peg)
|
||||||
"""
|
"""
|
||||||
pattern = r"Move disk (\d+) from Peg (\d+) to Peg (\d+)"
|
pattern = r"Move disk (\d+) from Peg (\d+) to Peg (\d+)"
|
||||||
match = re.match(pattern, move)
|
match = re.search(pattern, move)
|
||||||
if not match:
|
if not match:
|
||||||
raise ValueError(f"Unexpected move format: '{move}'")
|
raise ValueError(f"Unexpected move format: '{move}'")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -175,9 +175,9 @@ class FamilyRelationshipsDataset(ProceduralDataset):
|
||||||
|
|
||||||
def __init__(self, config: FamilyRelationshipsConfig):
|
def __init__(self, config: FamilyRelationshipsConfig):
|
||||||
self._templates = [
|
self._templates = [
|
||||||
"What is {person1} to {person2}?",
|
"What is {person1} to {person2}? Respond only with the word that describes their relationship.",
|
||||||
"How is {person1} related to {person2}?",
|
"How is {person1} related to {person2}? Provide the relationship in one word.",
|
||||||
"What relation is {person1} to {person2}?",
|
"What relation is {person1} to {person2}? Answer with a single word.",
|
||||||
]
|
]
|
||||||
super().__init__(config=config, seed=config.seed, size=config.size)
|
super().__init__(config=config, seed=config.seed, size=config.size)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,7 @@ Logic tasks for training reasoning capabilities.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from .aiw import AliceInWonderlandConfig, AliceInWonderlandDataset
|
from .aiw import AliceInWonderlandConfig, AliceInWonderlandDataset
|
||||||
|
from .circuit_logic import CircuitLogicConfig, CircuitLogicDataset
|
||||||
from .propositional_logic import PropositionalLogicConfig, PropositionalLogicDataset
|
from .propositional_logic import PropositionalLogicConfig, PropositionalLogicDataset
|
||||||
from .self_reference import SelfReferenceConfig, SelfReferenceDataset
|
from .self_reference import SelfReferenceConfig, SelfReferenceDataset
|
||||||
from .syllogisms import SyllogismConfig, SyllogismDataset, Term
|
from .syllogisms import SyllogismConfig, SyllogismDataset, Term
|
||||||
|
|
@ -20,5 +21,8 @@ __all__ = [
|
||||||
"ZebraConfig",
|
"ZebraConfig",
|
||||||
"ZebraDataset",
|
"ZebraDataset",
|
||||||
"SelfReference",
|
"SelfReference",
|
||||||
|
"SelfReferenceConfig",
|
||||||
"SelfReferenceDataset",
|
"SelfReferenceDataset",
|
||||||
|
"CircuitLogicConfig",
|
||||||
|
"CircuitLogicDataset",
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -187,7 +187,7 @@ class AliceInWonderlandDataset(ProceduralDataset):
|
||||||
num_female_colleagues_bob_circle=num_female_colleagues_bob_circle,
|
num_female_colleagues_bob_circle=num_female_colleagues_bob_circle,
|
||||||
)
|
)
|
||||||
|
|
||||||
return {"question": question, "answer": answer, "metadata": {"task_type": task_type.value}}
|
return {"question": question, "answer": str(answer), "metadata": {"task_type": task_type.value}}
|
||||||
|
|
||||||
def __getitem__(self, idx: int) -> dict:
|
def __getitem__(self, idx: int) -> dict:
|
||||||
rng = Random(self.seed + idx)
|
rng = Random(self.seed + idx)
|
||||||
|
|
|
||||||
416
reasoning_gym/logic/circuit_logic.py
Normal file
416
reasoning_gym/logic/circuit_logic.py
Normal file
|
|
@ -0,0 +1,416 @@
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from random import Random
|
||||||
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
from ..factory import ProceduralDataset, register_dataset
|
||||||
|
|
||||||
|
VERT = "│"
|
||||||
|
HORIZ = "─"
|
||||||
|
RBRANCH = "├"
|
||||||
|
LUP = "┘"
|
||||||
|
LDOWN = "┐"
|
||||||
|
RUP = "└"
|
||||||
|
RDOWN = "┌"
|
||||||
|
|
||||||
|
|
||||||
|
def _repeat(s: str, n: int) -> str:
|
||||||
|
return s * n
|
||||||
|
|
||||||
|
|
||||||
|
def _matrix_put(matrix: List[List[str]], h: int, w: int, x: int, y: int, s: str, direction: str):
|
||||||
|
"""Place a string `s` into the 2D `matrix` starting at (x,y),
|
||||||
|
advancing in `direction` ('RIGHT' or 'DOWN')."""
|
||||||
|
if x >= w or y >= h:
|
||||||
|
raise IndexError(f"_matrix_put: point ({x}, {y}) out of bounds!")
|
||||||
|
for c in s:
|
||||||
|
if x < 0 or x >= w or y < 0 or y >= h:
|
||||||
|
break
|
||||||
|
matrix[y][x] = c
|
||||||
|
if direction == "RIGHT":
|
||||||
|
x += 1
|
||||||
|
elif direction == "DOWN":
|
||||||
|
y += 1
|
||||||
|
|
||||||
|
|
||||||
|
def _get_excel_name(index: int) -> str:
|
||||||
|
"""
|
||||||
|
Convert a zero-based integer `index` into an Excel-like column name:
|
||||||
|
0 -> A, 1 -> B, ..., 25 -> Z, 26 -> AA, etc.
|
||||||
|
"""
|
||||||
|
result = ""
|
||||||
|
index += 1
|
||||||
|
while index > 0:
|
||||||
|
rem = (index - 1) % 26
|
||||||
|
result = chr(ord("A") + rem) + result
|
||||||
|
index = (index - 1) // 26
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CircuitLogicConfig:
|
||||||
|
"""
|
||||||
|
Configuration for circuit logic task generation.
|
||||||
|
|
||||||
|
:param num_terms: Number of terms (sub-expressions) to generate
|
||||||
|
:param min_inputs: Minimum inputs per term
|
||||||
|
:param max_inputs: Maximum inputs per term
|
||||||
|
:param neg_prob: Probability (0.0-1.0) that an input is negated
|
||||||
|
:param allow_reuse: Whether inputs can be reused
|
||||||
|
:param size: Number of items in the dataset
|
||||||
|
:param seed: Random seed
|
||||||
|
"""
|
||||||
|
|
||||||
|
num_terms: int = 5
|
||||||
|
min_inputs: int = 2
|
||||||
|
max_inputs: int = 4
|
||||||
|
neg_prob: float = 0.3
|
||||||
|
allow_reuse: bool = True
|
||||||
|
size: int = 100
|
||||||
|
seed: Optional[int] = None
|
||||||
|
|
||||||
|
def validate(self):
|
||||||
|
assert 1 <= self.min_inputs <= self.max_inputs, "Invalid input range"
|
||||||
|
assert 1 <= self.num_terms, "Invalid number of terms"
|
||||||
|
assert 0.0 <= self.neg_prob <= 1.0, "neg_prob must be between 0 and 1"
|
||||||
|
|
||||||
|
|
||||||
|
class CircuitLogicDataset(ProceduralDataset):
|
||||||
|
"""
|
||||||
|
Generates random digital logic circuits (in ASCII) together with:
|
||||||
|
- a random Boolean expression,
|
||||||
|
- random input assignments,
|
||||||
|
- the final evaluated output.
|
||||||
|
|
||||||
|
Each item in the dataset is a dict with:
|
||||||
|
{
|
||||||
|
"question": <str>,
|
||||||
|
"answer": <str>,
|
||||||
|
"metadata": {
|
||||||
|
"diagram": <ASCII circuit diagram>,
|
||||||
|
"expression": <str>,
|
||||||
|
"term_strings": <list of term_strings>,
|
||||||
|
"assignments": <dict of input->0/1>,
|
||||||
|
"final_gate": <str>,
|
||||||
|
"inputs": <list of input names>,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config: CircuitLogicConfig):
|
||||||
|
super().__init__(config=config, seed=config.seed, size=config.size)
|
||||||
|
self.config.validate()
|
||||||
|
|
||||||
|
self.internal_ops = [
|
||||||
|
("AND", "&", "&"),
|
||||||
|
("NAND", "↑", "↑"),
|
||||||
|
("XOR", "⊕", "⊕"),
|
||||||
|
]
|
||||||
|
self.final_gate_options = [
|
||||||
|
("OR", "+"),
|
||||||
|
("NOR", "↓"),
|
||||||
|
("XOR", "⊕"),
|
||||||
|
("AND", "&"),
|
||||||
|
]
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
return self.config.size
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
self._current_idx = 0
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __next__(self) -> Dict[str, Any]:
|
||||||
|
if self._current_idx >= self.config.size:
|
||||||
|
raise StopIteration
|
||||||
|
item = self[self._current_idx]
|
||||||
|
self._current_idx += 1
|
||||||
|
return item
|
||||||
|
|
||||||
|
def __getitem__(self, idx: int) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Generate one random circuit logic item using ASCII drawing.
|
||||||
|
"""
|
||||||
|
rng = Random(self.seed + idx if self.seed is not None else None)
|
||||||
|
return self._generate_circuit(
|
||||||
|
rng=rng,
|
||||||
|
num_terms=self.config.num_terms,
|
||||||
|
min_inputs=self.config.min_inputs,
|
||||||
|
max_inputs=self.config.max_inputs,
|
||||||
|
neg_prob=self.config.neg_prob,
|
||||||
|
allow_reuse=self.config.allow_reuse,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _generate_circuit(
|
||||||
|
self, rng: Random, num_terms: int, min_inputs: int, max_inputs: int, neg_prob: float, allow_reuse: bool
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Generate circuit logic (ASCII drawing + expression + evaluation)
|
||||||
|
"""
|
||||||
|
final_gate_name, final_gate_sym = rng.choice(self.final_gate_options)
|
||||||
|
final_gate_width = 2 + len(final_gate_sym)
|
||||||
|
|
||||||
|
distinct_inputs: List[str] = []
|
||||||
|
|
||||||
|
def get_random_input() -> str:
|
||||||
|
if allow_reuse and distinct_inputs and rng.random() < 0.5:
|
||||||
|
return rng.choice(distinct_inputs)
|
||||||
|
else:
|
||||||
|
name = _get_excel_name(len(distinct_inputs))
|
||||||
|
distinct_inputs.append(name)
|
||||||
|
return name
|
||||||
|
|
||||||
|
term_ops: List[Tuple[str, str, str]] = []
|
||||||
|
term_strings: List[str] = []
|
||||||
|
for _ in range(num_terms):
|
||||||
|
op = rng.choice(self.internal_ops)
|
||||||
|
term_ops.append(op)
|
||||||
|
|
||||||
|
term_length = rng.randint(min_inputs, max_inputs)
|
||||||
|
parts = []
|
||||||
|
for __ in range(term_length):
|
||||||
|
inp = get_random_input()
|
||||||
|
neg = rng.random() < neg_prob
|
||||||
|
parts.append(inp + ("'" if neg else ""))
|
||||||
|
# Join the parts with the operator’s join symbol.
|
||||||
|
term_str = op[1].join(parts)
|
||||||
|
term_strings.append(term_str)
|
||||||
|
|
||||||
|
expression_for_display = final_gate_sym.join(f"({t})" for t in term_strings)
|
||||||
|
# use || separator internally that doesn't clash with other symbols...
|
||||||
|
separator = "||"
|
||||||
|
expression_for_internal = separator.join(term_strings)
|
||||||
|
|
||||||
|
expr = [] # will hold a list of tuples (op_used, term_input_list)
|
||||||
|
inputs_set = set()
|
||||||
|
term_inputs_map = {}
|
||||||
|
input_ypos = 0
|
||||||
|
|
||||||
|
outer_terms = expression_for_internal.split(separator)
|
||||||
|
for op_chosen, term in zip(term_ops, outer_terms):
|
||||||
|
op_used = op_chosen
|
||||||
|
# If the join symbol appears in the term, split by it; otherwise (single literal) use it as-is.
|
||||||
|
if op_used[1] in term:
|
||||||
|
input_strs = term.split(op_used[1])
|
||||||
|
else:
|
||||||
|
input_strs = [term]
|
||||||
|
|
||||||
|
curr_term = []
|
||||||
|
for part in input_strs:
|
||||||
|
if not part:
|
||||||
|
continue
|
||||||
|
neg = part.endswith("'")
|
||||||
|
name = part[:-1] if neg else part
|
||||||
|
inputs_set.add(name)
|
||||||
|
curr_term.append({"name": name, "ypos": input_ypos, "neg": neg})
|
||||||
|
term_inputs_map.setdefault(name, []).append({"ypos": input_ypos, "neg": neg})
|
||||||
|
input_ypos += 1
|
||||||
|
|
||||||
|
expr.append((op_used, curr_term))
|
||||||
|
# Add a gap after each term.
|
||||||
|
input_ypos += 1
|
||||||
|
|
||||||
|
inputs_list = sorted(list(inputs_set))
|
||||||
|
total_term_inputs = sum(len(t) for (_, t) in expr)
|
||||||
|
height = len(inputs_list) + total_term_inputs + len(expr) - 1
|
||||||
|
|
||||||
|
max_input_len = max((len(s) for s in inputs_list), default=0)
|
||||||
|
input_width = max_input_len + 2
|
||||||
|
width = 0
|
||||||
|
|
||||||
|
width += input_width
|
||||||
|
width += len(inputs_list) * 2
|
||||||
|
not_and_start = width
|
||||||
|
width += 7 # space for gates ...
|
||||||
|
width += len(expr) + 1
|
||||||
|
gate_start = width
|
||||||
|
width += final_gate_width
|
||||||
|
width += 4 # additional wiring space
|
||||||
|
width += 4
|
||||||
|
width += len(expression_for_display)
|
||||||
|
|
||||||
|
# Create an empty drawing matrix.
|
||||||
|
matrix = [[" " for _ in range(width)] for _ in range(height)]
|
||||||
|
base_y = len(inputs_list)
|
||||||
|
|
||||||
|
x = width - 8 - len(expression_for_display)
|
||||||
|
y = base_y + ((height - base_y) // 2)
|
||||||
|
_matrix_put(matrix, height, width, x, y, _repeat(HORIZ, 4) + " OUT", "RIGHT")
|
||||||
|
|
||||||
|
x = gate_start
|
||||||
|
out_gate_center = base_y + ((height - base_y) // 2) - (len(expr) // 2)
|
||||||
|
if len(expr) == 1:
|
||||||
|
_matrix_put(matrix, height, width, x, out_gate_center, _repeat(HORIZ, final_gate_width), "RIGHT")
|
||||||
|
else:
|
||||||
|
_matrix_put(matrix, height, width, x, out_gate_center, _repeat(HORIZ, len(expr)), "DOWN")
|
||||||
|
_matrix_put(matrix, height, width, x + 1, out_gate_center, _repeat(VERT, len(expr)), "DOWN")
|
||||||
|
for i, ch in enumerate(final_gate_sym):
|
||||||
|
_matrix_put(matrix, height, width, x + 2 + i, out_gate_center, _repeat(ch, len(expr)), "DOWN")
|
||||||
|
_matrix_put(matrix, height, width, x + 3 + i, out_gate_center, _repeat(ch, len(expr)), "DOWN")
|
||||||
|
|
||||||
|
# Draw internal wiring (for the internal gate section).
|
||||||
|
x = not_and_start
|
||||||
|
y = base_y
|
||||||
|
for op, term_inputs in expr:
|
||||||
|
layers = [""] * 7
|
||||||
|
for ti in term_inputs:
|
||||||
|
layers[0] += HORIZ
|
||||||
|
layers[1] += ">" if ti["neg"] else HORIZ
|
||||||
|
layers[2] += "o" if ti["neg"] else HORIZ
|
||||||
|
layers[3] += HORIZ
|
||||||
|
# If multiple inputs, we connect them vertically
|
||||||
|
layers[4] += VERT if len(term_inputs) > 1 else HORIZ
|
||||||
|
layers[5] += op[2] if (len(term_inputs) > 1) else HORIZ
|
||||||
|
layers[6] += op[2] if (len(term_inputs) > 1) else HORIZ
|
||||||
|
|
||||||
|
for i in range(7):
|
||||||
|
_matrix_put(matrix, height, width, x + i, y, layers[i], "DOWN")
|
||||||
|
y += len(term_inputs) + 1
|
||||||
|
|
||||||
|
x = 0
|
||||||
|
y = 0
|
||||||
|
for inp in inputs_list:
|
||||||
|
label = f"{inp}: " + _repeat(HORIZ, input_width - (len(inp) + 2))
|
||||||
|
_matrix_put(matrix, height, width, x, y, label, "RIGHT")
|
||||||
|
y += 1
|
||||||
|
|
||||||
|
x = input_width
|
||||||
|
for idx, inp in enumerate(inputs_list):
|
||||||
|
y = idx
|
||||||
|
length = len(inputs_list) * 2 - 1 - (idx * 2)
|
||||||
|
_matrix_put(matrix, height, width, x, y, _repeat(HORIZ, length) + LDOWN, "RIGHT")
|
||||||
|
|
||||||
|
num = 0
|
||||||
|
offset = len(inputs_list) * 2 - 1
|
||||||
|
for inp in inputs_list:
|
||||||
|
y_breaks = [base_y + ti["ypos"] for ti in term_inputs_map.get(inp, [])]
|
||||||
|
y_breaks.sort()
|
||||||
|
for yb in y_breaks:
|
||||||
|
_matrix_put(
|
||||||
|
matrix, height, width, x + offset, yb, _repeat(HORIZ, len(inputs_list) * 2 - offset), "RIGHT"
|
||||||
|
)
|
||||||
|
y_start = num + 1
|
||||||
|
max_break = max(y_breaks) if y_breaks else y_start
|
||||||
|
branch = list(_repeat(VERT, max_break - y_start + 1))
|
||||||
|
for yb in y_breaks:
|
||||||
|
pos = yb - y_start
|
||||||
|
if 0 <= pos < len(branch):
|
||||||
|
branch[pos] = RBRANCH
|
||||||
|
branch[-1] = RUP
|
||||||
|
_matrix_put(matrix, height, width, x + offset, y_start, "".join(branch), "DOWN")
|
||||||
|
offset -= 2
|
||||||
|
num += 1
|
||||||
|
|
||||||
|
x = not_and_start + 7
|
||||||
|
out_y = out_gate_center
|
||||||
|
breakx = len(expr) // 2
|
||||||
|
for op, term_inputs in expr:
|
||||||
|
in_y = base_y + (term_inputs[0]["ypos"] + term_inputs[-1]["ypos"]) // 2
|
||||||
|
# horizontal to branch
|
||||||
|
_matrix_put(matrix, height, width, x, in_y, _repeat(HORIZ, abs(breakx) + 1), "RIGHT")
|
||||||
|
# horizontal from branch up/down to final gate column
|
||||||
|
_matrix_put(
|
||||||
|
matrix, height, width, x + abs(breakx) + 1, out_y, _repeat(HORIZ, len(expr) - abs(breakx)), "RIGHT"
|
||||||
|
)
|
||||||
|
|
||||||
|
if in_y < out_y:
|
||||||
|
branch = LDOWN + _repeat(VERT, out_y - in_y - 1) + RUP
|
||||||
|
_matrix_put(matrix, height, width, x + abs(breakx) + 1, in_y, branch, "DOWN")
|
||||||
|
elif in_y > out_y:
|
||||||
|
branch = RDOWN + _repeat(VERT, in_y - out_y - 1) + LUP
|
||||||
|
_matrix_put(matrix, height, width, x + abs(breakx) + 1, out_y, branch, "DOWN")
|
||||||
|
|
||||||
|
out_y += 1
|
||||||
|
breakx -= 1
|
||||||
|
|
||||||
|
ascii_diagram = "\n".join("".join(row).rstrip() for row in matrix)
|
||||||
|
|
||||||
|
assignments = {}
|
||||||
|
for inp in inputs_list:
|
||||||
|
assignments[inp] = rng.choice([0, 1])
|
||||||
|
|
||||||
|
term_values = []
|
||||||
|
for op_used, term_inputs in expr:
|
||||||
|
op_name = op_used[0]
|
||||||
|
values = []
|
||||||
|
for literal in term_inputs:
|
||||||
|
val = assignments[literal["name"]]
|
||||||
|
if literal["neg"]:
|
||||||
|
val = 1 - val
|
||||||
|
values.append(val)
|
||||||
|
|
||||||
|
if op_name == "AND":
|
||||||
|
term_val = 1 if all(v == 1 for v in values) else 0
|
||||||
|
elif op_name == "NAND":
|
||||||
|
term_val = 0 if all(v == 1 for v in values) else 1
|
||||||
|
elif op_name == "XOR":
|
||||||
|
tmp = 0
|
||||||
|
for v in values:
|
||||||
|
tmp ^= v
|
||||||
|
term_val = tmp
|
||||||
|
else:
|
||||||
|
term_val = 0
|
||||||
|
term_values.append(term_val)
|
||||||
|
|
||||||
|
# Evaluate final gate based on term values
|
||||||
|
if final_gate_name == "OR":
|
||||||
|
final_result = 1 if any(v == 1 for v in term_values) else 0
|
||||||
|
elif final_gate_name == "NOR":
|
||||||
|
final_result = 0 if any(v == 1 for v in term_values) else 1
|
||||||
|
elif final_gate_name == "XOR":
|
||||||
|
final_result = sum(term_values) % 2
|
||||||
|
elif final_gate_name == "AND":
|
||||||
|
final_result = 1 if all(v == 1 for v in term_values) else 0
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown gate type: {final_gate_name}")
|
||||||
|
|
||||||
|
lines = []
|
||||||
|
lines.append("Below is a randomly generated logic circuit.\n")
|
||||||
|
lines.append(ascii_diagram)
|
||||||
|
lines.append("\n")
|
||||||
|
legend_lines = []
|
||||||
|
legend_lines.append("Legend for gates:")
|
||||||
|
for op_name, _, draw_sym in self.internal_ops:
|
||||||
|
legend_lines.append(f"{draw_sym*2}: {op_name}")
|
||||||
|
if neg_prob > 0:
|
||||||
|
legend_lines.append(f">o: Negate")
|
||||||
|
if final_gate_sym not in self.internal_ops:
|
||||||
|
legend_lines.append(f"{final_gate_sym*2}: {final_gate_name}")
|
||||||
|
legend_str = "\n".join(legend_lines)
|
||||||
|
|
||||||
|
lines.append(legend_str)
|
||||||
|
lines.append("")
|
||||||
|
lines.append("Given the following input assignments:")
|
||||||
|
for inp in inputs_list:
|
||||||
|
lines.append(f" {inp} = {assignments[inp]}")
|
||||||
|
lines.append("")
|
||||||
|
lines.append("What is the final output?")
|
||||||
|
|
||||||
|
answer_str = str(final_result)
|
||||||
|
question_str = "\n".join(lines)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"question": question_str,
|
||||||
|
"answer": answer_str,
|
||||||
|
"metadata": {
|
||||||
|
"expression": expression_for_display,
|
||||||
|
"assignments": assignments,
|
||||||
|
"term_strings": term_strings,
|
||||||
|
"final_gate": final_gate_name,
|
||||||
|
"inputs": inputs_list,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
def score_answer(self, answer: Optional[str], entry: Dict[str, Any]) -> float:
|
||||||
|
if answer is None or len(answer) == 0:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
oracle_answer = entry["answer"]
|
||||||
|
if oracle_answer == answer:
|
||||||
|
return 1.0
|
||||||
|
elif oracle_answer == answer.strip():
|
||||||
|
return len(oracle_answer) / len(answer)
|
||||||
|
|
||||||
|
return 0.01
|
||||||
|
|
||||||
|
|
||||||
|
register_dataset("circuit_logic", CircuitLogicDataset, CircuitLogicConfig)
|
||||||
|
|
@ -8,12 +8,12 @@ SYSTEM_PROMPTS = {
|
||||||
"DeepSeekZero": """A conversation between User and Assistant. The user asks a question, and the Assistant solves it.
|
"DeepSeekZero": """A conversation between User and Assistant. The user asks a question, and the Assistant solves it.
|
||||||
The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think> reasoning process here </think>
|
The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think> reasoning process here </think>
|
||||||
<answer>answer here</answer>
|
<answer>answer here</answer>
|
||||||
Do not explain your reasoning inside the answer tags, provide only the final answer.
|
Do not explain your reasoning inside the answer tags, provide only the final answer. When an example is provided, you should strictly follow the format of the output/answer in that example.
|
||||||
""",
|
""",
|
||||||
"default": """Given a problem, your task is to answer the question by thinking step-by-step in a clear and specific manner.
|
"default": """Given a problem, your task is to answer the question by thinking step-by-step in a clear and specific manner.
|
||||||
Once you have thought about the reasoning process, provide the answer in the following format:
|
Once you have thought about the reasoning process, provide the answer in the following format:
|
||||||
<answer>answer here</answer>
|
<answer>answer here</answer>
|
||||||
Do not explain your reasoning inside the answer tags, provide only the final answer.
|
Do not explain your reasoning inside the answer tags, provide only the final answer. When an example is provided, you should strictly follow the format of the output/answer in that example.
|
||||||
""",
|
""",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -137,3 +137,32 @@ def test_arc_agi_dataset_modes():
|
||||||
both_ds = ArcAgiDataset(both_config)
|
both_ds = ArcAgiDataset(both_config)
|
||||||
assert len(both_ds._task_ids) > len(train_ds._task_ids)
|
assert len(both_ds._task_ids) > len(train_ds._task_ids)
|
||||||
assert len(both_ds._task_ids) > len(eval_ds._task_ids)
|
assert len(both_ds._task_ids) > len(eval_ds._task_ids)
|
||||||
|
|
||||||
|
|
||||||
|
def test_arc_agi_shuffled_order():
|
||||||
|
config_unshuffled = ArcAgiConfig(
|
||||||
|
shuffle_example_order=False,
|
||||||
|
use_train=True,
|
||||||
|
use_eval=False,
|
||||||
|
rotations=[],
|
||||||
|
mirrors=[],
|
||||||
|
use_color_permutation=False,
|
||||||
|
size=3,
|
||||||
|
seed=42,
|
||||||
|
)
|
||||||
|
config_shuffled = ArcAgiConfig(
|
||||||
|
shuffle_example_order=True,
|
||||||
|
use_train=True,
|
||||||
|
use_eval=False,
|
||||||
|
rotations=[],
|
||||||
|
mirrors=[],
|
||||||
|
use_color_permutation=False,
|
||||||
|
size=3,
|
||||||
|
seed=42,
|
||||||
|
)
|
||||||
|
unshuffled = ArcAgiDataset(config_unshuffled)
|
||||||
|
shuffled = ArcAgiDataset(config_shuffled)
|
||||||
|
|
||||||
|
for a, b in zip(shuffled, unshuffled):
|
||||||
|
assert a["question"] != b["question"]
|
||||||
|
assert a["answer"] == b["answer"]
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,3 @@
|
||||||
from random import Random
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from reasoning_gym.arithmetic.basic_arithmetic import (
|
from reasoning_gym.arithmetic.basic_arithmetic import (
|
||||||
|
|
@ -64,11 +62,19 @@ def test_arithmetic_dataset_format_styles():
|
||||||
max_digits=2,
|
max_digits=2,
|
||||||
)
|
)
|
||||||
dataset = BasicArithmeticDataset(config)
|
dataset = BasicArithmeticDataset(config)
|
||||||
assert all(item["question"].endswith("=") for item in dataset)
|
assert all(item["question"].strip().endswith(".") for item in dataset)
|
||||||
|
|
||||||
config.format_style = "natural"
|
config = BasicArithmeticDatasetConfig(
|
||||||
|
size=10,
|
||||||
|
seed=42,
|
||||||
|
format_style="natural",
|
||||||
|
min_terms=2,
|
||||||
|
max_terms=3, # Keep expressions simple for testing
|
||||||
|
min_digits=1,
|
||||||
|
max_digits=2,
|
||||||
|
)
|
||||||
dataset = BasicArithmeticDataset(config)
|
dataset = BasicArithmeticDataset(config)
|
||||||
assert all("=" not in item["question"] for item in dataset)
|
assert all(item["question"].strip().endswith(".") for item in dataset)
|
||||||
|
|
||||||
|
|
||||||
def test_arithmetic_dataset_iteration():
|
def test_arithmetic_dataset_iteration():
|
||||||
|
|
|
||||||
|
|
@ -15,6 +15,14 @@ def test_binary_matrix_config_validation():
|
||||||
config = BinaryMatrixConfig(max_n=0) # Zero not allowed
|
config = BinaryMatrixConfig(max_n=0) # Zero not allowed
|
||||||
config.validate()
|
config.validate()
|
||||||
|
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
config = BinaryMatrixConfig(min_n=-1) # Negative not allowed
|
||||||
|
config.validate()
|
||||||
|
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
config = BinaryMatrixConfig(min_n=0) # Zero not allowed
|
||||||
|
config.validate()
|
||||||
|
|
||||||
with pytest.raises(AssertionError):
|
with pytest.raises(AssertionError):
|
||||||
config = BinaryMatrixConfig(p_zero=0) # <= 0 not allowed
|
config = BinaryMatrixConfig(p_zero=0) # <= 0 not allowed
|
||||||
config.validate()
|
config.validate()
|
||||||
|
|
@ -98,3 +106,18 @@ def test_binary_matrix_answer():
|
||||||
# Empty matrix
|
# Empty matrix
|
||||||
matrix = [[0, 0, 0], [0, 0, 0], [0, 0, 0]]
|
matrix = [[0, 0, 0], [0, 0, 0], [0, 0, 0]]
|
||||||
assert dataset._get_distances(matrix) == [[0, 0, 0], [0, 0, 0], [0, 0, 0]]
|
assert dataset._get_distances(matrix) == [[0, 0, 0], [0, 0, 0], [0, 0, 0]]
|
||||||
|
|
||||||
|
# String representation of answer
|
||||||
|
answer = "0 0 0\n0 1 0\n1 2 1"
|
||||||
|
entry = {"answer": "0 0 0\n0 1 0\n1 2 1"}
|
||||||
|
assert dataset.score_answer(answer, entry) == 1.0
|
||||||
|
|
||||||
|
# Answer is a python list (partially correct answer)
|
||||||
|
answer = "[[0, 0, 0], [0, 1, 0], [1, 2, 1]]"
|
||||||
|
entry = {"answer": "0 0 0\n0 1 0\n1 2 1"}
|
||||||
|
assert dataset.score_answer(answer, entry) == 0.5
|
||||||
|
|
||||||
|
# Answer is null
|
||||||
|
answer = None
|
||||||
|
entry = {"answer": "0 0 0\n0 1 0\n1 2 1"}
|
||||||
|
assert dataset.score_answer(answer, entry) == 0.0
|
||||||
|
|
|
||||||
224
tests/test_circuit_logic.py
Normal file
224
tests/test_circuit_logic.py
Normal file
|
|
@ -0,0 +1,224 @@
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from reasoning_gym.logic import CircuitLogicConfig, CircuitLogicDataset
|
||||||
|
|
||||||
|
|
||||||
|
def test_circuit_logic_config_validation():
|
||||||
|
"""Test that invalid configs raise appropriate errors"""
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
config = CircuitLogicConfig(min_inputs=3, max_inputs=2)
|
||||||
|
config.validate()
|
||||||
|
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
config = CircuitLogicConfig(num_terms=0)
|
||||||
|
config.validate()
|
||||||
|
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
config = CircuitLogicConfig(neg_prob=-0.1)
|
||||||
|
config.validate()
|
||||||
|
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
config = CircuitLogicConfig(neg_prob=1.1)
|
||||||
|
config.validate()
|
||||||
|
|
||||||
|
|
||||||
|
def test_circuit_logic_deterministic():
|
||||||
|
"""Test that dataset generates same items with same seed"""
|
||||||
|
config = CircuitLogicConfig(seed=42, size=10)
|
||||||
|
dataset1 = CircuitLogicDataset(config)
|
||||||
|
dataset2 = CircuitLogicDataset(config)
|
||||||
|
|
||||||
|
for i in range(len(dataset1)):
|
||||||
|
assert dataset1[i] == dataset2[i]
|
||||||
|
|
||||||
|
|
||||||
|
def test_circuit_logic_items():
|
||||||
|
"""Test basic properties of generated items"""
|
||||||
|
config = CircuitLogicConfig(num_terms=3, min_inputs=2, max_inputs=3, neg_prob=0.3, size=50, seed=42)
|
||||||
|
dataset = CircuitLogicDataset(config)
|
||||||
|
|
||||||
|
for i in range(len(dataset)):
|
||||||
|
item = dataset[i]
|
||||||
|
assert isinstance(item, dict)
|
||||||
|
assert "question" in item
|
||||||
|
assert "answer" in item
|
||||||
|
assert "metadata" in item
|
||||||
|
|
||||||
|
# Verify metadata contents
|
||||||
|
metadata = item["metadata"]
|
||||||
|
assert "expression" in metadata
|
||||||
|
assert "assignments" in metadata
|
||||||
|
assert "final_gate" in metadata
|
||||||
|
assert "inputs" in metadata
|
||||||
|
|
||||||
|
# Verify answer is binary
|
||||||
|
assert item["answer"] in ("0", "1")
|
||||||
|
|
||||||
|
# Verify assignments are binary
|
||||||
|
for input_name, value in metadata["assignments"].items():
|
||||||
|
assert value in (0, 1)
|
||||||
|
|
||||||
|
# Verify final gate is valid
|
||||||
|
assert metadata["final_gate"] in ("OR", "NOR", "XOR", "AND")
|
||||||
|
|
||||||
|
# Verify inputs list matches assignments
|
||||||
|
assert set(metadata["inputs"]) == set(metadata["assignments"].keys())
|
||||||
|
|
||||||
|
|
||||||
|
def test_circuit_logic_expression_validity():
|
||||||
|
"""Test that generated expressions follow logical circuit rules"""
|
||||||
|
config = CircuitLogicConfig(
|
||||||
|
num_terms=2, min_inputs=2, max_inputs=2, neg_prob=0.0, size=20, seed=42 # Disable negation for simpler testing
|
||||||
|
)
|
||||||
|
dataset = CircuitLogicDataset(config)
|
||||||
|
|
||||||
|
for i in range(len(dataset)):
|
||||||
|
item = dataset[i]
|
||||||
|
metadata = item["metadata"]
|
||||||
|
|
||||||
|
# Expression should contain valid operators
|
||||||
|
expr = metadata["expression"]
|
||||||
|
assert any(op in expr for op in ("&", "↑", "⊕", "+", "↓"))
|
||||||
|
|
||||||
|
# Input names should be valid Excel-style names
|
||||||
|
for input_name in metadata["inputs"]:
|
||||||
|
assert input_name.isalpha()
|
||||||
|
assert input_name.isupper()
|
||||||
|
|
||||||
|
|
||||||
|
def test_circuit_logic_answer_verification():
|
||||||
|
"""Test that answers match logical evaluation of circuits"""
|
||||||
|
config = CircuitLogicConfig(num_terms=2, min_inputs=2, max_inputs=2, size=20, seed=42)
|
||||||
|
dataset = CircuitLogicDataset(config)
|
||||||
|
|
||||||
|
def evaluate_term(term: str, assignments: dict) -> int:
|
||||||
|
"""Evaluate a single term with given assignments"""
|
||||||
|
if "↑" in term: # NAND
|
||||||
|
parts = term.split("↑")
|
||||||
|
values = []
|
||||||
|
for p in parts:
|
||||||
|
if p.endswith("'"):
|
||||||
|
values.append(1 - assignments[p[:-1]])
|
||||||
|
else:
|
||||||
|
values.append(assignments[p])
|
||||||
|
return 0 if all(v == 1 for v in values) else 1
|
||||||
|
elif "&" in term: # AND
|
||||||
|
parts = term.split("&")
|
||||||
|
values = []
|
||||||
|
for p in parts:
|
||||||
|
if p.endswith("'"):
|
||||||
|
values.append(1 - assignments[p[:-1]])
|
||||||
|
else:
|
||||||
|
values.append(assignments[p])
|
||||||
|
return 1 if all(v == 1 for v in values) else 0
|
||||||
|
elif "⊕" in term: # XOR
|
||||||
|
parts = term.split("⊕")
|
||||||
|
values = []
|
||||||
|
for p in parts:
|
||||||
|
if p.endswith("'"):
|
||||||
|
values.append(1 - assignments[p[:-1]])
|
||||||
|
else:
|
||||||
|
values.append(assignments[p])
|
||||||
|
return sum(values) % 2
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown operator in term: {term}")
|
||||||
|
|
||||||
|
def evaluate_final_gate(gate_type: str, term_values: list) -> int:
|
||||||
|
"""Evaluate the final gate with given term values"""
|
||||||
|
if gate_type == "AND":
|
||||||
|
return 1 if all(v == 1 for v in term_values) else 0
|
||||||
|
elif gate_type == "OR":
|
||||||
|
return 1 if any(v == 1 for v in term_values) else 0
|
||||||
|
elif gate_type == "XOR":
|
||||||
|
return sum(term_values) % 2
|
||||||
|
elif gate_type == "NOR":
|
||||||
|
return 0 if any(v == 1 for v in term_values) else 1
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown gate type: {gate_type}")
|
||||||
|
|
||||||
|
for i in range(len(dataset)):
|
||||||
|
item = dataset[i]
|
||||||
|
metadata = item["metadata"]
|
||||||
|
assignments = metadata["assignments"]
|
||||||
|
final_gate = metadata["final_gate"]
|
||||||
|
term_strings = metadata["term_strings"]
|
||||||
|
|
||||||
|
# First evaluate each term
|
||||||
|
term_values = [evaluate_term(term, assignments) for term in term_strings]
|
||||||
|
|
||||||
|
# Then combine terms with final gate
|
||||||
|
expected = evaluate_final_gate(final_gate, term_values)
|
||||||
|
|
||||||
|
# Compare with actual result
|
||||||
|
result = int(item["answer"])
|
||||||
|
assert (
|
||||||
|
result == expected
|
||||||
|
), f"Item {i}: Expected {expected} but got {result} for terms {term_strings} with assignments {assignments} and final gate {final_gate}"
|
||||||
|
|
||||||
|
|
||||||
|
def test_circuit_logic_ascii_diagram():
|
||||||
|
"""Test properties of the ASCII circuit diagram"""
|
||||||
|
config = CircuitLogicConfig(num_terms=2, min_inputs=2, max_inputs=2, size=10, seed=42)
|
||||||
|
dataset = CircuitLogicDataset(config)
|
||||||
|
|
||||||
|
for i in range(len(dataset)):
|
||||||
|
item = dataset[i]
|
||||||
|
|
||||||
|
# Split question to get diagram
|
||||||
|
parts = item["question"].split("\n")
|
||||||
|
diagram_start = parts.index("Below is a randomly generated logic circuit.") + 2
|
||||||
|
diagram_end = parts.index("", diagram_start)
|
||||||
|
diagram = parts[diagram_start:diagram_end]
|
||||||
|
|
||||||
|
# Basic diagram validation
|
||||||
|
assert len(diagram) > 0
|
||||||
|
assert all(len(row) > 0 for row in diagram)
|
||||||
|
|
||||||
|
# Check for required circuit elements
|
||||||
|
diagram_str = "\n".join(diagram)
|
||||||
|
assert "OUT" in diagram_str
|
||||||
|
assert any(gate in diagram_str for gate in ("&", "↑", "⊕"))
|
||||||
|
|
||||||
|
# Verify input labels
|
||||||
|
for input_name in item["metadata"]["inputs"]:
|
||||||
|
assert f"{input_name}:" in diagram_str
|
||||||
|
|
||||||
|
|
||||||
|
def test_circuit_logic_scoring():
|
||||||
|
"""Test the answer scoring mechanism"""
|
||||||
|
config = CircuitLogicConfig(size=5, seed=42)
|
||||||
|
dataset = CircuitLogicDataset(config)
|
||||||
|
|
||||||
|
item = dataset[0]
|
||||||
|
|
||||||
|
# Correct answer should score 1.0
|
||||||
|
assert dataset.score_answer(item["answer"], item) == 1.0
|
||||||
|
|
||||||
|
# Wrong answer should score lower
|
||||||
|
wrong_answer = "1" if item["answer"] == "0" else "0"
|
||||||
|
assert dataset.score_answer(wrong_answer, item) < 1.0
|
||||||
|
|
||||||
|
# None or empty answer should score 0.0
|
||||||
|
assert dataset.score_answer(None, item) == 0.0
|
||||||
|
assert dataset.score_answer("", item) == 0.0 # Empty string should score 0.0 like None
|
||||||
|
|
||||||
|
|
||||||
|
def test_circuit_logic_iteration():
|
||||||
|
"""Test that iteration works correctly"""
|
||||||
|
config = CircuitLogicConfig(size=5, seed=42)
|
||||||
|
dataset = CircuitLogicDataset(config)
|
||||||
|
|
||||||
|
# Test manual iteration
|
||||||
|
items = []
|
||||||
|
for item in dataset:
|
||||||
|
items.append(item)
|
||||||
|
assert len(items) == config.size
|
||||||
|
|
||||||
|
# Test list conversion
|
||||||
|
items = list(dataset)
|
||||||
|
assert len(items) == config.size
|
||||||
|
|
||||||
|
# Test multiple iterations yield same items
|
||||||
|
first_items = list(dataset)
|
||||||
|
second_items = list(dataset)
|
||||||
|
assert first_items == second_items
|
||||||
105
tests/test_cryptarithm.py
Normal file
105
tests/test_cryptarithm.py
Normal file
|
|
@ -0,0 +1,105 @@
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from reasoning_gym import create_dataset
|
||||||
|
from reasoning_gym.algorithmic.cryptarithm import CryptarithmConfig, CryptarithmDataset
|
||||||
|
|
||||||
|
|
||||||
|
def test_cryptarithm_generation():
|
||||||
|
dataset = create_dataset("cryptarithm", seed=42, size=10)
|
||||||
|
assert isinstance(dataset, CryptarithmDataset)
|
||||||
|
unique_number = set()
|
||||||
|
for item in dataset:
|
||||||
|
# Check required keys exist
|
||||||
|
assert "question" in item
|
||||||
|
assert "answer" in item
|
||||||
|
assert "metadata" in item
|
||||||
|
|
||||||
|
# Validate question format
|
||||||
|
question = item["question"]
|
||||||
|
assert "Solve this cryptarithm:" in question
|
||||||
|
assert "Each letter stands for a unique digit (0-9)" in question
|
||||||
|
|
||||||
|
# Validate metadata structure
|
||||||
|
metadata = item["metadata"]
|
||||||
|
assert "letters" in metadata
|
||||||
|
assert "letter_to_digit" in metadata
|
||||||
|
assert "words_letters" in metadata
|
||||||
|
assert "result_letters" in metadata
|
||||||
|
assert "word_values" in metadata
|
||||||
|
assert "sum_number" in metadata
|
||||||
|
|
||||||
|
# Validate letter to digit mapping
|
||||||
|
letter_to_digit = metadata["letter_to_digit"]
|
||||||
|
used_digits = set(letter_to_digit.values())
|
||||||
|
assert len(used_digits) == len(letter_to_digit), "Each letter should map to a unique digit"
|
||||||
|
assert all(0 <= digit <= 9 for digit in used_digits), "All digits should be between 0 and 9"
|
||||||
|
|
||||||
|
# Validate the arithmetic
|
||||||
|
word_values = metadata["word_values"]
|
||||||
|
result_value = metadata["sum_number"]
|
||||||
|
assert sum(word_values) == result_value, "Sum of word values should equal result value"
|
||||||
|
unique_number.add(result_value)
|
||||||
|
|
||||||
|
assert len(unique_number) == len(dataset)
|
||||||
|
|
||||||
|
|
||||||
|
def test_cryptarithm_config():
|
||||||
|
# Test invalid configs raise assertions
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
dataset = create_dataset("cryptarithm", min_words=1) # min_words must be >= 2
|
||||||
|
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
dataset = create_dataset("cryptarithm", min_words=4, max_words=3) # min must be <= max
|
||||||
|
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
dataset = create_dataset("cryptarithm", size=0) # size must be positive
|
||||||
|
|
||||||
|
|
||||||
|
def test_leading_zero_constraint():
|
||||||
|
# Test with leading zeros not allowed
|
||||||
|
dataset = create_dataset("cryptarithm", seed=42, size=5, allow_leading_zero=False, max_words=10, min_words=5)
|
||||||
|
|
||||||
|
for item in dataset:
|
||||||
|
# print(item['question'])
|
||||||
|
metadata = item["metadata"]
|
||||||
|
letter_to_digit = metadata["letter_to_digit"]
|
||||||
|
words_letters = metadata["words_letters"]
|
||||||
|
result_letters = metadata["result_letters"]
|
||||||
|
|
||||||
|
# Check leading letters of all words and result
|
||||||
|
leading_letters = [word[0] for word in words_letters] + [result_letters[0]]
|
||||||
|
for letter in leading_letters:
|
||||||
|
assert letter_to_digit[letter] != 0, "Leading letters cannot be zero when allow_leading_zero=False"
|
||||||
|
|
||||||
|
|
||||||
|
def test_deterministic_generation():
|
||||||
|
dataset1 = create_dataset("cryptarithm", seed=42, size=5)
|
||||||
|
dataset2 = create_dataset("cryptarithm", seed=42, size=5)
|
||||||
|
|
||||||
|
for i in range(5):
|
||||||
|
assert dataset1[i]["question"] == dataset2[i]["question"]
|
||||||
|
assert dataset1[i]["answer"] == dataset2[i]["answer"]
|
||||||
|
assert dataset1[i]["metadata"] == dataset2[i]["metadata"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_word_length_constraints():
|
||||||
|
dataset = create_dataset("cryptarithm", seed=42, size=10)
|
||||||
|
|
||||||
|
for item in dataset:
|
||||||
|
metadata = item["metadata"]
|
||||||
|
words_letters = metadata["words_letters"]
|
||||||
|
|
||||||
|
# Check each word is between 3-5 letters as specified in the code
|
||||||
|
for word in words_letters:
|
||||||
|
assert 3 <= len(word) <= 5, "Each word should be between 3 and 5 letters long"
|
||||||
|
|
||||||
|
|
||||||
|
def test_max_letters_constraint():
|
||||||
|
dataset = create_dataset("cryptarithm", seed=42, size=10)
|
||||||
|
|
||||||
|
for item in dataset:
|
||||||
|
metadata = item["metadata"]
|
||||||
|
letter_to_digit = metadata["letter_to_digit"]
|
||||||
|
|
||||||
|
# Check total unique letters doesn't exceed 10 (digits 0-9)
|
||||||
|
assert len(letter_to_digit) <= 10, "Total unique letters should not exceed 10"
|
||||||
188
tests/test_futoshiki.py
Normal file
188
tests/test_futoshiki.py
Normal file
|
|
@ -0,0 +1,188 @@
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from reasoning_gym.games import FutoshikiConfig, FutoshikiDataset
|
||||||
|
|
||||||
|
|
||||||
|
def test_futoshiki_config_validation():
|
||||||
|
"""Test that invalid configs raise appropriate errors"""
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
config = FutoshikiConfig(board_size=3) # Too small
|
||||||
|
config.validate()
|
||||||
|
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
config = FutoshikiConfig(board_size=10) # Too large
|
||||||
|
config.validate()
|
||||||
|
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
config = FutoshikiConfig(difficulty=-1) # Invalid difficulty
|
||||||
|
config.validate()
|
||||||
|
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
config = FutoshikiConfig(difficulty=4) # Invalid difficulty
|
||||||
|
config.validate()
|
||||||
|
|
||||||
|
|
||||||
|
def test_futoshiki_deterministic():
|
||||||
|
"""Test that dataset generates same puzzles with same seed"""
|
||||||
|
config = FutoshikiConfig(seed=42, size=10)
|
||||||
|
dataset1 = FutoshikiDataset(config)
|
||||||
|
dataset2 = FutoshikiDataset(config)
|
||||||
|
|
||||||
|
for i in range(len(dataset1)):
|
||||||
|
assert dataset1[i] == dataset2[i]
|
||||||
|
|
||||||
|
|
||||||
|
def test_futoshiki_items():
|
||||||
|
"""Test basic properties of generated items"""
|
||||||
|
config = FutoshikiConfig(board_size=4, difficulty=1, size=10, seed=42)
|
||||||
|
dataset = FutoshikiDataset(config)
|
||||||
|
|
||||||
|
for i in range(len(dataset)):
|
||||||
|
item = dataset[i]
|
||||||
|
assert isinstance(item, dict)
|
||||||
|
assert "question" in item
|
||||||
|
assert "answer" in item
|
||||||
|
assert "metadata" in item
|
||||||
|
|
||||||
|
# Verify metadata contents
|
||||||
|
metadata = item["metadata"]
|
||||||
|
assert "puzzle" in metadata
|
||||||
|
assert "solution" in metadata
|
||||||
|
assert "constraints" in metadata
|
||||||
|
assert "board_size" in metadata
|
||||||
|
assert "difficulty" in metadata
|
||||||
|
|
||||||
|
# Verify board dimensions
|
||||||
|
puzzle = metadata["puzzle"]
|
||||||
|
solution = metadata["solution"]
|
||||||
|
assert len(puzzle) == config.board_size
|
||||||
|
assert len(solution) == config.board_size
|
||||||
|
for row in puzzle:
|
||||||
|
assert len(row) == config.board_size
|
||||||
|
for row in solution:
|
||||||
|
assert len(row) == config.board_size
|
||||||
|
|
||||||
|
# Verify constraints format
|
||||||
|
constraints = metadata["constraints"]
|
||||||
|
for ((r1, c1), (r2, c2)), rel in constraints.items():
|
||||||
|
assert 0 <= r1 < config.board_size
|
||||||
|
assert 0 <= c1 < config.board_size
|
||||||
|
assert 0 <= r2 < config.board_size
|
||||||
|
assert 0 <= c2 < config.board_size
|
||||||
|
assert rel in ("<", ">")
|
||||||
|
|
||||||
|
|
||||||
|
def test_futoshiki_solution_validity():
|
||||||
|
"""Test that solutions are valid according to Futoshiki rules"""
|
||||||
|
config = FutoshikiConfig(board_size=4, difficulty=1, size=10, seed=42)
|
||||||
|
dataset = FutoshikiDataset(config)
|
||||||
|
|
||||||
|
def is_valid_solution(solution, board_size, constraints):
|
||||||
|
# Check rows
|
||||||
|
for row in solution:
|
||||||
|
if sorted(row) != list(range(1, board_size + 1)):
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check columns
|
||||||
|
for col in range(board_size):
|
||||||
|
column = [solution[row][col] for row in range(board_size)]
|
||||||
|
if sorted(column) != list(range(1, board_size + 1)):
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check constraints
|
||||||
|
for ((r1, c1), (r2, c2)), rel in constraints.items():
|
||||||
|
v1, v2 = solution[r1][c1], solution[r2][c2]
|
||||||
|
if rel == "<" and not (v1 < v2):
|
||||||
|
return False
|
||||||
|
if rel == ">" and not (v1 > v2):
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
for i in range(len(dataset)):
|
||||||
|
item = dataset[i]
|
||||||
|
metadata = item["metadata"]
|
||||||
|
solution = metadata["solution"]
|
||||||
|
constraints = metadata["constraints"]
|
||||||
|
|
||||||
|
assert is_valid_solution(solution, config.board_size, constraints)
|
||||||
|
|
||||||
|
|
||||||
|
def test_futoshiki_puzzle_solvability():
|
||||||
|
"""Test that generated puzzles are solvable and have unique solutions"""
|
||||||
|
config = FutoshikiConfig(board_size=4, difficulty=1, size=5, seed=42)
|
||||||
|
dataset = FutoshikiDataset(config)
|
||||||
|
|
||||||
|
for i in range(len(dataset)):
|
||||||
|
item = dataset[i]
|
||||||
|
metadata = item["metadata"]
|
||||||
|
puzzle = metadata["puzzle"]
|
||||||
|
constraints = metadata["constraints"]
|
||||||
|
|
||||||
|
# Verify puzzle has exactly one solution
|
||||||
|
assert dataset.count_solutions(puzzle, constraints, limit=2) == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_futoshiki_difficulty_levels():
|
||||||
|
"""Test that different difficulty levels affect puzzle complexity"""
|
||||||
|
size = 5
|
||||||
|
board_size = 4
|
||||||
|
seeds = [42, 43, 44] # Test multiple seeds for robustness
|
||||||
|
|
||||||
|
def count_clues(puzzle):
|
||||||
|
return sum(cell != 0 for row in puzzle for cell in row)
|
||||||
|
|
||||||
|
def count_constraints(constraints):
|
||||||
|
return len(constraints)
|
||||||
|
|
||||||
|
for seed in seeds:
|
||||||
|
clues_by_difficulty = []
|
||||||
|
constraints_by_difficulty = []
|
||||||
|
|
||||||
|
for difficulty in range(4): # 0 to 3
|
||||||
|
config = FutoshikiConfig(board_size=board_size, difficulty=difficulty, size=size, seed=seed)
|
||||||
|
dataset = FutoshikiDataset(config)
|
||||||
|
|
||||||
|
avg_clues = sum(count_clues(item["metadata"]["puzzle"]) for item in dataset) / size
|
||||||
|
avg_constraints = sum(count_constraints(item["metadata"]["constraints"]) for item in dataset) / size
|
||||||
|
|
||||||
|
clues_by_difficulty.append(avg_clues)
|
||||||
|
constraints_by_difficulty.append(avg_constraints)
|
||||||
|
|
||||||
|
# Higher difficulty should generally mean fewer clues and/or more constraints
|
||||||
|
assert all(clues_by_difficulty[i] >= clues_by_difficulty[i + 1] for i in range(len(clues_by_difficulty) - 1))
|
||||||
|
assert all(
|
||||||
|
constraints_by_difficulty[i] <= constraints_by_difficulty[i + 1]
|
||||||
|
for i in range(len(constraints_by_difficulty) - 1)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_futoshiki_answer_scoring():
|
||||||
|
"""Test the answer scoring mechanism"""
|
||||||
|
config = FutoshikiConfig(board_size=4, difficulty=0, size=5, seed=42)
|
||||||
|
dataset = FutoshikiDataset(config)
|
||||||
|
|
||||||
|
for item in dataset:
|
||||||
|
# Correct answer should score 1.0
|
||||||
|
assert dataset.score_answer(item["answer"], item) == 1.0
|
||||||
|
|
||||||
|
# Wrong answer should score lower
|
||||||
|
wrong_answer = item["answer"].replace("1", "2")
|
||||||
|
assert dataset.score_answer(wrong_answer, item) < 1.0
|
||||||
|
|
||||||
|
# None or empty answer should score 0.0
|
||||||
|
assert dataset.score_answer(None, item) == 0.0
|
||||||
|
assert dataset.score_answer("", item) == 0.0
|
||||||
|
|
||||||
|
answer = item["answer"]
|
||||||
|
white_space_mismatch = answer.replace(" ", " ")
|
||||||
|
assert dataset.score_answer(white_space_mismatch, item) == 0.9
|
||||||
|
|
||||||
|
anwser_with_additional_text = "This is an anwser " + answer + "\nwith surrounding text."
|
||||||
|
assert 0 < dataset.score_answer(anwser_with_additional_text, item) < 0.9
|
||||||
|
|
||||||
|
partially_correct = anwser_with_additional_text.replace("1", "2")
|
||||||
|
assert dataset.score_answer(partially_correct, item) > 0.1
|
||||||
|
|
||||||
|
bad_answer = "\n".join(anwser_with_additional_text.split("\n")[::-1])
|
||||||
|
assert dataset.score_answer(bad_answer, item) < 0.1
|
||||||
|
|
@ -122,6 +122,11 @@ def test_nqueens_score_answer():
|
||||||
# Test None answer gets score 0.0
|
# Test None answer gets score 0.0
|
||||||
assert dataset.score_answer(None, item) == 0.0
|
assert dataset.score_answer(None, item) == 0.0
|
||||||
|
|
||||||
|
# Test python list representation of board (partial solution)
|
||||||
|
answer = "[['_', 'Q', '_', '_'], ['_', '_', '_', 'Q'], ['Q', '_', '_', '_'], ['_', '_', 'Q', '_']]"
|
||||||
|
entry = {"metadata": {"valid_answers": {"_ Q _ _\n_ _ _ Q\nQ _ _ _\n_ _ Q _"}}}
|
||||||
|
assert dataset.score_answer(answer, entry) == 0.5
|
||||||
|
|
||||||
|
|
||||||
def is_valid_solution(board: list[list[str]]) -> bool:
|
def is_valid_solution(board: list[list[str]]) -> bool:
|
||||||
"""Helper function to verify N Queens solution validity"""
|
"""Helper function to verify N Queens solution validity"""
|
||||||
|
|
|
||||||
111
tests/test_palindrome_partitioning.py
Normal file
111
tests/test_palindrome_partitioning.py
Normal file
|
|
@ -0,0 +1,111 @@
|
||||||
|
"""Tests for Palindrome Partitioning questions generation"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
|
||||||
|
from reasoning_gym.algorithmic.palindrome_partitioning import (
|
||||||
|
PalindromePartitioningConfig,
|
||||||
|
PalindromePartitioningDataset,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_palindrome_partitioning_dataset_deterministic():
|
||||||
|
"""Test that dataset generates same items with same seed"""
|
||||||
|
config = PalindromePartitioningConfig(seed=42, size=10)
|
||||||
|
dataset1 = PalindromePartitioningDataset(config)
|
||||||
|
dataset2 = PalindromePartitioningDataset(config)
|
||||||
|
|
||||||
|
for i in range(len(dataset1)):
|
||||||
|
assert dataset1[i] == dataset2[i]
|
||||||
|
|
||||||
|
|
||||||
|
def test_palindrome_partitioning_dataset_items():
|
||||||
|
"""Test basic properties of generated items"""
|
||||||
|
config = PalindromePartitioningConfig(size=10, seed=42)
|
||||||
|
dataset = PalindromePartitioningDataset(config)
|
||||||
|
|
||||||
|
for i in range(len(dataset)):
|
||||||
|
item = dataset[i]
|
||||||
|
# Check item structure
|
||||||
|
assert isinstance(item, dict)
|
||||||
|
assert "question" in item
|
||||||
|
assert "answer" in item
|
||||||
|
assert "metadata" in item
|
||||||
|
|
||||||
|
# Check metadata
|
||||||
|
assert "string" in item["metadata"]
|
||||||
|
assert "solution" in item["metadata"]
|
||||||
|
string = item["metadata"]["string"]
|
||||||
|
solution = item["metadata"]["solution"]
|
||||||
|
|
||||||
|
# Verify string is not empty
|
||||||
|
assert len(string) > 0
|
||||||
|
|
||||||
|
# At least one partitioning exists (each letter is a palindrome)
|
||||||
|
assert len(solution) >= 1
|
||||||
|
|
||||||
|
# Verify each partitioning reconstructs the original string
|
||||||
|
assert all(len(partitioning) > 0 for partitioning in solution)
|
||||||
|
assert all("".join(partitioning) == string for partitioning in solution)
|
||||||
|
|
||||||
|
|
||||||
|
def test_palindrome_partitioning_dataset_iteration():
|
||||||
|
"""Test that iteration respects dataset size"""
|
||||||
|
config = PalindromePartitioningConfig(size=5, seed=42)
|
||||||
|
dataset = PalindromePartitioningDataset(config)
|
||||||
|
|
||||||
|
items = list(dataset)
|
||||||
|
assert len(items) == config.size
|
||||||
|
|
||||||
|
# Test multiple iterations yield same items
|
||||||
|
assert items == list(dataset)
|
||||||
|
|
||||||
|
|
||||||
|
def test_palindrome_partitioning_answer():
|
||||||
|
"""Test the _palindrome_partitioning method"""
|
||||||
|
config = PalindromePartitioningConfig(seed=42)
|
||||||
|
dataset = PalindromePartitioningDataset(config)
|
||||||
|
|
||||||
|
# General use case
|
||||||
|
word = "afternoon"
|
||||||
|
correct = [
|
||||||
|
["a", "f", "t", "e", "r", "n", "o", "o", "n"],
|
||||||
|
["a", "f", "t", "e", "r", "n", "oo", "n"],
|
||||||
|
["a", "f", "t", "e", "r", "noon"],
|
||||||
|
]
|
||||||
|
assert json.dumps(dataset._palindrome_partitioning(word)) == json.dumps(correct)
|
||||||
|
|
||||||
|
# Single letter word
|
||||||
|
word = "a"
|
||||||
|
correct = [["a"]]
|
||||||
|
assert json.dumps(dataset._palindrome_partitioning(word)) == json.dumps(correct)
|
||||||
|
|
||||||
|
# Empty string
|
||||||
|
word = ""
|
||||||
|
correct = []
|
||||||
|
assert json.dumps(dataset._palindrome_partitioning(word)) == json.dumps(correct)
|
||||||
|
|
||||||
|
|
||||||
|
def test_palindrome_partitioning_score_answer():
|
||||||
|
"""Test the score_answer method"""
|
||||||
|
config = PalindromePartitioningConfig(seed=42)
|
||||||
|
dataset = PalindromePartitioningDataset(config)
|
||||||
|
|
||||||
|
# Verify the scoring function is permutation invariant
|
||||||
|
answer = json.dumps([["n", "o", "o", "n"], ["no", "on"], ["noon"]])
|
||||||
|
item = {"metadata": {"solution": [["no", "on"], ["noon"], ["n", "o", "o", "n"]]}}
|
||||||
|
assert dataset.score_answer(answer, item) == 1
|
||||||
|
|
||||||
|
# Verify the score is 0.01 when incorrect
|
||||||
|
answer = json.dumps([["n", "o", "o", "n"], ["no", "on"]])
|
||||||
|
item = {"metadata": {"solution": [["no", "on"], ["noon"], ["n", "o", "o", "n"]]}}
|
||||||
|
assert dataset.score_answer(answer, item) == 0.01
|
||||||
|
|
||||||
|
# Verify the score is 0 when answer is None
|
||||||
|
answer = None
|
||||||
|
item = {"metadata": {"solution": [["no", "on"], ["noon"], ["n", "o", "o", "n"]]}}
|
||||||
|
assert dataset.score_answer(answer, item) == 0
|
||||||
|
|
||||||
|
# Verify the score is 0 when answer is malformed JSON
|
||||||
|
answer = '["n", "o", "o", "n"], ["no", "on"], ["noon"]'
|
||||||
|
item = {"metadata": {"solution": [["no", "on"], ["noon"], ["n", "o", "o", "n"]]}}
|
||||||
|
assert dataset.score_answer(answer, item) == 0
|
||||||
|
|
@ -1,3 +1,5 @@
|
||||||
|
import string
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import sympy as sp
|
import sympy as sp
|
||||||
|
|
||||||
|
|
@ -17,7 +19,7 @@ def test_polynomial_config_validation():
|
||||||
PolynomialMultiplicationConfig(min_value=0).validate()
|
PolynomialMultiplicationConfig(min_value=0).validate()
|
||||||
|
|
||||||
with pytest.raises(AssertionError):
|
with pytest.raises(AssertionError):
|
||||||
PolynomialMultiplicationConfig(min_degree=0, max_degree=3).validate()
|
PolynomialMultiplicationConfig(min_degree=-1, max_degree=3).validate()
|
||||||
|
|
||||||
with pytest.raises(AssertionError):
|
with pytest.raises(AssertionError):
|
||||||
PolynomialMultiplicationConfig(min_degree=4, max_degree=3).validate()
|
PolynomialMultiplicationConfig(min_degree=4, max_degree=3).validate()
|
||||||
|
|
@ -28,6 +30,17 @@ def test_polynomial_config_validation():
|
||||||
with pytest.raises(AssertionError):
|
with pytest.raises(AssertionError):
|
||||||
PolynomialMultiplicationConfig(min_polynomials=5, max_polynomials=2).validate()
|
PolynomialMultiplicationConfig(min_polynomials=5, max_polynomials=2).validate()
|
||||||
|
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
PolynomialMultiplicationConfig(variables="").validate()
|
||||||
|
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
PolynomialMultiplicationConfig(
|
||||||
|
allow_cross_variable_product=False, allow_multivariate_polynomials=True
|
||||||
|
).validate()
|
||||||
|
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
PolynomialMultiplicationConfig(min_polynomials=5, max_polynomials=2).validate()
|
||||||
|
|
||||||
|
|
||||||
def test_polynomial_multiplication_dataset_basic():
|
def test_polynomial_multiplication_dataset_basic():
|
||||||
"""Test dataset creation and length"""
|
"""Test dataset creation and length"""
|
||||||
|
|
@ -41,7 +54,9 @@ def test_polynomial_multiplication_dataset_basic():
|
||||||
max_degree=2,
|
max_degree=2,
|
||||||
min_polynomials=2,
|
min_polynomials=2,
|
||||||
max_polynomials=3,
|
max_polynomials=3,
|
||||||
single_variable=True,
|
variables=tuple(string.ascii_lowercase),
|
||||||
|
allow_cross_variable_product=False,
|
||||||
|
allow_multivariate_polynomials=False,
|
||||||
seed=42,
|
seed=42,
|
||||||
size=dataset_size,
|
size=dataset_size,
|
||||||
)
|
)
|
||||||
|
|
@ -63,7 +78,9 @@ def test_polynomial_equations_dataset_items():
|
||||||
max_degree=2,
|
max_degree=2,
|
||||||
min_polynomials=2,
|
min_polynomials=2,
|
||||||
max_polynomials=5,
|
max_polynomials=5,
|
||||||
single_variable=False,
|
variables=tuple("xyz"),
|
||||||
|
allow_cross_variable_product=False,
|
||||||
|
allow_multivariate_polynomials=False,
|
||||||
size=3,
|
size=3,
|
||||||
seed=100,
|
seed=100,
|
||||||
)
|
)
|
||||||
|
|
@ -75,7 +92,113 @@ def test_polynomial_equations_dataset_items():
|
||||||
|
|
||||||
# Check metadata
|
# Check metadata
|
||||||
assert isinstance(item["metadata"]["polynomial_expr"], str)
|
assert isinstance(item["metadata"]["polynomial_expr"], str)
|
||||||
assert isinstance(item["metadata"]["single_variable"], bool)
|
assert isinstance(item["metadata"]["result"], str)
|
||||||
|
assert isinstance(item["metadata"]["variables"], list)
|
||||||
|
|
||||||
|
# Check polynomial_expr existence
|
||||||
|
poly_str = item["metadata"]["polynomial_expr"]
|
||||||
|
# Ensure it can parse with sympy
|
||||||
|
sp.sympify(poly_str)
|
||||||
|
|
||||||
|
|
||||||
|
def test_cross_polynomial_equations_dataset_items():
|
||||||
|
"""Test that generated items have correct structure"""
|
||||||
|
ds = create_dataset(
|
||||||
|
"polynomial_multiplication",
|
||||||
|
min_terms=2,
|
||||||
|
max_terms=3,
|
||||||
|
min_value=1,
|
||||||
|
max_value=5,
|
||||||
|
min_degree=1,
|
||||||
|
max_degree=2,
|
||||||
|
min_polynomials=2,
|
||||||
|
max_polynomials=5,
|
||||||
|
variables=tuple("xyz"),
|
||||||
|
allow_cross_variable_product=True,
|
||||||
|
allow_multivariate_polynomials=False,
|
||||||
|
size=3,
|
||||||
|
seed=100,
|
||||||
|
)
|
||||||
|
|
||||||
|
for item in ds:
|
||||||
|
assert "question" in item
|
||||||
|
assert "answer" in item
|
||||||
|
assert "metadata" in item
|
||||||
|
|
||||||
|
# Check metadata
|
||||||
|
assert isinstance(item["metadata"]["polynomial_expr"], str)
|
||||||
|
assert isinstance(item["metadata"]["result"], str)
|
||||||
|
assert isinstance(item["metadata"]["variables"], list)
|
||||||
|
|
||||||
|
# Check polynomial_expr existence
|
||||||
|
poly_str = item["metadata"]["polynomial_expr"]
|
||||||
|
# Ensure it can parse with sympy
|
||||||
|
sp.sympify(poly_str)
|
||||||
|
|
||||||
|
|
||||||
|
def test_cross_polynomial_equations_dataset_items():
|
||||||
|
"""Test that generated items have correct structure"""
|
||||||
|
ds = create_dataset(
|
||||||
|
"polynomial_multiplication",
|
||||||
|
min_terms=2,
|
||||||
|
max_terms=3,
|
||||||
|
min_value=1,
|
||||||
|
max_value=5,
|
||||||
|
min_degree=1,
|
||||||
|
max_degree=2,
|
||||||
|
min_polynomials=2,
|
||||||
|
max_polynomials=5,
|
||||||
|
variables=tuple("xyz"),
|
||||||
|
allow_cross_variable_product=True,
|
||||||
|
allow_multivariate_polynomials=False,
|
||||||
|
size=3,
|
||||||
|
seed=100,
|
||||||
|
)
|
||||||
|
|
||||||
|
for item in ds:
|
||||||
|
assert "question" in item
|
||||||
|
assert "answer" in item
|
||||||
|
assert "metadata" in item
|
||||||
|
|
||||||
|
# Check metadata
|
||||||
|
assert isinstance(item["metadata"]["polynomial_expr"], str)
|
||||||
|
assert isinstance(item["metadata"]["result"], str)
|
||||||
|
assert isinstance(item["metadata"]["variables"], list)
|
||||||
|
|
||||||
|
# Check polynomial_expr existence
|
||||||
|
poly_str = item["metadata"]["polynomial_expr"]
|
||||||
|
# Ensure it can parse with sympy
|
||||||
|
sp.sympify(poly_str)
|
||||||
|
|
||||||
|
|
||||||
|
def test_multivariate_polynomial_equations_dataset_items():
|
||||||
|
"""Test that generated items have correct structure"""
|
||||||
|
ds = create_dataset(
|
||||||
|
"polynomial_multiplication",
|
||||||
|
min_terms=2,
|
||||||
|
max_terms=3,
|
||||||
|
min_value=1,
|
||||||
|
max_value=5,
|
||||||
|
min_degree=1,
|
||||||
|
max_degree=2,
|
||||||
|
min_polynomials=2,
|
||||||
|
max_polynomials=5,
|
||||||
|
variables=tuple(["x", "y"]),
|
||||||
|
allow_cross_variable_product=True,
|
||||||
|
allow_multivariate_polynomials=True,
|
||||||
|
size=3,
|
||||||
|
seed=100,
|
||||||
|
)
|
||||||
|
|
||||||
|
for item in ds:
|
||||||
|
assert "question" in item
|
||||||
|
assert "answer" in item
|
||||||
|
assert "metadata" in item
|
||||||
|
|
||||||
|
# Check metadata
|
||||||
|
assert isinstance(item["metadata"]["polynomial_expr"], str)
|
||||||
|
assert isinstance(item["metadata"]["result"], str)
|
||||||
|
assert isinstance(item["metadata"]["variables"], list)
|
||||||
|
|
||||||
# Check polynomial_expr existence
|
# Check polynomial_expr existence
|
||||||
poly_str = item["metadata"]["polynomial_expr"]
|
poly_str = item["metadata"]["polynomial_expr"]
|
||||||
|
|
@ -105,7 +228,9 @@ def test_polynomial_solutions_evaluation():
|
||||||
max_degree=3,
|
max_degree=3,
|
||||||
min_polynomials=2,
|
min_polynomials=2,
|
||||||
max_polynomials=5,
|
max_polynomials=5,
|
||||||
single_variable=False,
|
variables=tuple(["x", "y"]),
|
||||||
|
allow_cross_variable_product=True,
|
||||||
|
allow_multivariate_polynomials=True,
|
||||||
size=5,
|
size=5,
|
||||||
seed=42,
|
seed=42,
|
||||||
)
|
)
|
||||||
|
|
@ -125,42 +250,27 @@ def test_score_function():
|
||||||
ds = create_dataset(
|
ds = create_dataset(
|
||||||
"polynomial_multiplication",
|
"polynomial_multiplication",
|
||||||
min_terms=2,
|
min_terms=2,
|
||||||
max_terms=4,
|
max_terms=3,
|
||||||
min_value=1,
|
min_value=1,
|
||||||
max_value=10,
|
max_value=3,
|
||||||
min_degree=1,
|
min_degree=1,
|
||||||
max_degree=3,
|
max_degree=3,
|
||||||
min_polynomials=2,
|
min_polynomials=3,
|
||||||
max_polynomials=5,
|
max_polynomials=3,
|
||||||
single_variable=True,
|
variables=tuple(["x", "y"]),
|
||||||
size=1,
|
allow_cross_variable_product=True,
|
||||||
|
allow_multivariate_polynomials=True,
|
||||||
|
size=3,
|
||||||
seed=42,
|
seed=42,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert ds.score_answer(None, ds[0]) == 0.00
|
for item in ds:
|
||||||
assert ds.score_answer("6*x**4 + 9*x**3 - 6*x**2 - 39*x - 45", ds[0]) == 1
|
poly_str = item["metadata"]["polynomial_expr"]
|
||||||
assert ds.score_answer("Not a polynomial", ds[0]) == 0.01
|
assert ds.score_answer(poly_str, item) == 0.05
|
||||||
assert ds.score_answer("x**4", ds[0]) == 0.05
|
|
||||||
|
|
||||||
|
poly_expr = str(sp.expand(poly_str))
|
||||||
|
assert ds.score_answer(poly_expr, item) == 1.0
|
||||||
|
|
||||||
def test_multivariate_score_function():
|
assert ds.score_answer(None, item) == 0.00
|
||||||
"""Test that solution satisfy the polynomial multiplication."""
|
assert ds.score_answer("Not a polynomial", item) == 0.01
|
||||||
ds = create_dataset(
|
assert ds.score_answer("x**4", item) == 0.05
|
||||||
"polynomial_multiplication",
|
|
||||||
min_terms=2,
|
|
||||||
max_terms=4,
|
|
||||||
min_value=1,
|
|
||||||
max_value=10,
|
|
||||||
min_degree=1,
|
|
||||||
max_degree=3,
|
|
||||||
min_polynomials=2,
|
|
||||||
max_polynomials=5,
|
|
||||||
single_variable=False,
|
|
||||||
size=1,
|
|
||||||
seed=42,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert ds.score_answer(None, ds[0]) == 0.00
|
|
||||||
assert ds.score_answer("-27*a**3*c - 27*a**3 + 144*a*c + 144*a", ds[0]) == 1
|
|
||||||
assert ds.score_answer("Not a polynomial", ds[0]) == 0.01
|
|
||||||
assert ds.score_answer("x**4", ds[0]) == 0.05
|
|
||||||
|
|
|
||||||
|
|
@ -37,6 +37,7 @@ def test_getitem(dataset, config):
|
||||||
assert "metadata" in item
|
assert "metadata" in item
|
||||||
assert item["metadata"]["word_count"] >= config.min_words_in_sentence
|
assert item["metadata"]["word_count"] >= config.min_words_in_sentence
|
||||||
assert item["metadata"]["word_count"] <= config.max_words_in_sentence
|
assert item["metadata"]["word_count"] <= config.max_words_in_sentence
|
||||||
|
assert len(item["answer"].split()) == item["metadata"]["word_count"]
|
||||||
|
|
||||||
|
|
||||||
def test_key_error_in_getitem(dataset):
|
def test_key_error_in_getitem(dataset):
|
||||||
|
|
|
||||||
|
|
@ -15,6 +15,10 @@ def test_spiral_matrix_config_validation():
|
||||||
config = SpiralMatrixConfig(max_n=0) # Zero not allowed
|
config = SpiralMatrixConfig(max_n=0) # Zero not allowed
|
||||||
config.validate()
|
config.validate()
|
||||||
|
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
config = SpiralMatrixConfig(max_n=1) # One not allowed
|
||||||
|
config.validate()
|
||||||
|
|
||||||
|
|
||||||
def test_spiral_matrix_dataset_deterministic():
|
def test_spiral_matrix_dataset_deterministic():
|
||||||
"""Test that dataset generates same items with same seed"""
|
"""Test that dataset generates same items with same seed"""
|
||||||
|
|
@ -69,18 +73,26 @@ def test_spiral_matrix_answer():
|
||||||
config = SpiralMatrixConfig(seed=42)
|
config = SpiralMatrixConfig(seed=42)
|
||||||
dataset = SpiralMatrixDataset(config)
|
dataset = SpiralMatrixDataset(config)
|
||||||
|
|
||||||
# One element
|
|
||||||
matrix = [[0]]
|
|
||||||
assert dataset._get_spiral(matrix) == [0]
|
|
||||||
|
|
||||||
# One row
|
|
||||||
matrix = [[0, 1, 2]]
|
|
||||||
assert dataset._get_spiral(matrix) == [0, 1, 2]
|
|
||||||
|
|
||||||
# One column
|
|
||||||
matrix = [[0], [1], [2]]
|
|
||||||
assert dataset._get_spiral(matrix) == [0, 1, 2]
|
|
||||||
|
|
||||||
# 2D grid
|
# 2D grid
|
||||||
matrix = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
|
matrix = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
|
||||||
assert dataset._get_spiral(matrix) == [1, 2, 3, 6, 9, 8, 7, 4, 5]
|
assert dataset._get_spiral(matrix) == [1, 2, 3, 6, 9, 8, 7, 4, 5]
|
||||||
|
|
||||||
|
# Answer is identical (up to trimming)
|
||||||
|
entry = {"answer": "1 2 3 6 9 8 7 4 5"}
|
||||||
|
answer = "\n\n1 2 3 6 9 8 7 4 5\n"
|
||||||
|
assert dataset.score_answer(answer, entry) == 1.0
|
||||||
|
|
||||||
|
# Score answer in list format (partially correct)
|
||||||
|
entry = {"answer": "1 2 3 6 9 8 7 4 5"}
|
||||||
|
answer = "[1, 2, 3, 6, 9, 8, 7, 4, 5]"
|
||||||
|
assert dataset.score_answer(answer, entry) == 0.5
|
||||||
|
|
||||||
|
# Answer is incorrect
|
||||||
|
entry = {"answer": "1 2 3 6 9 8 7 4 5"}
|
||||||
|
answer = "1 2 3"
|
||||||
|
assert dataset.score_answer(answer, entry) == 0.01
|
||||||
|
|
||||||
|
# Answer is none
|
||||||
|
entry = {"answer": "1 2 3 6 9 8 7 4 5"}
|
||||||
|
answer = None
|
||||||
|
assert dataset.score_answer(answer, entry) == 0.0
|
||||||
|
|
|
||||||
|
|
@ -92,3 +92,13 @@ def test_string_insertion_answer():
|
||||||
|
|
||||||
# No reuse of newly inserted characters
|
# No reuse of newly inserted characters
|
||||||
assert dataset._get_answer("ABCDBCD") == "ABCDABCD"
|
assert dataset._get_answer("ABCDBCD") == "ABCDABCD"
|
||||||
|
|
||||||
|
# Test score_answer with correct answer
|
||||||
|
answer = "AABCDAEEEEEEEBCDEBAAAAA"
|
||||||
|
entry = {"answer": "AABCDAEEEEEEEBCDEBAAAAA"}
|
||||||
|
assert dataset.score_answer(answer, entry) == 1.0
|
||||||
|
|
||||||
|
# Test score_answer with correct answer as python list of characters (partial correct)
|
||||||
|
answer = "['A', 'A', 'B', 'C', 'D', 'A', 'E', 'E', 'E', 'E', 'E', 'E', 'E', 'B', 'C', 'D', 'E', 'B', 'A', 'A', 'A', 'A', 'A']"
|
||||||
|
entry = {"answer": "AABCDAEEEEEEEBCDEBAAAAA"}
|
||||||
|
assert dataset.score_answer(answer, entry) == 0.5
|
||||||
|
|
|
||||||
|
|
@ -116,3 +116,35 @@ def test_word_sorting_dataset_iteration():
|
||||||
|
|
||||||
# Test multiple iterations yield same items
|
# Test multiple iterations yield same items
|
||||||
assert items == list(dataset)
|
assert items == list(dataset)
|
||||||
|
|
||||||
|
|
||||||
|
def test_word_sorting_scoring():
|
||||||
|
"""Test scoring function"""
|
||||||
|
config = WordSortingConfig(size=1, seed=42)
|
||||||
|
dataset = WordSortingDataset(config)
|
||||||
|
|
||||||
|
item = {
|
||||||
|
"metadata": {
|
||||||
|
"sorted_words": ["apple", "banana", "cherry"],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
# Correct answer
|
||||||
|
answer = "apple, banana, cherry"
|
||||||
|
assert dataset.score_answer(answer, item) == 1.0
|
||||||
|
|
||||||
|
# Correct answer, with incorrect spaces
|
||||||
|
answer = "apple,banana, cherry"
|
||||||
|
assert dataset.score_answer(answer, item) == 1.0
|
||||||
|
|
||||||
|
# All words present, but not sorted
|
||||||
|
answer = "banana, cherry, apple"
|
||||||
|
assert dataset.score_answer(answer, item) == 0.2
|
||||||
|
|
||||||
|
# Garbage
|
||||||
|
answer = "gibberish"
|
||||||
|
assert dataset.score_answer(answer, item) == 0.01
|
||||||
|
|
||||||
|
# Empty answer
|
||||||
|
answer = None
|
||||||
|
assert dataset.score_answer(answer, item) == 0.0
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue