diff --git a/reasoning_gym/algorithmic/pool_matrix.py b/reasoning_gym/algorithmic/pool_matrix.py index 41b9117d..002c0c0c 100644 --- a/reasoning_gym/algorithmic/pool_matrix.py +++ b/reasoning_gym/algorithmic/pool_matrix.py @@ -102,8 +102,8 @@ class PoolMatrixDataset(ProceduralDataset): reward = 0.0 try: - oracle_answer = np.array(entry["answer"]) - answer = np.array(answer) + oracle_answer = np.loadtxt(entry["answer"].splitlines(), dtype=np.float32) + answer = np.loadtxt(answer.splitlines(), dtype=np.float32) if oracle_answer.shape == answer.shape and np.allclose(oracle_answer, answer): reward = 1.0 elif oracle_answer.shape == answer.shape: diff --git a/tests/test_pool_matrix.py b/tests/test_pool_matrix.py index aa3fe6b6..c110967f 100644 --- a/tests/test_pool_matrix.py +++ b/tests/test_pool_matrix.py @@ -136,3 +136,28 @@ def test_pool_matrix_answer(): ] ) assert np.allclose(dataset._average_pool(matrix, 2), np.array([[3.5, 5.5], [11.5, 13.5]])) + + +def test_pool_matrix_score_answer(): + config = PoolMatrixConfig(seed=42, size=100) + dataset = PoolMatrixDataset(config) + for entry in dataset: + assert dataset.score_answer(entry["answer"], entry=entry) == 1 + assert 0.0 < dataset.score_answer("1 2.0\n3.0 4", entry=entry) <= 0.1 + assert dataset.score_answer("one two three", entry=entry) == 0.0 + assert dataset.score_answer("", entry=entry) == 0.0 + assert dataset.score_answer(None, entry=entry) == 0.0 + + +def test_pool_matrix_int_answer(): + config = PoolMatrixConfig(seed=42, size=10) + dataset = PoolMatrixDataset(config) + for entry in dataset: + matrix = np.loadtxt(entry["answer"].splitlines()) + is_integer = np.equal(np.mod(matrix, 1), 0) + if is_integer.all(): + matrix = matrix.astype(np.int32) + if matrix.ndim == 0: + matrix = matrix.reshape(1, 1) + int_answer = "\n".join(" ".join(str(x) for x in row) for row in matrix) + assert dataset.score_answer(answer=int_answer, entry=entry) == 1.0