mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-19 12:58:07 +00:00
Tolerant scoring for CodeI/O based on edit distances (#277)
* add zss dep * codeio edit distance-based scoring * edit distance tweaks
This commit is contained in:
parent
a8e920b552
commit
f490b9f760
2 changed files with 86 additions and 14 deletions
|
|
@ -4,6 +4,8 @@ from dataclasses import dataclass
|
|||
from random import Random
|
||||
from typing import Any, Optional
|
||||
|
||||
import zss
|
||||
|
||||
from ..data import get_data_file_path
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
|
|
@ -118,8 +120,85 @@ class CodeIODataset(ProceduralDataset):
|
|||
"metadata": {"input_data": input_data, "output_data": output_data},
|
||||
}
|
||||
|
||||
def _json_to_tree(self, data, label="root"):
|
||||
"""Recursively convert a JSON dictionary to a ZSS tree."""
|
||||
if isinstance(data, dict):
|
||||
node = zss.Node(label)
|
||||
for key, value in sorted(data.items()):
|
||||
node.addkid(self._json_to_tree(value, key))
|
||||
return node
|
||||
elif isinstance(data, list):
|
||||
node = zss.Node(label)
|
||||
for idx, item in enumerate(data):
|
||||
node.addkid(self._json_to_tree(item, f"item_{idx}"))
|
||||
return node
|
||||
else:
|
||||
return zss.Node(f"{label}:{data}")
|
||||
|
||||
def _compute_json_similarity(self, json1, json2):
|
||||
"""Compute a similarity score in [0, 1] between two JSON dictionaries using tree edit distance."""
|
||||
tree1 = self._json_to_tree(json1)
|
||||
tree2 = self._json_to_tree(json2)
|
||||
|
||||
def _str_edit_distance(str1, str2):
|
||||
"""Compute Levenshtein edit distance between two strings."""
|
||||
m, n = len(str1), len(str2)
|
||||
prev = list(range(n + 1))
|
||||
curr = [0] * (n + 1)
|
||||
for i in range(1, m + 1):
|
||||
curr[0] = i
|
||||
for j in range(1, n + 1):
|
||||
if str1[i - 1] == str2[j - 1]:
|
||||
curr[j] = prev[j - 1]
|
||||
else:
|
||||
curr[j] = 1 + min(prev[j], curr[j - 1], prev[j - 1])
|
||||
prev, curr = curr, prev
|
||||
return prev[n]
|
||||
|
||||
def _tree_node_edit_distance(text1: str, text2: str):
|
||||
"""Compute edit distance between two tree nodes based on their types."""
|
||||
if ":" not in text1 or ":" not in text2:
|
||||
return _str_edit_distance(text1, text2)
|
||||
|
||||
key1, value1 = text1.split(":", 1)
|
||||
key2, value2 = text2.split(":", 1)
|
||||
|
||||
key_dist = _str_edit_distance(key1, key2) if key1 != key2 else 0
|
||||
value_dist = _str_edit_distance(value1, value2) if value1 != value2 else 0
|
||||
|
||||
if value1 != value2:
|
||||
# Numeric, allowing decimals
|
||||
if value1.replace(".", "").isnumeric() and value2.replace(".", "").isnumeric():
|
||||
try:
|
||||
# TODO: Consider a more sophisticated distance metric for numeric values?
|
||||
abs1, abs2 = abs(float(value1)), abs(float(value2))
|
||||
divisor = max(min(abs1, abs2), 10e-5)
|
||||
value_dist += (abs1 - abs2) / divisor
|
||||
except ValueError:
|
||||
# Fall back on string edit distance
|
||||
pass
|
||||
elif value1.isnumeric() or value2.isnumeric():
|
||||
# Penalise severely if the answer is numeric when it shouldn't be, or vice versa
|
||||
value_dist += max(len(text1), len(text2))
|
||||
|
||||
return key_dist + value_dist
|
||||
|
||||
edit_distance = zss.simple_distance(tree1, tree2, label_dist=_tree_node_edit_distance)
|
||||
max_size = max(len(json.dumps(json1)), len(json.dumps(json2)))
|
||||
|
||||
similarity_score = 1 - (edit_distance / (0.2 * max_size))
|
||||
return max(0, similarity_score)
|
||||
|
||||
def _score_answer_json(self, answer_json: dict, oracle_json: dict, max_score: float) -> float:
|
||||
"""If the answer is valid JSON, compute a similarity score between the answer and the oracle JSON."""
|
||||
if answer_json == oracle_json:
|
||||
return max_score
|
||||
else:
|
||||
similarity = self._compute_json_similarity(answer_json, oracle_json)
|
||||
# 0.01 minimum reward, since it produced a valid JSON output
|
||||
return max(similarity * max_score, 0.01)
|
||||
|
||||
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
|
||||
# TODO: this scoring could definitely be refined
|
||||
oracle_answer = entry["answer"].strip()
|
||||
reward = 0.0
|
||||
if answer is not None and len(answer) > 0:
|
||||
|
|
@ -132,27 +211,19 @@ class CodeIODataset(ProceduralDataset):
|
|||
ans_first_open, ans_last_close = answer.index("{"), answer.rindex("}")
|
||||
extra_chars = len(answer[:ans_first_open]) + len(answer[ans_last_close + 1 :])
|
||||
|
||||
# 0.5 is arbitrary here, but the answers are very short so it seems harsh to penalize too much
|
||||
# e.g. if oracle is {"steps": "3"} and answer is "The correct answer is: {"steps": "3"}"
|
||||
max_score = max(len(oracle_answer) / (len(oracle_answer) + 0.5 * extra_chars), 0.2)
|
||||
|
||||
try:
|
||||
answer_dict = json.loads(answer[ans_first_open : ans_last_close + 1])
|
||||
oracle_dict = json.loads(oracle_answer)
|
||||
if answer_dict == oracle_dict:
|
||||
# 0.5 is arbitrary here, but the answers are very short so it seems harsh to penalize too much
|
||||
# e.g. if oracle is {"steps": "3"} and answer is "The correct answer is: {"steps": "3"}"
|
||||
reward = max(len(oracle_answer) / (len(oracle_answer) + 0.5 * extra_chars), 0.2)
|
||||
elif answer_dict.keys() == oracle_dict.keys():
|
||||
# Wrong answer, but at least the right format
|
||||
reward = 0.1
|
||||
else:
|
||||
# At least we got a JSON object, I guess?
|
||||
reward = 0.01
|
||||
return self._score_answer_json(answer_dict, oracle_dict, max_score)
|
||||
except json.JSONDecodeError:
|
||||
if oracle_answer in answer:
|
||||
reward = len(oracle_answer) / len(answer)
|
||||
else:
|
||||
reward = 0.00
|
||||
elif oracle_answer in answer:
|
||||
# max() to avoid penalising too heavily, since correct answers are short here
|
||||
reward = max(len(oracle_answer) / len(answer), 0.2)
|
||||
else:
|
||||
reward = 0.00
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue