diff --git a/reasoning_gym/logic/zebra_puzzles.py b/reasoning_gym/logic/zebra_puzzles.py index 2ba58994..3ba177a7 100644 --- a/reasoning_gym/logic/zebra_puzzles.py +++ b/reasoning_gym/logic/zebra_puzzles.py @@ -10,25 +10,21 @@ from .contrib.logic_puzzle.generate import generate_puzzle class ZebraConfig: """Configuration for zebra puzzle generation""" - k: int = 4 - m: int = 4 + num_people: int = 4 + num_characteristics: int = 4 seed: Optional[int] = None size: int = 500 def validate(self): """Validate configuration parameters""" - assert 2 <= self.k <= 7, "k must be between 2 and 7" - assert 2 <= self.m <= 7, "m must be between 2 and 7" + assert 2 <= self.num_people <= 7, "num_people must be between 2 and 7" + assert 2 <= self.num_characteristics <= 7, "num_characteristics must be between 2 and 7" class ZebraDataset(ProceduralDataset): - """Generates Game of Life games with configurable parameters""" + """Generates [Zebra Puzzles](https://en.wikipedia.org/wiki/Zebra_Puzzle) with configurable parameters""" def __init__(self, config: ZebraConfig): - self._prompt_templates = [ - "What will this Game of Life board look like after {simulation_steps} steps of simulation?\n\n{board}" - ] - super().__init__(config=config, seed=config.seed, size=config.size) def __getitem__(self, idx: int) -> dict: @@ -42,17 +38,20 @@ class ZebraDataset(ProceduralDataset): """ seed(self.seed + idx) - K = self.config.k - M = self.config.m + K = self.config.num_people + M = self.config.num_characteristics instance, puzzle = generate_puzzle(K, M, "train") q = instance["questions"][0]["question"] - a = instance["questions"][0]["answer"] + answer = instance["questions"][0]["answer"] question = str(puzzle) + "\n" + q return { "question": question, - "answer": a, - "metadata": {"K": K, "M": M, "answer": a}, + "answer": answer, + "metadata": { + "num_people": K, + "num_characteristics": M, + }, } def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float: diff --git a/tests/test_zebra.py b/tests/test_zebra.py index 339ff3bc..d233c438 100644 --- a/tests/test_zebra.py +++ b/tests/test_zebra.py @@ -6,7 +6,7 @@ from reasoning_gym.logic.zebra_puzzles import ZebraConfig, ZebraDataset def test_zebra_puzzles(): """Test basic properties and solution of generated items""" - config = ZebraConfig(seed=42, size=10, k=4, m=4) + config = ZebraConfig(seed=42, size=10, num_people=4, num_characteristics=4) dataset = ZebraDataset(config) for item in dataset: