diff --git a/pyproject.toml b/pyproject.toml index 36d44f5d..d19be6bf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,6 +22,7 @@ dependencies = [ "tabulate==0.9.0", "pyyaml>=6.0.2", "arckit==0.1.0", + "zss>=1.2.0", ] classifiers = [ "Programming Language :: Python :: 3", diff --git a/reasoning_gym/code/codeio.py b/reasoning_gym/code/codeio.py index 10f5ee48..258d5e03 100644 --- a/reasoning_gym/code/codeio.py +++ b/reasoning_gym/code/codeio.py @@ -4,6 +4,8 @@ from dataclasses import dataclass from random import Random from typing import Any, Optional +import zss + from ..data import get_data_file_path from ..factory import ProceduralDataset, register_dataset @@ -118,8 +120,85 @@ class CodeIODataset(ProceduralDataset): "metadata": {"input_data": input_data, "output_data": output_data}, } + def _json_to_tree(self, data, label="root"): + """Recursively convert a JSON dictionary to a ZSS tree.""" + if isinstance(data, dict): + node = zss.Node(label) + for key, value in sorted(data.items()): + node.addkid(self._json_to_tree(value, key)) + return node + elif isinstance(data, list): + node = zss.Node(label) + for idx, item in enumerate(data): + node.addkid(self._json_to_tree(item, f"item_{idx}")) + return node + else: + return zss.Node(f"{label}:{data}") + + def _compute_json_similarity(self, json1, json2): + """Compute a similarity score in [0, 1] between two JSON dictionaries using tree edit distance.""" + tree1 = self._json_to_tree(json1) + tree2 = self._json_to_tree(json2) + + def _str_edit_distance(str1, str2): + """Compute Levenshtein edit distance between two strings.""" + m, n = len(str1), len(str2) + prev = list(range(n + 1)) + curr = [0] * (n + 1) + for i in range(1, m + 1): + curr[0] = i + for j in range(1, n + 1): + if str1[i - 1] == str2[j - 1]: + curr[j] = prev[j - 1] + else: + curr[j] = 1 + min(prev[j], curr[j - 1], prev[j - 1]) + prev, curr = curr, prev + return prev[n] + + def _tree_node_edit_distance(text1: str, text2: str): + """Compute edit distance between two tree nodes based on their types.""" + if ":" not in text1 or ":" not in text2: + return _str_edit_distance(text1, text2) + + key1, value1 = text1.split(":", 1) + key2, value2 = text2.split(":", 1) + + key_dist = _str_edit_distance(key1, key2) if key1 != key2 else 0 + value_dist = _str_edit_distance(value1, value2) if value1 != value2 else 0 + + if value1 != value2: + # Numeric, allowing decimals + if value1.replace(".", "").isnumeric() and value2.replace(".", "").isnumeric(): + try: + # TODO: Consider a more sophisticated distance metric for numeric values? + abs1, abs2 = abs(float(value1)), abs(float(value2)) + divisor = max(min(abs1, abs2), 10e-5) + value_dist += (abs1 - abs2) / divisor + except ValueError: + # Fall back on string edit distance + pass + elif value1.isnumeric() or value2.isnumeric(): + # Penalise severely if the answer is numeric when it shouldn't be, or vice versa + value_dist += max(len(text1), len(text2)) + + return key_dist + value_dist + + edit_distance = zss.simple_distance(tree1, tree2, label_dist=_tree_node_edit_distance) + max_size = max(len(json.dumps(json1)), len(json.dumps(json2))) + + similarity_score = 1 - (edit_distance / (0.2 * max_size)) + return max(0, similarity_score) + + def _score_answer_json(self, answer_json: dict, oracle_json: dict, max_score: float) -> float: + """If the answer is valid JSON, compute a similarity score between the answer and the oracle JSON.""" + if answer_json == oracle_json: + return max_score + else: + similarity = self._compute_json_similarity(answer_json, oracle_json) + # 0.01 minimum reward, since it produced a valid JSON output + return max(similarity * max_score, 0.01) + def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float: - # TODO: this scoring could definitely be refined oracle_answer = entry["answer"].strip() reward = 0.0 if answer is not None and len(answer) > 0: @@ -132,27 +211,19 @@ class CodeIODataset(ProceduralDataset): ans_first_open, ans_last_close = answer.index("{"), answer.rindex("}") extra_chars = len(answer[:ans_first_open]) + len(answer[ans_last_close + 1 :]) + # 0.5 is arbitrary here, but the answers are very short so it seems harsh to penalize too much + # e.g. if oracle is {"steps": "3"} and answer is "The correct answer is: {"steps": "3"}" + max_score = max(len(oracle_answer) / (len(oracle_answer) + 0.5 * extra_chars), 0.2) + try: answer_dict = json.loads(answer[ans_first_open : ans_last_close + 1]) oracle_dict = json.loads(oracle_answer) - if answer_dict == oracle_dict: - # 0.5 is arbitrary here, but the answers are very short so it seems harsh to penalize too much - # e.g. if oracle is {"steps": "3"} and answer is "The correct answer is: {"steps": "3"}" - reward = max(len(oracle_answer) / (len(oracle_answer) + 0.5 * extra_chars), 0.2) - elif answer_dict.keys() == oracle_dict.keys(): - # Wrong answer, but at least the right format - reward = 0.1 - else: - # At least we got a JSON object, I guess? - reward = 0.01 + return self._score_answer_json(answer_dict, oracle_dict, max_score) except json.JSONDecodeError: if oracle_answer in answer: reward = len(oracle_answer) / len(answer) else: reward = 0.00 - elif oracle_answer in answer: - # max() to avoid penalising too heavily, since correct answers are short here - reward = max(len(oracle_answer) / len(answer), 0.2) else: reward = 0.00