atropos/atroposlib/envs/reward_fns/chandas_meter_reward.py

131 lines
3.5 KiB
Python

import logging
import re
from typing import Any, List
try:
from pkg_resources import resource_filename
except Exception: # pragma: no cover - optional dependency
resource_filename = None
from .registry import registry
from .reward_function import RewardFunction
logger = logging.getLogger(__name__)
# Basic mapping from IAST characters/digraphs to SLP1
_IAST_TO_SLP1 = [
("kh", "K"),
("gh", "G"),
("ch", "C"),
("jh", "J"),
("ṭh", "W"),
("ḍh", "Q"),
("th", "T"),
("dh", "D"),
("ph", "P"),
("bh", "B"),
("ai", "E"),
("au", "O"),
("ā", "A"),
("ī", "I"),
("ū", "U"),
("", "f"),
("", "F"),
("", "x"),
("", "X"),
("", "N"),
("ñ", "Y"),
("", "w"),
("", "q"),
("", "R"),
("ś", "S"),
("", "z"),
("", "M"),
("", "M"),
("", "H"),
]
_SINGLE_CHAR_MAP = {
"a": "a",
"i": "i",
"u": "u",
"e": "e",
"o": "o",
"k": "k",
"g": "g",
"c": "c",
"j": "j",
"t": "t",
"d": "d",
"n": "n",
"p": "p",
"b": "b",
"m": "m",
"y": "y",
"r": "r",
"l": "l",
"v": "v",
"s": "s",
"h": "h",
}
_DIGRAPH_RE = re.compile("|".join(re.escape(d[0]) for d in _IAST_TO_SLP1), re.UNICODE)
def iast_to_slp1(text: str) -> str:
"""Convert a string from IAST to SLP1."""
def _replace(match: re.Match) -> str:
for iast, slp in _IAST_TO_SLP1:
if match.group(0) == iast:
return slp
return match.group(0)
text = _DIGRAPH_RE.sub(_replace, text)
return "".join(_SINGLE_CHAR_MAP.get(ch, ch) for ch in text)
@registry.register
class ChandasMeterReward(RewardFunction):
"""Reward based on how closely a poem matches a target Sanskrit meter."""
def __init__(self, meter: str = "tristubh", weight: float = 1.0, **kwargs):
super().__init__(weight=weight, **kwargs)
self.meter = meter
try:
from chandas import Classifier # type: ignore
if resource_filename is not None:
data_path = resource_filename("chandas", "data/data.json")
self.classifier = Classifier.from_json_file(data_path)
else:
self.classifier = Classifier.from_default_location()
except Exception as e: # pragma: no cover - optional dependency
logger.error("Failed to load chandas Classifier: %s", e)
self.classifier = None
def _score_text(self, text: str) -> float:
if not self.classifier:
return 0.0
try:
slp_text = iast_to_slp1(text)
result = self.classifier.classify(slp_text)
if not result:
return 0.0
predicted = getattr(result, "name", str(result)).lower()
raw_score = getattr(
result, "score", 1.0 if predicted == self.meter.lower() else 0.0
)
if predicted != self.meter.lower():
raw_score = 1.0 - raw_score if raw_score <= 1.0 else 0.0
return max(0.0, min(1.0, float(raw_score)))
except Exception as e: # pragma: no cover - runtime safeguard
logger.error("Error scoring text with chandas: %s", e)
return 0.0
def compute(self, completions: List[Any], **kwargs) -> List[float]:
rewards: List[float] = []
for completion in completions:
text = self.get_content(completion)
rewards.append(self._score_text(text))
return rewards