mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-19 12:58:07 +00:00
feat: Add scoring method & unit tests for circuit logic dataset
This commit is contained in:
parent
719369bce6
commit
63bd662acf
3 changed files with 246 additions and 8 deletions
|
|
@ -175,7 +175,7 @@ class CircuitLogicDataset(ProceduralDataset):
|
|||
term_str = op[1].join(parts)
|
||||
term_strings.append(term_str)
|
||||
|
||||
expression_for_display = final_gate_sym.join(term_strings)
|
||||
expression_for_display = final_gate_sym.join(f"({t})" for t in term_strings)
|
||||
# use || separator internally that doesn't clash with other symbols...
|
||||
separator = "||"
|
||||
expression_for_internal = separator.join(term_strings)
|
||||
|
|
@ -351,19 +351,17 @@ class CircuitLogicDataset(ProceduralDataset):
|
|||
term_val = 0
|
||||
term_values.append(term_val)
|
||||
|
||||
# Evaluate final gate based on term values
|
||||
if final_gate_name == "OR":
|
||||
final_result = 1 if any(v == 1 for v in term_values) else 0
|
||||
elif final_gate_name == "NOR":
|
||||
final_result = 0 if any(v == 1 for v in term_values) else 1
|
||||
elif final_gate_name == "XOR":
|
||||
tmp = 0
|
||||
for v in term_values:
|
||||
tmp ^= v
|
||||
final_result = tmp
|
||||
final_result = sum(term_values) % 2
|
||||
elif final_gate_name == "AND":
|
||||
final_result = 1 if all(v == 1 for v in term_values) else 0
|
||||
else:
|
||||
final_result = 0
|
||||
raise ValueError(f"Unknown gate type: {final_gate_name}")
|
||||
|
||||
lines = []
|
||||
lines.append("Below is a randomly generated logic circuit.\n")
|
||||
|
|
@ -373,7 +371,10 @@ class CircuitLogicDataset(ProceduralDataset):
|
|||
legend_lines.append("Legend for gates:")
|
||||
for op_name, _, draw_sym in self.internal_ops:
|
||||
legend_lines.append(f"{draw_sym*2}: {op_name}")
|
||||
legend_lines.append(f"{final_gate_sym*2}: {final_gate_name}")
|
||||
if neg_prob > 0:
|
||||
legend_lines.append(f">o: Negate")
|
||||
if final_gate_sym not in self.internal_ops:
|
||||
legend_lines.append(f"{final_gate_sym*2}: {final_gate_name}")
|
||||
legend_str = "\n".join(legend_lines)
|
||||
|
||||
lines.append(legend_str)
|
||||
|
|
@ -393,11 +394,23 @@ class CircuitLogicDataset(ProceduralDataset):
|
|||
"metadata": {
|
||||
"expression": expression_for_display,
|
||||
"assignments": assignments,
|
||||
"term_strings": term_strings,
|
||||
"final_gate": final_gate_name,
|
||||
"inputs": inputs_list,
|
||||
"legend": legend_str,
|
||||
},
|
||||
}
|
||||
|
||||
def score_answer(self, answer: Optional[str], entry: Dict[str, Any]) -> float:
|
||||
if answer is None or len(answer) == 0:
|
||||
return 0.0
|
||||
|
||||
oracle_answer = entry["answer"]
|
||||
if oracle_answer == answer:
|
||||
return 1.0
|
||||
elif oracle_answer == answer.strip():
|
||||
return len(oracle_answer) / len(answer)
|
||||
|
||||
return 0.01
|
||||
|
||||
|
||||
register_dataset("circuit_logic", CircuitLogicDataset, CircuitLogicConfig)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue