This commit is contained in:
Rich Jones 2025-02-03 16:47:29 +01:00
parent 7274f79c50
commit 4d950e562a
2 changed files with 14 additions and 15 deletions

View file

@ -10,25 +10,21 @@ from .contrib.logic_puzzle.generate import generate_puzzle
class ZebraConfig: class ZebraConfig:
"""Configuration for zebra puzzle generation""" """Configuration for zebra puzzle generation"""
k: int = 4 num_people: int = 4
m: int = 4 num_characteristics: int = 4
seed: Optional[int] = None seed: Optional[int] = None
size: int = 500 size: int = 500
def validate(self): def validate(self):
"""Validate configuration parameters""" """Validate configuration parameters"""
assert 2 <= self.k <= 7, "k must be between 2 and 7" assert 2 <= self.num_people <= 7, "num_people must be between 2 and 7"
assert 2 <= self.m <= 7, "m must be between 2 and 7" assert 2 <= self.num_characteristics <= 7, "num_characteristics must be between 2 and 7"
class ZebraDataset(ProceduralDataset): class ZebraDataset(ProceduralDataset):
"""Generates Game of Life games with configurable parameters""" """Generates [Zebra Puzzles](https://en.wikipedia.org/wiki/Zebra_Puzzle) with configurable parameters"""
def __init__(self, config: ZebraConfig): def __init__(self, config: ZebraConfig):
self._prompt_templates = [
"What will this Game of Life board look like after {simulation_steps} steps of simulation?\n\n{board}"
]
super().__init__(config=config, seed=config.seed, size=config.size) super().__init__(config=config, seed=config.seed, size=config.size)
def __getitem__(self, idx: int) -> dict: def __getitem__(self, idx: int) -> dict:
@ -42,17 +38,20 @@ class ZebraDataset(ProceduralDataset):
""" """
seed(self.seed + idx) seed(self.seed + idx)
K = self.config.k K = self.config.num_people
M = self.config.m M = self.config.num_characteristics
instance, puzzle = generate_puzzle(K, M, "train") instance, puzzle = generate_puzzle(K, M, "train")
q = instance["questions"][0]["question"] q = instance["questions"][0]["question"]
a = instance["questions"][0]["answer"] answer = instance["questions"][0]["answer"]
question = str(puzzle) + "\n" + q question = str(puzzle) + "\n" + q
return { return {
"question": question, "question": question,
"answer": a, "answer": answer,
"metadata": {"K": K, "M": M, "answer": a}, "metadata": {
"num_people": K,
"num_characteristics": M,
},
} }
def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float: def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float:

View file

@ -6,7 +6,7 @@ from reasoning_gym.logic.zebra_puzzles import ZebraConfig, ZebraDataset
def test_zebra_puzzles(): def test_zebra_puzzles():
"""Test basic properties and solution of generated items""" """Test basic properties and solution of generated items"""
config = ZebraConfig(seed=42, size=10, k=4, m=4) config = ZebraConfig(seed=42, size=10, num_people=4, num_characteristics=4)
dataset = ZebraDataset(config) dataset = ZebraDataset(config)
for item in dataset: for item in dataset: