mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-30 17:40:45 +00:00
add ArcAgiDataset class, fix score_entry() metadata params
This commit is contained in:
parent
2ad0965fdc
commit
4e49806d22
20 changed files with 194 additions and 93 deletions
|
|
@ -21,6 +21,7 @@ dependencies = [
|
||||||
"pytz>=2024.1",
|
"pytz>=2024.1",
|
||||||
"tabulate==0.9.0",
|
"tabulate==0.9.0",
|
||||||
"pyyaml>=6.0.2",
|
"pyyaml>=6.0.2",
|
||||||
|
"arckit==0.1.0",
|
||||||
]
|
]
|
||||||
classifiers = [
|
classifiers = [
|
||||||
"Programming Language :: Python :: 3",
|
"Programming Language :: Python :: 3",
|
||||||
|
|
|
||||||
|
|
@ -127,11 +127,12 @@ class ComplexArithmeticDataset(ProceduralDataset):
|
||||||
|
|
||||||
return student_result
|
return student_result
|
||||||
|
|
||||||
def score_answer(self, answer: str, metadata: dict) -> float:
|
def score_answer(self, answer: Optional[str], entry: dict) -> float:
|
||||||
"""Score the answer using exponential distance-based scoring."""
|
"""Score the answer using exponential distance-based scoring."""
|
||||||
if answer is None:
|
if answer is None:
|
||||||
return 0.0
|
return 0.0
|
||||||
|
|
||||||
|
metadata = entry["metadata"]
|
||||||
try:
|
try:
|
||||||
student_result = self.parse_string_to_complex(answer)
|
student_result = self.parse_string_to_complex(answer)
|
||||||
expected_result = complex(*metadata["result"])
|
expected_result = complex(*metadata["result"])
|
||||||
|
|
|
||||||
|
|
@ -235,9 +235,10 @@ class IntermediateIntegrationDataset(ProceduralDataset):
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
def score_answer(self, answer: Optional[str], metadata: Dict[str, Any]) -> float:
|
def score_answer(self, answer: Optional[str], entry: Dict[str, Any]) -> float:
|
||||||
"""Determine if the solution provided solves the problem"""
|
"""Determine if the solution provided solves the problem"""
|
||||||
reward = 0.0
|
reward = 0.0
|
||||||
|
metadata = entry["metadata"]
|
||||||
if answer is not None:
|
if answer is not None:
|
||||||
try:
|
try:
|
||||||
var = metadata["variable"]
|
var = metadata["variable"]
|
||||||
|
|
|
||||||
|
|
@ -138,8 +138,9 @@ class PolynomialMultiplicationDataset(ProceduralDataset):
|
||||||
|
|
||||||
return polynomial_expr
|
return polynomial_expr
|
||||||
|
|
||||||
def score_answer(self, answer: Optional[str], metadata: Dict[str, Any]) -> float:
|
def score_answer(self, answer: Optional[str], entry: Dict[str, Any]) -> float:
|
||||||
reward = 0.0
|
reward = 0.0
|
||||||
|
metadata = entry["metadata"]
|
||||||
if answer is not None:
|
if answer is not None:
|
||||||
try:
|
try:
|
||||||
predicted_poly = sp.parse_expr(answer)
|
predicted_poly = sp.parse_expr(answer)
|
||||||
|
|
|
||||||
|
|
@ -80,9 +80,10 @@ class SimpleIntegrationDataset(ProceduralDataset):
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
def score_answer(self, answer: Optional[str], metadata: Dict[str, Any]) -> float:
|
def score_answer(self, answer: Optional[str], entry: Dict[str, Any]) -> float:
|
||||||
"""Determine if the solution provided solves the problem"""
|
"""Determine if the solution provided solves the problem"""
|
||||||
reward = 0.0
|
reward = 0.0
|
||||||
|
metadata = entry["metadata"]
|
||||||
if answer is not None:
|
if answer is not None:
|
||||||
try:
|
try:
|
||||||
var = metadata["variable"]
|
var = metadata["variable"]
|
||||||
|
|
|
||||||
|
|
@ -81,7 +81,7 @@ class PalindromeDataset(ProceduralDataset):
|
||||||
"""Return the palindrome string from the letter set."""
|
"""Return the palindrome string from the letter set."""
|
||||||
return "".join(letters)
|
return "".join(letters)
|
||||||
|
|
||||||
def score_answer(self, answer: Optional[str], metadata: Dict[str, Any]) -> float:
|
def score_answer(self, answer: Optional[str], entry: Dict[str, Any]) -> float:
|
||||||
"""Determine if the solution provided is a valid palindrome.
|
"""Determine if the solution provided is a valid palindrome.
|
||||||
The answer is expected to be a single string
|
The answer is expected to be a single string
|
||||||
|
|
||||||
|
|
@ -98,6 +98,7 @@ class PalindromeDataset(ProceduralDataset):
|
||||||
if answer == "":
|
if answer == "":
|
||||||
return 0.01
|
return 0.01
|
||||||
|
|
||||||
|
metadata = entry["metadata"]
|
||||||
answer = answer.strip().lower()
|
answer = answer.strip().lower()
|
||||||
expected_letters = metadata["letters"]
|
expected_letters = metadata["letters"]
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,5 @@
|
||||||
from .arc_1d import Arc1DConfig, Arc1DDataset
|
from .arc_1d import Arc1DConfig, Arc1DDataset
|
||||||
|
from .arc_agi import ArcAgiConfig, ArcAgiDataset
|
||||||
from .rearc import ReArcConfig, ReArcDataset
|
from .rearc import ReArcConfig, ReArcDataset
|
||||||
|
|
||||||
__all__ = ["Arc1DConfig", "Arc1DDataset", "ReArcDataset", "ReArcConfig"]
|
__all__ = ["Arc1DConfig", "Arc1DDataset", "ArcAgiConfig", "ArcAgiDataset", "ReArcDataset", "ReArcConfig"]
|
||||||
|
|
|
||||||
110
reasoning_gym/arc/arc_agi.py
Normal file
110
reasoning_gym/arc/arc_agi.py
Normal file
|
|
@ -0,0 +1,110 @@
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from random import Random
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
import arckit
|
||||||
|
|
||||||
|
from reasoning_gym.arc.board_format import (
|
||||||
|
ARC_PROMPT_TEMPLATE,
|
||||||
|
BoardFormattingOptions,
|
||||||
|
format_board,
|
||||||
|
format_board_pair,
|
||||||
|
parse_board,
|
||||||
|
)
|
||||||
|
from reasoning_gym.dataset import ProceduralDataset
|
||||||
|
from reasoning_gym.factory import register_dataset
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ArcAgiConfig:
|
||||||
|
use_train: bool = True
|
||||||
|
use_eval: bool = True
|
||||||
|
board_format_opts: BoardFormattingOptions = field(default_factory=lambda: BoardFormattingOptions())
|
||||||
|
seed: Optional[int] = None
|
||||||
|
size: int = 500
|
||||||
|
|
||||||
|
def validate(self):
|
||||||
|
assert self.size > 0, "Size of dataset must be positive."
|
||||||
|
|
||||||
|
|
||||||
|
class ArcAgiDataset(ProceduralDataset):
|
||||||
|
def __init__(self, config: ArcAgiConfig):
|
||||||
|
super().__init__(config=config, seed=config.seed, size=config.size)
|
||||||
|
self.board_format_opts = config.board_format_opts
|
||||||
|
self._prompt_templates = ARC_PROMPT_TEMPLATE
|
||||||
|
|
||||||
|
self._tasks = {}
|
||||||
|
train_set, eval_set = arckit.load_data()
|
||||||
|
if config.use_train:
|
||||||
|
for x in train_set:
|
||||||
|
self._tasks[x.id] = x.to_dict()
|
||||||
|
if config.use_eval:
|
||||||
|
for x in eval_set:
|
||||||
|
self._tasks[x.id] = x.to_dict()
|
||||||
|
self._task_ids = list(self._tasks.keys())
|
||||||
|
|
||||||
|
def __getitem__(self, idx: int) -> dict:
|
||||||
|
"""
|
||||||
|
Generate a single ARC-AGI-1 task
|
||||||
|
"""
|
||||||
|
rng = Random(self.seed + idx)
|
||||||
|
|
||||||
|
task_id = rng.choice(self._task_ids)
|
||||||
|
task = self._tasks[task_id]
|
||||||
|
|
||||||
|
train = task["train"]
|
||||||
|
test = task["test"][0]
|
||||||
|
|
||||||
|
examples = [
|
||||||
|
format_board_pair(i + 1, p, formatting_options=self.config.board_format_opts) for i, p in enumerate(train)
|
||||||
|
]
|
||||||
|
examples = "".join(examples)
|
||||||
|
test_input = format_board(test["input"], self.board_format_opts)
|
||||||
|
test_output = format_board(test["output"], self.board_format_opts)
|
||||||
|
|
||||||
|
input_prompt = self._prompt_templates.format(examples=examples, input_grid=test_input)
|
||||||
|
|
||||||
|
def totuple(board: list[list[int]]) -> tuple[tuple[int, ...], ...]:
|
||||||
|
return tuple(tuple(r) for r in board)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"question": input_prompt,
|
||||||
|
"answer": test_output,
|
||||||
|
"metadata": {
|
||||||
|
"input": totuple(test["input"]),
|
||||||
|
"output": totuple(test["output"]),
|
||||||
|
"task_id": task_id,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
|
||||||
|
reward = 0.0
|
||||||
|
metadata = entry["metadata"]
|
||||||
|
if answer is not None:
|
||||||
|
try:
|
||||||
|
answer_board = parse_board(answer, self.board_format_opts)
|
||||||
|
if answer_board == metadata["output"]:
|
||||||
|
reward = 1.0
|
||||||
|
else:
|
||||||
|
reward = 0.05
|
||||||
|
except:
|
||||||
|
reward = 0.01
|
||||||
|
return reward
|
||||||
|
|
||||||
|
|
||||||
|
register_dataset("arc_agi", ArcAgiDataset, ArcAgiConfig)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
cfg = ArcAgiConfig(seed=99)
|
||||||
|
test = ArcAgiDataset(cfg)
|
||||||
|
|
||||||
|
x = test[1]
|
||||||
|
|
||||||
|
a = """1 6 7
|
||||||
|
6 7 6
|
||||||
|
2 2 6"""
|
||||||
|
|
||||||
|
print("q:", x["question"])
|
||||||
|
print("a:", x["answer"])
|
||||||
|
print("score:", test.score_answer(answer=a, entry=x))
|
||||||
|
|
@ -1,6 +1,16 @@
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import List, Tuple
|
from typing import List, Tuple
|
||||||
|
|
||||||
|
ARC_PROMPT_TEMPLATE = """Find the common rule that maps an input grid to an output grid, given the examples below.
|
||||||
|
|
||||||
|
{examples}
|
||||||
|
Below is a test input grid. Predict the corresponding output grid by applying the rule you found.
|
||||||
|
Your final answer should just be the text output grid itself.
|
||||||
|
|
||||||
|
Input:
|
||||||
|
{input_grid}
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BoardFormattingOptions:
|
class BoardFormattingOptions:
|
||||||
|
|
@ -10,26 +20,6 @@ class BoardFormattingOptions:
|
||||||
array_brackets: bool = False
|
array_brackets: bool = False
|
||||||
|
|
||||||
|
|
||||||
def format_arc_task(
|
|
||||||
input_grid: Tuple[Tuple[int, ...], ...], output_grid: Tuple[Tuple[int, ...], ...], options: BoardFormattingOptions
|
|
||||||
) -> str:
|
|
||||||
"""
|
|
||||||
Format an ARC task as a string
|
|
||||||
"""
|
|
||||||
|
|
||||||
buffer = []
|
|
||||||
if options.task_identifier:
|
|
||||||
buffer.append(f"ARC Task: {options.task_identifier}")
|
|
||||||
|
|
||||||
buffer.append("\nInput Grid:")
|
|
||||||
buffer.append(format_board(input_grid, options))
|
|
||||||
|
|
||||||
buffer.append("\n\nOutput Grid:")
|
|
||||||
buffer.append(format_board(output_grid, options))
|
|
||||||
|
|
||||||
return "\n".join(buffer)
|
|
||||||
|
|
||||||
|
|
||||||
def format_board(
|
def format_board(
|
||||||
board: List[List[int]], formatting_options: BoardFormattingOptions, with_board_shape: bool = False
|
board: List[List[int]], formatting_options: BoardFormattingOptions, with_board_shape: bool = False
|
||||||
) -> str:
|
) -> str:
|
||||||
|
|
|
||||||
|
|
@ -3,17 +3,7 @@ from random import Random
|
||||||
from typing import Any, Callable, Dict, Optional
|
from typing import Any, Callable, Dict, Optional
|
||||||
|
|
||||||
from ..factory import ProceduralDataset, register_dataset
|
from ..factory import ProceduralDataset, register_dataset
|
||||||
from .board_format import BoardFormattingOptions, format_board, format_board_pair, parse_board
|
from .board_format import ARC_PROMPT_TEMPLATE, BoardFormattingOptions, format_board, format_board_pair, parse_board
|
||||||
|
|
||||||
_REARC_PROMPT_TEMPLATES = """Find the common rule that maps an input grid to an output grid, given the examples below.
|
|
||||||
|
|
||||||
{examples}
|
|
||||||
Below is a test input grid. Predict the corresponding output grid by applying the rule you found.
|
|
||||||
Your final answer should just be the text output grid itself.
|
|
||||||
|
|
||||||
Input:
|
|
||||||
{input_grid}
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
@ -37,7 +27,7 @@ class ReArcDataset(ProceduralDataset):
|
||||||
def __init__(self, config: ReArcConfig):
|
def __init__(self, config: ReArcConfig):
|
||||||
super().__init__(config=config, seed=config.seed, size=config.size)
|
super().__init__(config=config, seed=config.seed, size=config.size)
|
||||||
self.board_format_opts = config.board_format_opts
|
self.board_format_opts = config.board_format_opts
|
||||||
self._prompt_templates = _REARC_PROMPT_TEMPLATES
|
self._prompt_templates = ARC_PROMPT_TEMPLATE
|
||||||
self.diff_lb = config.diff_lb
|
self.diff_lb = config.diff_lb
|
||||||
self.diff_ub = config.diff_ub
|
self.diff_ub = config.diff_ub
|
||||||
|
|
||||||
|
|
@ -89,10 +79,11 @@ class ReArcDataset(ProceduralDataset):
|
||||||
rng_difficulty = self.get_rng_difficulty(rng)
|
rng_difficulty = self.get_rng_difficulty(rng)
|
||||||
pso_difficulty = self.get_pso_difficulty(task)
|
pso_difficulty = self.get_pso_difficulty(task)
|
||||||
input_prompt = self.format_rearc_input(rng, task, generator)
|
input_prompt = self.format_rearc_input(rng, task, generator)
|
||||||
|
answer = format_board(task["output"], self.board_format_opts)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"question": input_prompt,
|
"question": input_prompt,
|
||||||
"answer": task["output"],
|
"answer": answer,
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"input": task["input"],
|
"input": task["input"],
|
||||||
"output": task["output"],
|
"output": task["output"],
|
||||||
|
|
@ -104,12 +95,13 @@ class ReArcDataset(ProceduralDataset):
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
def score_answer(self, answer: str, metadata: Dict[str, Any]) -> float:
|
def score_answer(self, answer: str, entry: Dict[str, Any]) -> float:
|
||||||
reward = 0.0
|
reward = 0.0
|
||||||
|
metadata = entry["metadata"]
|
||||||
if answer is not None:
|
if answer is not None:
|
||||||
try:
|
try:
|
||||||
formatted_answer = parse_board(answer, self.board_format_opts)
|
answer_board = parse_board(answer, self.board_format_opts)
|
||||||
if formatted_answer == metadata["output"]:
|
if answer_board == metadata["output"]:
|
||||||
reward = 1.0
|
reward = 1.0
|
||||||
else:
|
else:
|
||||||
reward = 0.05
|
reward = 0.05
|
||||||
|
|
|
||||||
|
|
@ -159,9 +159,10 @@ class CountdownDataset(ProceduralDataset):
|
||||||
|
|
||||||
raise ValueError(f"Failed to generate valid expression after {max_attempts} attempts")
|
raise ValueError(f"Failed to generate valid expression after {max_attempts} attempts")
|
||||||
|
|
||||||
def score_answer(self, answer: Optional[str], metadata: Dict[str, Any]) -> float:
|
def score_answer(self, answer: Optional[str], entry: Dict[str, Any]) -> float:
|
||||||
"""Determine if the solution provided solves the problem"""
|
"""Determine if the solution provided solves the problem"""
|
||||||
reward = 0.0
|
reward = 0.0
|
||||||
|
metadata = entry["metadata"]
|
||||||
if answer is not None:
|
if answer is not None:
|
||||||
try:
|
try:
|
||||||
user_answer = int(parse_expr(answer))
|
user_answer = int(parse_expr(answer))
|
||||||
|
|
|
||||||
|
|
@ -368,7 +368,7 @@ class HanoiDataset(ProceduralDataset):
|
||||||
to_peg = int(match.group(3))
|
to_peg = int(match.group(3))
|
||||||
return disk, from_peg, to_peg
|
return disk, from_peg, to_peg
|
||||||
|
|
||||||
def score_answer(self, answer: Optional[str], metadata: Dict[str, Any]) -> float:
|
def score_answer(self, answer: Optional[str], entry: Dict[str, Any]) -> float:
|
||||||
"""
|
"""
|
||||||
Score the user's solution for the Tower of Hanoi puzzle.
|
Score the user's solution for the Tower of Hanoi puzzle.
|
||||||
|
|
||||||
|
|
@ -398,6 +398,7 @@ class HanoiDataset(ProceduralDataset):
|
||||||
return 0.0
|
return 0.0
|
||||||
|
|
||||||
# Build the initial peg state from metadata.
|
# Build the initial peg state from metadata.
|
||||||
|
metadata = entry["metadata"]
|
||||||
num_disks = metadata["num_disks"]
|
num_disks = metadata["num_disks"]
|
||||||
num_pegs = metadata["num_pegs"]
|
num_pegs = metadata["num_pegs"]
|
||||||
start_peg = metadata["start_peg"]
|
start_peg = metadata["start_peg"]
|
||||||
|
|
|
||||||
|
|
@ -52,30 +52,30 @@ def test_complex_arithmetic_scoring():
|
||||||
dataset = ComplexArithmeticDataset(config)
|
dataset = ComplexArithmeticDataset(config)
|
||||||
|
|
||||||
# Test case with answer 3 + 2i
|
# Test case with answer 3 + 2i
|
||||||
metadata = {"result": (3.0, 2.0)}
|
entry = {"metadata": {"result": (3.0, 2.0)}}
|
||||||
|
|
||||||
# Test exact matches (should get score of 1.0)
|
# Test exact matches (should get score of 1.0)
|
||||||
assert dataset.score_answer("3 + 2i", metadata) == 1.0
|
assert dataset.score_answer("3 + 2i", entry) == 1.0
|
||||||
assert dataset.score_answer("3+2i", metadata) == 1.0
|
assert dataset.score_answer("3+2i", entry) == 1.0
|
||||||
assert dataset.score_answer("3.0 + 2.0i", metadata) == 1.0
|
assert dataset.score_answer("3.0 + 2.0i", entry) == 1.0
|
||||||
|
|
||||||
# Test answers with small errors (should get high but < 1.0 scores)
|
# Test answers with small errors (should get high but < 1.0 scores)
|
||||||
print(dataset.score_answer("3.1 + 2i", metadata))
|
print(dataset.score_answer("3.1 + 2i", entry))
|
||||||
assert 0.9 < dataset.score_answer("3.1 + 2i", metadata) < 1.0
|
assert 0.9 < dataset.score_answer("3.1 + 2i", entry) < 1.0
|
||||||
assert 0.9 < dataset.score_answer("3 + 2.1i", metadata) < 1.0
|
assert 0.9 < dataset.score_answer("3 + 2.1i", entry) < 1.0
|
||||||
assert 0.7 < dataset.score_answer("3.1 + 2.1i", metadata) < 0.95
|
assert 0.7 < dataset.score_answer("3.1 + 2.1i", entry) < 0.95
|
||||||
|
|
||||||
# Test answers with moderate errors (should get medium scores)
|
# Test answers with moderate errors (should get medium scores)
|
||||||
assert 0.3 < dataset.score_answer("4 + 2i", metadata) < 0.4
|
assert 0.3 < dataset.score_answer("4 + 2i", entry) < 0.4
|
||||||
assert 0.3 < dataset.score_answer("3 + 3i", metadata) < 0.4
|
assert 0.3 < dataset.score_answer("3 + 3i", entry) < 0.4
|
||||||
|
|
||||||
# Test answers with large errors (should get very low scores)
|
# Test answers with large errors (should get very low scores)
|
||||||
assert dataset.score_answer("10 + 10i", metadata) < 0.01
|
assert dataset.score_answer("10 + 10i", entry) < 0.01
|
||||||
|
|
||||||
# Test invalid answers (should get 0.0)
|
# Test invalid answers (should get 0.0)
|
||||||
assert dataset.score_answer("invalid", metadata) == 0.0
|
assert dataset.score_answer("invalid", entry) == 0.0
|
||||||
assert dataset.score_answer(None, metadata) == 0.0
|
assert dataset.score_answer(None, entry) == 0.0
|
||||||
assert dataset.score_answer("inf + 2i", metadata) == 0.0
|
assert dataset.score_answer("inf + 2i", entry) == 0.0
|
||||||
|
|
||||||
|
|
||||||
def test_complex_arithmetic_division_by_zero():
|
def test_complex_arithmetic_division_by_zero():
|
||||||
|
|
|
||||||
|
|
@ -66,13 +66,13 @@ def test_countdown_game_items():
|
||||||
expr = item["metadata"]["expression"]
|
expr = item["metadata"]["expression"]
|
||||||
|
|
||||||
# check score
|
# check score
|
||||||
assert dataset.score_answer(answer=expr, metadata=item["metadata"]) == 1.0 # correct answer
|
assert dataset.score_answer(answer=expr, entry=item) == 1.0 # correct answer
|
||||||
assert dataset.score_answer(answer="45+2", metadata=item["metadata"]) == 0.05 # wrong answer but an attempt
|
assert dataset.score_answer(answer="45+2", entry=item) == 0.05 # wrong answer but an attempt
|
||||||
assert (
|
assert (
|
||||||
dataset.score_answer(answer="a wrong solution", metadata=item["metadata"]) == 0.01
|
dataset.score_answer(answer="a wrong solution", entry=item) == 0.01
|
||||||
) # wrong answer but incorrectly formatted
|
) # wrong answer but incorrectly formatted
|
||||||
assert dataset.score_answer(answer="", metadata=item["metadata"]) == 0.01 # wrong answer but empty string
|
assert dataset.score_answer(answer="", entry=item) == 0.01 # wrong answer but empty string
|
||||||
assert dataset.score_answer(answer=None, metadata=item["metadata"]) == 0.0 # no answer
|
assert dataset.score_answer(answer=None, entry=item) == 0.0 # no answer
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = eval(expr) # Safe here since we control expression generation
|
result = eval(expr) # Safe here since we control expression generation
|
||||||
|
|
|
||||||
|
|
@ -100,7 +100,7 @@ def test_verify_answer():
|
||||||
dataset = IntermediateIntegrationDataset(config)
|
dataset = IntermediateIntegrationDataset(config)
|
||||||
for i in range(len(dataset)):
|
for i in range(len(dataset)):
|
||||||
item = dataset[i]
|
item = dataset[i]
|
||||||
score = dataset.score_answer(item["answer"], item["metadata"])
|
score = dataset.score_answer(answer=item["answer"], entry=item)
|
||||||
assert score == 1.0
|
assert score == 1.0
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -140,5 +140,6 @@ def test_score_answer_cases():
|
||||||
]
|
]
|
||||||
|
|
||||||
for answer, metadata, expected in test_cases:
|
for answer, metadata, expected in test_cases:
|
||||||
score = dataset.score_answer(answer, metadata)
|
dummy_entry = {"metadata": metadata}
|
||||||
|
score = dataset.score_answer(answer, entry=dummy_entry)
|
||||||
assert score == expected, f"Failed case: {answer} | Expected {expected}, got {score}"
|
assert score == expected, f"Failed case: {answer} | Expected {expected}, got {score}"
|
||||||
|
|
|
||||||
|
|
@ -72,21 +72,20 @@ def test_score_answer():
|
||||||
|
|
||||||
for item in dataset:
|
for item in dataset:
|
||||||
correct_answer = item["answer"]
|
correct_answer = item["answer"]
|
||||||
metadata = item["metadata"]
|
|
||||||
|
|
||||||
# Correct answer should score 1.0
|
# Correct answer should score 1.0
|
||||||
assert dataset.score_answer(correct_answer, metadata) == 1.0
|
assert dataset.score_answer(correct_answer, entry=item) == 1.0
|
||||||
|
|
||||||
# Incorrect answer (palindrome, but not correct one) should score 0.05
|
# Incorrect answer (palindrome, but not correct one) should score 0.05
|
||||||
pal_letters = "racecar" if "racecar" != correct_answer else "aba"
|
pal_letters = "racecar" if "racecar" != correct_answer else "aba"
|
||||||
assert dataset.score_answer(pal_letters, metadata) == 0.05
|
assert dataset.score_answer(pal_letters, entry=item) == 0.05
|
||||||
|
|
||||||
# Incorrect answer (not palindrome) should score 0.02
|
# Incorrect answer (not palindrome) should score 0.02
|
||||||
wrong_letters = "abcd" if "abcd" != correct_answer else "efgh"
|
wrong_letters = "abcd" if "abcd" != correct_answer else "efgh"
|
||||||
assert dataset.score_answer(wrong_letters, metadata) == 0.02
|
assert dataset.score_answer(wrong_letters, entry=item) == 0.02
|
||||||
|
|
||||||
# Empty String input should score 0.01
|
# Empty String input should score 0.01
|
||||||
assert dataset.score_answer("", metadata) == 0.01
|
assert dataset.score_answer("", entry=item) == 0.01
|
||||||
|
|
||||||
# Empty input should score 0.0
|
# Empty input should score 0.0
|
||||||
assert dataset.score_answer(None, metadata) == 0.0
|
assert dataset.score_answer(None, entry=item) == 0.0
|
||||||
|
|
|
||||||
|
|
@ -137,10 +137,10 @@ def test_score_function():
|
||||||
seed=42,
|
seed=42,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert ds.score_answer(None, ds[0]["metadata"]) == 0.00
|
assert ds.score_answer(None, ds[0]) == 0.00
|
||||||
assert ds.score_answer("6*x**4 + 9*x**3 - 6*x**2 - 39*x - 45", ds[0]["metadata"]) == 1
|
assert ds.score_answer("6*x**4 + 9*x**3 - 6*x**2 - 39*x - 45", ds[0]) == 1
|
||||||
assert ds.score_answer("Not a polynomial", ds[0]["metadata"]) == 0.01
|
assert ds.score_answer("Not a polynomial", ds[0]) == 0.01
|
||||||
assert ds.score_answer("x**4", ds[0]["metadata"]) == 0.05
|
assert ds.score_answer("x**4", ds[0]) == 0.05
|
||||||
|
|
||||||
|
|
||||||
def test_multivariate_score_function():
|
def test_multivariate_score_function():
|
||||||
|
|
@ -160,7 +160,7 @@ def test_multivariate_score_function():
|
||||||
seed=42,
|
seed=42,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert ds.score_answer(None, ds[0]["metadata"]) == 0.00
|
assert ds.score_answer(None, ds[0]) == 0.00
|
||||||
assert ds.score_answer("-27*a**3*c - 27*a**3 + 144*a*c + 144*a", ds[0]["metadata"]) == 1
|
assert ds.score_answer("-27*a**3*c - 27*a**3 + 144*a*c + 144*a", ds[0]) == 1
|
||||||
assert ds.score_answer("Not a polynomial", ds[0]["metadata"]) == 0.01
|
assert ds.score_answer("Not a polynomial", ds[0]) == 0.01
|
||||||
assert ds.score_answer("x**4", ds[0]["metadata"]) == 0.05
|
assert ds.score_answer("x**4", ds[0]) == 0.05
|
||||||
|
|
|
||||||
|
|
@ -54,7 +54,7 @@ def test_rearc_solution_validation():
|
||||||
for item in dataset:
|
for item in dataset:
|
||||||
# Test correct solution
|
# Test correct solution
|
||||||
correct = format_board(item["metadata"]["output"], dataset.board_format_opts)
|
correct = format_board(item["metadata"]["output"], dataset.board_format_opts)
|
||||||
assert dataset.score_answer(correct, item["metadata"]) == 1.0
|
assert dataset.score_answer(correct, entry=item) == 1.0
|
||||||
|
|
||||||
# Test invalid format
|
# Test invalid format
|
||||||
invalid_grid = """
|
invalid_grid = """
|
||||||
|
|
@ -63,10 +63,10 @@ def test_rearc_solution_validation():
|
||||||
7 8 7
|
7 8 7
|
||||||
0 0 0
|
0 0 0
|
||||||
"""
|
"""
|
||||||
assert dataset.score_answer(invalid_grid, item["metadata"]) == 0.05
|
assert dataset.score_answer(invalid_grid, entry=item) == 0.05
|
||||||
|
|
||||||
# Test empty answer
|
# Test empty answer
|
||||||
assert dataset.score_answer(None, item["metadata"]) == 0.0
|
assert dataset.score_answer(None, entry=item) == 0.0
|
||||||
|
|
||||||
|
|
||||||
def test_rearc_scoring_edge_cases():
|
def test_rearc_scoring_edge_cases():
|
||||||
|
|
@ -77,11 +77,11 @@ def test_rearc_scoring_edge_cases():
|
||||||
for item in dataset:
|
for item in dataset:
|
||||||
# Partial match
|
# Partial match
|
||||||
partial = format_board([[0, 0], [0, 0]], dataset.board_format_opts)
|
partial = format_board([[0, 0], [0, 0]], dataset.board_format_opts)
|
||||||
assert 0.0 < dataset.score_answer(partial, item["metadata"]) < 1.0
|
assert 0.0 < dataset.score_answer(partial, entry=item) < 1.0
|
||||||
|
|
||||||
# Malformed answer
|
# Malformed answer
|
||||||
assert dataset.score_answer("[[invalid", item["metadata"]) == 0.01
|
assert dataset.score_answer("[[invalid", entry=item) == 0.01
|
||||||
|
|
||||||
# Case sensitivity
|
# Case sensitivity
|
||||||
answer = format_board(item["metadata"]["output"], dataset.board_format_opts).lower()
|
answer = format_board(item["metadata"]["output"], dataset.board_format_opts).lower()
|
||||||
assert dataset.score_answer(answer, item["metadata"]) == 1.0
|
assert dataset.score_answer(answer, entry=item) == 1.0
|
||||||
|
|
|
||||||
|
|
@ -73,7 +73,7 @@ def test_verify_answer():
|
||||||
dataset = SimpleIntegrationDataset(config)
|
dataset = SimpleIntegrationDataset(config)
|
||||||
for i in range(len(dataset)):
|
for i in range(len(dataset)):
|
||||||
item = dataset[i]
|
item = dataset[i]
|
||||||
score = dataset.score_answer(item["answer"], item["metadata"])
|
score = dataset.score_answer(item["answer"], item)
|
||||||
assert score == 1.0
|
assert score == 1.0
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -113,5 +113,6 @@ def test_score_answer_cases():
|
||||||
]
|
]
|
||||||
|
|
||||||
for answer, metadata, expected in test_cases:
|
for answer, metadata, expected in test_cases:
|
||||||
score = dataset.score_answer(answer, metadata)
|
dummy_entry = {"metadata": metadata}
|
||||||
|
score = dataset.score_answer(answer=answer, entry=dummy_entry)
|
||||||
assert score == expected, f"Failed case: {answer} | Expected {expected}, got {score}"
|
assert score == expected, f"Failed case: {answer} | Expected {expected}, got {score}"
|
||||||
|
|
|
||||||
|
|
@ -245,27 +245,26 @@ def test_score_answer():
|
||||||
dataset = HanoiDataset(config)
|
dataset = HanoiDataset(config)
|
||||||
# Pick one instance from the dataset for testing.
|
# Pick one instance from the dataset for testing.
|
||||||
item = dataset[0]
|
item = dataset[0]
|
||||||
metadata = item["metadata"]
|
|
||||||
correct_answer = item["answer"]
|
correct_answer = item["answer"]
|
||||||
|
|
||||||
# 1. Correct answer should yield full reward.
|
# 1. Correct answer should yield full reward.
|
||||||
score_correct = dataset.score_answer(answer=correct_answer, metadata=metadata)
|
score_correct = dataset.score_answer(answer=correct_answer, entry=item)
|
||||||
assert score_correct == 1.0, f"Correct answer score {score_correct} is not 1.0."
|
assert score_correct == 1.0, f"Correct answer score {score_correct} is not 1.0."
|
||||||
|
|
||||||
# 2. A badly formatted answer should yield minimal reward (0.01).
|
# 2. A badly formatted answer should yield minimal reward (0.01).
|
||||||
score_bad_format = dataset.score_answer(answer="a wrong solution", metadata=metadata)
|
score_bad_format = dataset.score_answer(answer="a wrong solution", entry=item)
|
||||||
assert score_bad_format == 0.01, f"Badly formatted answer score {score_bad_format} is not 0.01."
|
assert score_bad_format == 0.01, f"Badly formatted answer score {score_bad_format} is not 0.01."
|
||||||
|
|
||||||
# 3. An answer that is validly formatted but unsolved.
|
# 3. An answer that is validly formatted but unsolved.
|
||||||
# For example, remove the last move from the correct answer.
|
# For example, remove the last move from the correct answer.
|
||||||
unfinished_answer = correct_answer[:-1]
|
unfinished_answer = correct_answer[:-1]
|
||||||
score_unsolved = dataset.score_answer(answer=unfinished_answer, metadata=metadata)
|
score_unsolved = dataset.score_answer(answer=unfinished_answer, entry=item)
|
||||||
assert score_unsolved == 0.05, f"Unsolved answer score {score_unsolved} is not 0.05."
|
assert score_unsolved == 0.05, f"Unsolved answer score {score_unsolved} is not 0.05."
|
||||||
|
|
||||||
# 4. An empty answer should yield 0.01.
|
# 4. An empty answer should yield 0.01.
|
||||||
score_empty = dataset.score_answer(answer="", metadata=metadata)
|
score_empty = dataset.score_answer(answer="", entry=item)
|
||||||
assert score_empty == 0.01, f"Empty answer score {score_empty} is not 0.01."
|
assert score_empty == 0.01, f"Empty answer score {score_empty} is not 0.01."
|
||||||
|
|
||||||
# 5. A None answer should yield 0.0.
|
# 5. A None answer should yield 0.0.
|
||||||
score_none = dataset.score_answer(answer=None, metadata=metadata)
|
score_none = dataset.score_answer(answer=None, entry=item)
|
||||||
assert score_none == 0.0, f"None answer score {score_none} is not 0.0."
|
assert score_none == 0.0, f"None answer score {score_none} is not 0.0."
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue