mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-05-02 17:45:58 +00:00
Tolerant scoring for CodeI/O based on edit distances (#277)
* add zss dep * codeio edit distance-based scoring * edit distance tweaks
This commit is contained in:
parent
dfc28c94d6
commit
35c32cd5e7
2 changed files with 86 additions and 14 deletions
|
|
@ -22,6 +22,7 @@ dependencies = [
|
||||||
"tabulate==0.9.0",
|
"tabulate==0.9.0",
|
||||||
"pyyaml>=6.0.2",
|
"pyyaml>=6.0.2",
|
||||||
"arckit==0.1.0",
|
"arckit==0.1.0",
|
||||||
|
"zss>=1.2.0",
|
||||||
]
|
]
|
||||||
classifiers = [
|
classifiers = [
|
||||||
"Programming Language :: Python :: 3",
|
"Programming Language :: Python :: 3",
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,8 @@ from dataclasses import dataclass
|
||||||
from random import Random
|
from random import Random
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
import zss
|
||||||
|
|
||||||
from ..data import get_data_file_path
|
from ..data import get_data_file_path
|
||||||
from ..factory import ProceduralDataset, register_dataset
|
from ..factory import ProceduralDataset, register_dataset
|
||||||
|
|
||||||
|
|
@ -118,8 +120,85 @@ class CodeIODataset(ProceduralDataset):
|
||||||
"metadata": {"input_data": input_data, "output_data": output_data},
|
"metadata": {"input_data": input_data, "output_data": output_data},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def _json_to_tree(self, data, label="root"):
|
||||||
|
"""Recursively convert a JSON dictionary to a ZSS tree."""
|
||||||
|
if isinstance(data, dict):
|
||||||
|
node = zss.Node(label)
|
||||||
|
for key, value in sorted(data.items()):
|
||||||
|
node.addkid(self._json_to_tree(value, key))
|
||||||
|
return node
|
||||||
|
elif isinstance(data, list):
|
||||||
|
node = zss.Node(label)
|
||||||
|
for idx, item in enumerate(data):
|
||||||
|
node.addkid(self._json_to_tree(item, f"item_{idx}"))
|
||||||
|
return node
|
||||||
|
else:
|
||||||
|
return zss.Node(f"{label}:{data}")
|
||||||
|
|
||||||
|
def _compute_json_similarity(self, json1, json2):
|
||||||
|
"""Compute a similarity score in [0, 1] between two JSON dictionaries using tree edit distance."""
|
||||||
|
tree1 = self._json_to_tree(json1)
|
||||||
|
tree2 = self._json_to_tree(json2)
|
||||||
|
|
||||||
|
def _str_edit_distance(str1, str2):
|
||||||
|
"""Compute Levenshtein edit distance between two strings."""
|
||||||
|
m, n = len(str1), len(str2)
|
||||||
|
prev = list(range(n + 1))
|
||||||
|
curr = [0] * (n + 1)
|
||||||
|
for i in range(1, m + 1):
|
||||||
|
curr[0] = i
|
||||||
|
for j in range(1, n + 1):
|
||||||
|
if str1[i - 1] == str2[j - 1]:
|
||||||
|
curr[j] = prev[j - 1]
|
||||||
|
else:
|
||||||
|
curr[j] = 1 + min(prev[j], curr[j - 1], prev[j - 1])
|
||||||
|
prev, curr = curr, prev
|
||||||
|
return prev[n]
|
||||||
|
|
||||||
|
def _tree_node_edit_distance(text1: str, text2: str):
|
||||||
|
"""Compute edit distance between two tree nodes based on their types."""
|
||||||
|
if ":" not in text1 or ":" not in text2:
|
||||||
|
return _str_edit_distance(text1, text2)
|
||||||
|
|
||||||
|
key1, value1 = text1.split(":", 1)
|
||||||
|
key2, value2 = text2.split(":", 1)
|
||||||
|
|
||||||
|
key_dist = _str_edit_distance(key1, key2) if key1 != key2 else 0
|
||||||
|
value_dist = _str_edit_distance(value1, value2) if value1 != value2 else 0
|
||||||
|
|
||||||
|
if value1 != value2:
|
||||||
|
# Numeric, allowing decimals
|
||||||
|
if value1.replace(".", "").isnumeric() and value2.replace(".", "").isnumeric():
|
||||||
|
try:
|
||||||
|
# TODO: Consider a more sophisticated distance metric for numeric values?
|
||||||
|
abs1, abs2 = abs(float(value1)), abs(float(value2))
|
||||||
|
divisor = max(min(abs1, abs2), 10e-5)
|
||||||
|
value_dist += (abs1 - abs2) / divisor
|
||||||
|
except ValueError:
|
||||||
|
# Fall back on string edit distance
|
||||||
|
pass
|
||||||
|
elif value1.isnumeric() or value2.isnumeric():
|
||||||
|
# Penalise severely if the answer is numeric when it shouldn't be, or vice versa
|
||||||
|
value_dist += max(len(text1), len(text2))
|
||||||
|
|
||||||
|
return key_dist + value_dist
|
||||||
|
|
||||||
|
edit_distance = zss.simple_distance(tree1, tree2, label_dist=_tree_node_edit_distance)
|
||||||
|
max_size = max(len(json.dumps(json1)), len(json.dumps(json2)))
|
||||||
|
|
||||||
|
similarity_score = 1 - (edit_distance / (0.2 * max_size))
|
||||||
|
return max(0, similarity_score)
|
||||||
|
|
||||||
|
def _score_answer_json(self, answer_json: dict, oracle_json: dict, max_score: float) -> float:
|
||||||
|
"""If the answer is valid JSON, compute a similarity score between the answer and the oracle JSON."""
|
||||||
|
if answer_json == oracle_json:
|
||||||
|
return max_score
|
||||||
|
else:
|
||||||
|
similarity = self._compute_json_similarity(answer_json, oracle_json)
|
||||||
|
# 0.01 minimum reward, since it produced a valid JSON output
|
||||||
|
return max(similarity * max_score, 0.01)
|
||||||
|
|
||||||
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
|
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
|
||||||
# TODO: this scoring could definitely be refined
|
|
||||||
oracle_answer = entry["answer"].strip()
|
oracle_answer = entry["answer"].strip()
|
||||||
reward = 0.0
|
reward = 0.0
|
||||||
if answer is not None and len(answer) > 0:
|
if answer is not None and len(answer) > 0:
|
||||||
|
|
@ -132,27 +211,19 @@ class CodeIODataset(ProceduralDataset):
|
||||||
ans_first_open, ans_last_close = answer.index("{"), answer.rindex("}")
|
ans_first_open, ans_last_close = answer.index("{"), answer.rindex("}")
|
||||||
extra_chars = len(answer[:ans_first_open]) + len(answer[ans_last_close + 1 :])
|
extra_chars = len(answer[:ans_first_open]) + len(answer[ans_last_close + 1 :])
|
||||||
|
|
||||||
|
# 0.5 is arbitrary here, but the answers are very short so it seems harsh to penalize too much
|
||||||
|
# e.g. if oracle is {"steps": "3"} and answer is "The correct answer is: {"steps": "3"}"
|
||||||
|
max_score = max(len(oracle_answer) / (len(oracle_answer) + 0.5 * extra_chars), 0.2)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
answer_dict = json.loads(answer[ans_first_open : ans_last_close + 1])
|
answer_dict = json.loads(answer[ans_first_open : ans_last_close + 1])
|
||||||
oracle_dict = json.loads(oracle_answer)
|
oracle_dict = json.loads(oracle_answer)
|
||||||
if answer_dict == oracle_dict:
|
return self._score_answer_json(answer_dict, oracle_dict, max_score)
|
||||||
# 0.5 is arbitrary here, but the answers are very short so it seems harsh to penalize too much
|
|
||||||
# e.g. if oracle is {"steps": "3"} and answer is "The correct answer is: {"steps": "3"}"
|
|
||||||
reward = max(len(oracle_answer) / (len(oracle_answer) + 0.5 * extra_chars), 0.2)
|
|
||||||
elif answer_dict.keys() == oracle_dict.keys():
|
|
||||||
# Wrong answer, but at least the right format
|
|
||||||
reward = 0.1
|
|
||||||
else:
|
|
||||||
# At least we got a JSON object, I guess?
|
|
||||||
reward = 0.01
|
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
if oracle_answer in answer:
|
if oracle_answer in answer:
|
||||||
reward = len(oracle_answer) / len(answer)
|
reward = len(oracle_answer) / len(answer)
|
||||||
else:
|
else:
|
||||||
reward = 0.00
|
reward = 0.00
|
||||||
elif oracle_answer in answer:
|
|
||||||
# max() to avoid penalising too heavily, since correct answers are short here
|
|
||||||
reward = max(len(oracle_answer) / len(answer), 0.2)
|
|
||||||
else:
|
else:
|
||||||
reward = 0.00
|
reward = 0.00
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue