mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-29 17:35:16 +00:00
feat: Add scoring method & unit tests for circuit logic dataset
This commit is contained in:
parent
719369bce6
commit
63bd662acf
3 changed files with 246 additions and 8 deletions
|
|
@ -21,6 +21,7 @@ __all__ = [
|
||||||
"ZebraConfig",
|
"ZebraConfig",
|
||||||
"ZebraDataset",
|
"ZebraDataset",
|
||||||
"SelfReference",
|
"SelfReference",
|
||||||
|
"SelfReferenceConfig",
|
||||||
"SelfReferenceDataset",
|
"SelfReferenceDataset",
|
||||||
"CircuitLogicConfig",
|
"CircuitLogicConfig",
|
||||||
"CircuitLogicDataset",
|
"CircuitLogicDataset",
|
||||||
|
|
|
||||||
|
|
@ -175,7 +175,7 @@ class CircuitLogicDataset(ProceduralDataset):
|
||||||
term_str = op[1].join(parts)
|
term_str = op[1].join(parts)
|
||||||
term_strings.append(term_str)
|
term_strings.append(term_str)
|
||||||
|
|
||||||
expression_for_display = final_gate_sym.join(term_strings)
|
expression_for_display = final_gate_sym.join(f"({t})" for t in term_strings)
|
||||||
# use || separator internally that doesn't clash with other symbols...
|
# use || separator internally that doesn't clash with other symbols...
|
||||||
separator = "||"
|
separator = "||"
|
||||||
expression_for_internal = separator.join(term_strings)
|
expression_for_internal = separator.join(term_strings)
|
||||||
|
|
@ -351,19 +351,17 @@ class CircuitLogicDataset(ProceduralDataset):
|
||||||
term_val = 0
|
term_val = 0
|
||||||
term_values.append(term_val)
|
term_values.append(term_val)
|
||||||
|
|
||||||
|
# Evaluate final gate based on term values
|
||||||
if final_gate_name == "OR":
|
if final_gate_name == "OR":
|
||||||
final_result = 1 if any(v == 1 for v in term_values) else 0
|
final_result = 1 if any(v == 1 for v in term_values) else 0
|
||||||
elif final_gate_name == "NOR":
|
elif final_gate_name == "NOR":
|
||||||
final_result = 0 if any(v == 1 for v in term_values) else 1
|
final_result = 0 if any(v == 1 for v in term_values) else 1
|
||||||
elif final_gate_name == "XOR":
|
elif final_gate_name == "XOR":
|
||||||
tmp = 0
|
final_result = sum(term_values) % 2
|
||||||
for v in term_values:
|
|
||||||
tmp ^= v
|
|
||||||
final_result = tmp
|
|
||||||
elif final_gate_name == "AND":
|
elif final_gate_name == "AND":
|
||||||
final_result = 1 if all(v == 1 for v in term_values) else 0
|
final_result = 1 if all(v == 1 for v in term_values) else 0
|
||||||
else:
|
else:
|
||||||
final_result = 0
|
raise ValueError(f"Unknown gate type: {final_gate_name}")
|
||||||
|
|
||||||
lines = []
|
lines = []
|
||||||
lines.append("Below is a randomly generated logic circuit.\n")
|
lines.append("Below is a randomly generated logic circuit.\n")
|
||||||
|
|
@ -373,7 +371,10 @@ class CircuitLogicDataset(ProceduralDataset):
|
||||||
legend_lines.append("Legend for gates:")
|
legend_lines.append("Legend for gates:")
|
||||||
for op_name, _, draw_sym in self.internal_ops:
|
for op_name, _, draw_sym in self.internal_ops:
|
||||||
legend_lines.append(f"{draw_sym*2}: {op_name}")
|
legend_lines.append(f"{draw_sym*2}: {op_name}")
|
||||||
legend_lines.append(f"{final_gate_sym*2}: {final_gate_name}")
|
if neg_prob > 0:
|
||||||
|
legend_lines.append(f">o: Negate")
|
||||||
|
if final_gate_sym not in self.internal_ops:
|
||||||
|
legend_lines.append(f"{final_gate_sym*2}: {final_gate_name}")
|
||||||
legend_str = "\n".join(legend_lines)
|
legend_str = "\n".join(legend_lines)
|
||||||
|
|
||||||
lines.append(legend_str)
|
lines.append(legend_str)
|
||||||
|
|
@ -393,11 +394,23 @@ class CircuitLogicDataset(ProceduralDataset):
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"expression": expression_for_display,
|
"expression": expression_for_display,
|
||||||
"assignments": assignments,
|
"assignments": assignments,
|
||||||
|
"term_strings": term_strings,
|
||||||
"final_gate": final_gate_name,
|
"final_gate": final_gate_name,
|
||||||
"inputs": inputs_list,
|
"inputs": inputs_list,
|
||||||
"legend": legend_str,
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def score_answer(self, answer: Optional[str], entry: Dict[str, Any]) -> float:
|
||||||
|
if answer is None or len(answer) == 0:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
oracle_answer = entry["answer"]
|
||||||
|
if oracle_answer == answer:
|
||||||
|
return 1.0
|
||||||
|
elif oracle_answer == answer.strip():
|
||||||
|
return len(oracle_answer) / len(answer)
|
||||||
|
|
||||||
|
return 0.01
|
||||||
|
|
||||||
|
|
||||||
register_dataset("circuit_logic", CircuitLogicDataset, CircuitLogicConfig)
|
register_dataset("circuit_logic", CircuitLogicDataset, CircuitLogicConfig)
|
||||||
|
|
|
||||||
224
tests/test_circuit_logic.py
Normal file
224
tests/test_circuit_logic.py
Normal file
|
|
@ -0,0 +1,224 @@
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from reasoning_gym.logic import CircuitLogicConfig, CircuitLogicDataset
|
||||||
|
|
||||||
|
|
||||||
|
def test_circuit_logic_config_validation():
|
||||||
|
"""Test that invalid configs raise appropriate errors"""
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
config = CircuitLogicConfig(min_inputs=3, max_inputs=2)
|
||||||
|
config.validate()
|
||||||
|
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
config = CircuitLogicConfig(num_terms=0)
|
||||||
|
config.validate()
|
||||||
|
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
config = CircuitLogicConfig(neg_prob=-0.1)
|
||||||
|
config.validate()
|
||||||
|
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
config = CircuitLogicConfig(neg_prob=1.1)
|
||||||
|
config.validate()
|
||||||
|
|
||||||
|
|
||||||
|
def test_circuit_logic_deterministic():
|
||||||
|
"""Test that dataset generates same items with same seed"""
|
||||||
|
config = CircuitLogicConfig(seed=42, size=10)
|
||||||
|
dataset1 = CircuitLogicDataset(config)
|
||||||
|
dataset2 = CircuitLogicDataset(config)
|
||||||
|
|
||||||
|
for i in range(len(dataset1)):
|
||||||
|
assert dataset1[i] == dataset2[i]
|
||||||
|
|
||||||
|
|
||||||
|
def test_circuit_logic_items():
|
||||||
|
"""Test basic properties of generated items"""
|
||||||
|
config = CircuitLogicConfig(num_terms=3, min_inputs=2, max_inputs=3, neg_prob=0.3, size=50, seed=42)
|
||||||
|
dataset = CircuitLogicDataset(config)
|
||||||
|
|
||||||
|
for i in range(len(dataset)):
|
||||||
|
item = dataset[i]
|
||||||
|
assert isinstance(item, dict)
|
||||||
|
assert "question" in item
|
||||||
|
assert "answer" in item
|
||||||
|
assert "metadata" in item
|
||||||
|
|
||||||
|
# Verify metadata contents
|
||||||
|
metadata = item["metadata"]
|
||||||
|
assert "expression" in metadata
|
||||||
|
assert "assignments" in metadata
|
||||||
|
assert "final_gate" in metadata
|
||||||
|
assert "inputs" in metadata
|
||||||
|
|
||||||
|
# Verify answer is binary
|
||||||
|
assert item["answer"] in ("0", "1")
|
||||||
|
|
||||||
|
# Verify assignments are binary
|
||||||
|
for input_name, value in metadata["assignments"].items():
|
||||||
|
assert value in (0, 1)
|
||||||
|
|
||||||
|
# Verify final gate is valid
|
||||||
|
assert metadata["final_gate"] in ("OR", "NOR", "XOR", "AND")
|
||||||
|
|
||||||
|
# Verify inputs list matches assignments
|
||||||
|
assert set(metadata["inputs"]) == set(metadata["assignments"].keys())
|
||||||
|
|
||||||
|
|
||||||
|
def test_circuit_logic_expression_validity():
|
||||||
|
"""Test that generated expressions follow logical circuit rules"""
|
||||||
|
config = CircuitLogicConfig(
|
||||||
|
num_terms=2, min_inputs=2, max_inputs=2, neg_prob=0.0, size=20, seed=42 # Disable negation for simpler testing
|
||||||
|
)
|
||||||
|
dataset = CircuitLogicDataset(config)
|
||||||
|
|
||||||
|
for i in range(len(dataset)):
|
||||||
|
item = dataset[i]
|
||||||
|
metadata = item["metadata"]
|
||||||
|
|
||||||
|
# Expression should contain valid operators
|
||||||
|
expr = metadata["expression"]
|
||||||
|
assert any(op in expr for op in ("&", "↑", "⊕", "+", "↓"))
|
||||||
|
|
||||||
|
# Input names should be valid Excel-style names
|
||||||
|
for input_name in metadata["inputs"]:
|
||||||
|
assert input_name.isalpha()
|
||||||
|
assert input_name.isupper()
|
||||||
|
|
||||||
|
|
||||||
|
def test_circuit_logic_answer_verification():
|
||||||
|
"""Test that answers match logical evaluation of circuits"""
|
||||||
|
config = CircuitLogicConfig(num_terms=2, min_inputs=2, max_inputs=2, size=20, seed=42)
|
||||||
|
dataset = CircuitLogicDataset(config)
|
||||||
|
|
||||||
|
def evaluate_term(term: str, assignments: dict) -> int:
|
||||||
|
"""Evaluate a single term with given assignments"""
|
||||||
|
if "↑" in term: # NAND
|
||||||
|
parts = term.split("↑")
|
||||||
|
values = []
|
||||||
|
for p in parts:
|
||||||
|
if p.endswith("'"):
|
||||||
|
values.append(1 - assignments[p[:-1]])
|
||||||
|
else:
|
||||||
|
values.append(assignments[p])
|
||||||
|
return 0 if all(v == 1 for v in values) else 1
|
||||||
|
elif "&" in term: # AND
|
||||||
|
parts = term.split("&")
|
||||||
|
values = []
|
||||||
|
for p in parts:
|
||||||
|
if p.endswith("'"):
|
||||||
|
values.append(1 - assignments[p[:-1]])
|
||||||
|
else:
|
||||||
|
values.append(assignments[p])
|
||||||
|
return 1 if all(v == 1 for v in values) else 0
|
||||||
|
elif "⊕" in term: # XOR
|
||||||
|
parts = term.split("⊕")
|
||||||
|
values = []
|
||||||
|
for p in parts:
|
||||||
|
if p.endswith("'"):
|
||||||
|
values.append(1 - assignments[p[:-1]])
|
||||||
|
else:
|
||||||
|
values.append(assignments[p])
|
||||||
|
return sum(values) % 2
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown operator in term: {term}")
|
||||||
|
|
||||||
|
def evaluate_final_gate(gate_type: str, term_values: list) -> int:
|
||||||
|
"""Evaluate the final gate with given term values"""
|
||||||
|
if gate_type == "AND":
|
||||||
|
return 1 if all(v == 1 for v in term_values) else 0
|
||||||
|
elif gate_type == "OR":
|
||||||
|
return 1 if any(v == 1 for v in term_values) else 0
|
||||||
|
elif gate_type == "XOR":
|
||||||
|
return sum(term_values) % 2
|
||||||
|
elif gate_type == "NOR":
|
||||||
|
return 0 if any(v == 1 for v in term_values) else 1
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown gate type: {gate_type}")
|
||||||
|
|
||||||
|
for i in range(len(dataset)):
|
||||||
|
item = dataset[i]
|
||||||
|
metadata = item["metadata"]
|
||||||
|
assignments = metadata["assignments"]
|
||||||
|
final_gate = metadata["final_gate"]
|
||||||
|
term_strings = metadata["term_strings"]
|
||||||
|
|
||||||
|
# First evaluate each term
|
||||||
|
term_values = [evaluate_term(term, assignments) for term in term_strings]
|
||||||
|
|
||||||
|
# Then combine terms with final gate
|
||||||
|
expected = evaluate_final_gate(final_gate, term_values)
|
||||||
|
|
||||||
|
# Compare with actual result
|
||||||
|
result = int(item["answer"])
|
||||||
|
assert (
|
||||||
|
result == expected
|
||||||
|
), f"Item {i}: Expected {expected} but got {result} for terms {term_strings} with assignments {assignments} and final gate {final_gate}"
|
||||||
|
|
||||||
|
|
||||||
|
def test_circuit_logic_ascii_diagram():
|
||||||
|
"""Test properties of the ASCII circuit diagram"""
|
||||||
|
config = CircuitLogicConfig(num_terms=2, min_inputs=2, max_inputs=2, size=10, seed=42)
|
||||||
|
dataset = CircuitLogicDataset(config)
|
||||||
|
|
||||||
|
for i in range(len(dataset)):
|
||||||
|
item = dataset[i]
|
||||||
|
|
||||||
|
# Split question to get diagram
|
||||||
|
parts = item["question"].split("\n")
|
||||||
|
diagram_start = parts.index("Below is a randomly generated logic circuit.") + 2
|
||||||
|
diagram_end = parts.index("", diagram_start)
|
||||||
|
diagram = parts[diagram_start:diagram_end]
|
||||||
|
|
||||||
|
# Basic diagram validation
|
||||||
|
assert len(diagram) > 0
|
||||||
|
assert all(len(row) > 0 for row in diagram)
|
||||||
|
|
||||||
|
# Check for required circuit elements
|
||||||
|
diagram_str = "\n".join(diagram)
|
||||||
|
assert "OUT" in diagram_str
|
||||||
|
assert any(gate in diagram_str for gate in ("&", "↑", "⊕"))
|
||||||
|
|
||||||
|
# Verify input labels
|
||||||
|
for input_name in item["metadata"]["inputs"]:
|
||||||
|
assert f"{input_name}:" in diagram_str
|
||||||
|
|
||||||
|
|
||||||
|
def test_circuit_logic_scoring():
|
||||||
|
"""Test the answer scoring mechanism"""
|
||||||
|
config = CircuitLogicConfig(size=5, seed=42)
|
||||||
|
dataset = CircuitLogicDataset(config)
|
||||||
|
|
||||||
|
item = dataset[0]
|
||||||
|
|
||||||
|
# Correct answer should score 1.0
|
||||||
|
assert dataset.score_answer(item["answer"], item) == 1.0
|
||||||
|
|
||||||
|
# Wrong answer should score lower
|
||||||
|
wrong_answer = "1" if item["answer"] == "0" else "0"
|
||||||
|
assert dataset.score_answer(wrong_answer, item) < 1.0
|
||||||
|
|
||||||
|
# None or empty answer should score 0.0
|
||||||
|
assert dataset.score_answer(None, item) == 0.0
|
||||||
|
assert dataset.score_answer("", item) == 0.0 # Empty string should score 0.0 like None
|
||||||
|
|
||||||
|
|
||||||
|
def test_circuit_logic_iteration():
|
||||||
|
"""Test that iteration works correctly"""
|
||||||
|
config = CircuitLogicConfig(size=5, seed=42)
|
||||||
|
dataset = CircuitLogicDataset(config)
|
||||||
|
|
||||||
|
# Test manual iteration
|
||||||
|
items = []
|
||||||
|
for item in dataset:
|
||||||
|
items.append(item)
|
||||||
|
assert len(items) == config.size
|
||||||
|
|
||||||
|
# Test list conversion
|
||||||
|
items = list(dataset)
|
||||||
|
assert len(items) == config.size
|
||||||
|
|
||||||
|
# Test multiple iterations yield same items
|
||||||
|
first_items = list(dataset)
|
||||||
|
second_items = list(dataset)
|
||||||
|
assert first_items == second_items
|
||||||
Loading…
Add table
Add a link
Reference in a new issue