From 9c02ebc0548185b70866b31254481489e11fc02c Mon Sep 17 00:00:00 2001 From: Rohan Pandey <32777448+KhoomeiK@users.noreply.github.com> Date: Sun, 18 May 2025 17:26:13 -0700 Subject: [PATCH] Fix chandas reward to use classifier --- .../envs/reward_fns/chandas_meter_reward.py | 129 ++++++++++++++++++ environments/sanskrit_poetry_env.py | 114 ++++++++++++++++ 2 files changed, 243 insertions(+) create mode 100644 atroposlib/envs/reward_fns/chandas_meter_reward.py create mode 100644 environments/sanskrit_poetry_env.py diff --git a/atroposlib/envs/reward_fns/chandas_meter_reward.py b/atroposlib/envs/reward_fns/chandas_meter_reward.py new file mode 100644 index 00000000..8bce5179 --- /dev/null +++ b/atroposlib/envs/reward_fns/chandas_meter_reward.py @@ -0,0 +1,129 @@ +import logging +import re +from typing import Any, List + +try: + from pkg_resources import resource_filename +except Exception: # pragma: no cover - optional dependency + resource_filename = None + +from .registry import registry +from .reward_function import RewardFunction + +logger = logging.getLogger(__name__) + +# Basic mapping from IAST characters/digraphs to SLP1 +_IAST_TO_SLP1 = [ + ("kh", "K"), + ("gh", "G"), + ("ch", "C"), + ("jh", "J"), + ("ṭh", "W"), + ("ḍh", "Q"), + ("th", "T"), + ("dh", "D"), + ("ph", "P"), + ("bh", "B"), + ("ai", "E"), + ("au", "O"), + ("ā", "A"), + ("ī", "I"), + ("ū", "U"), + ("ṛ", "f"), + ("ṝ", "F"), + ("ḷ", "x"), + ("ḹ", "X"), + ("ṅ", "N"), + ("ñ", "Y"), + ("ṭ", "w"), + ("ḍ", "q"), + ("ṇ", "R"), + ("ś", "S"), + ("ṣ", "z"), + ("ṃ", "M"), + ("ṁ", "M"), + ("ḥ", "H"), +] + +_SINGLE_CHAR_MAP = { + "a": "a", + "i": "i", + "u": "u", + "e": "e", + "o": "o", + "k": "k", + "g": "g", + "c": "c", + "j": "j", + "t": "t", + "d": "d", + "n": "n", + "p": "p", + "b": "b", + "m": "m", + "y": "y", + "r": "r", + "l": "l", + "v": "v", + "s": "s", + "h": "h", +} + +_DIGRAPH_RE = re.compile("|".join(re.escape(d[0]) for d in _IAST_TO_SLP1), re.UNICODE) + + +def iast_to_slp1(text: str) -> str: + """Convert a string from IAST to SLP1.""" + def _replace(match: re.Match) -> str: + for iast, slp in _IAST_TO_SLP1: + if match.group(0) == iast: + return slp + return match.group(0) + + text = _DIGRAPH_RE.sub(_replace, text) + return "".join(_SINGLE_CHAR_MAP.get(ch, ch) for ch in text) + + +@registry.register +class ChandasMeterReward(RewardFunction): + """Reward based on how closely a poem matches a target Sanskrit meter.""" + + def __init__(self, meter: str = "tristubh", weight: float = 1.0, **kwargs): + super().__init__(weight=weight, **kwargs) + self.meter = meter + try: + from chandas import Classifier # type: ignore + + if resource_filename is not None: + data_path = resource_filename("chandas", "data/data.json") + self.classifier = Classifier.from_json_file(data_path) + else: + self.classifier = Classifier.from_default_location() + except Exception as e: # pragma: no cover - optional dependency + logger.error("Failed to load chandas Classifier: %s", e) + self.classifier = None + + def _score_text(self, text: str) -> float: + if not self.classifier: + return 0.0 + try: + slp_text = iast_to_slp1(text) + result = self.classifier.classify(slp_text) + if not result: + return 0.0 + predicted = getattr(result, "name", str(result)).lower() + raw_score = getattr(result, "score", 1.0 if predicted == self.meter.lower() else 0.0) + if predicted != self.meter.lower(): + raw_score = 1.0 - raw_score if raw_score <= 1.0 else 0.0 + return max(0.0, min(1.0, float(raw_score))) + except Exception as e: # pragma: no cover - runtime safeguard + logger.error("Error scoring text with chandas: %s", e) + return 0.0 + + def compute(self, completions: List[Any], **kwargs) -> List[float]: + rewards: List[float] = [] + for completion in completions: + text = self.get_content(completion) + rewards.append(self._score_text(text)) + return rewards + diff --git a/environments/sanskrit_poetry_env.py b/environments/sanskrit_poetry_env.py new file mode 100644 index 00000000..d65c6ac6 --- /dev/null +++ b/environments/sanskrit_poetry_env.py @@ -0,0 +1,114 @@ +import logging +from typing import List, Optional, Tuple + +from pydantic import Field + +from atroposlib.envs.base import APIServerConfig, BaseEnv, BaseEnvConfig, ScoredDataGroup +from atroposlib.envs.reward_fns.registry import registry +from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer + +logger = logging.getLogger(__name__) + + +class SanskritPoetryEnvConfig(BaseEnvConfig): + meter: str = Field("tristubh", description="Desired Sanskrit meter") + system_prompt: Optional[str] = Field( + "You are a Sanskrit poet. Respond only with the poem in IAST.", + description="System prompt for the model", + ) + temperature: float = 0.7 + top_p: float = 0.9 + max_tokens: int = 256 + + +class SanskritPoetryEnv(BaseEnv): + env_config_cls = SanskritPoetryEnvConfig + + def __init__( + self, + config: SanskritPoetryEnvConfig, + server_configs: List[APIServerConfig], + slurm: bool = True, + testing: bool = False, + ): + super().__init__(config, server_configs, slurm, testing) + # Create reward function using registry for easy configuration + self.reward_fn = registry.create( + {"type": "chandas_meter", "params": {"meter": config.meter}} + ) + self.iter = 0 + + @classmethod + def config_init(cls) -> Tuple[SanskritPoetryEnvConfig, List[APIServerConfig]]: + env_config = SanskritPoetryEnvConfig( + group_size=8, + use_wandb=True, + rollout_server_url="http://localhost:8000", + total_steps=1000, + batch_size=32, + steps_per_eval=50, + max_token_length=512, + wandb_name="sanskrit_poetry", + ) + server_configs = [ + APIServerConfig( + base_url="http://localhost:9001", + api_key="x", + num_requests_for_eval=64, + model_name="Qwen/Qwen3-1.7B", + server_type="trl", + ) + ] + return env_config, server_configs + + async def setup(self): + self.iter = 0 + + async def get_next_item(self): + prompt = ( + f"Compose a four line Sanskrit poem in the {self.config.meter} meter. " + "Use IAST transliteration only." + ) + user_msg = {"role": "user", "content": prompt} + return (tuple([frozenset(user_msg.items())]), None, None) + + async def collect_trajectories(self, item): + user_content = dict(item[0][0])["content"] + messages = [] + if self.config.system_prompt: + messages.append({"role": "system", "content": self.config.system_prompt}) + messages.append({"role": "user", "content": user_content}) + prompt = self.tokenizer.apply_chat_template(messages, tokenize=False) + completions = await self.server.completion( + prompt=prompt, + n=self.config.group_size, + max_tokens=self.config.max_tokens, + temperature=self.config.temperature, + top_p=self.config.top_p, + ) + trajectories = [] + for completion in completions.choices: + completion_text = ( + completion.text if hasattr(completion, "text") else completion.message.content + ) + msg_seq = [] + if self.config.system_prompt: + msg_seq.append({"role": "system", "content": self.config.system_prompt}) + msg_seq.append({"role": "user", "content": user_content}) + msg_seq.append({"role": "assistant", "content": completion_text}) + trajectories.append(msg_seq) + return trajectories, [] + + async def score(self, rollout_group_data: List) -> Optional[ScoredDataGroup]: + scored = ScoredDataGroup() + scored["tokens"] = [] + scored["masks"] = [] + scored["scores"] = [] + for traj in rollout_group_data: + reward = self.reward_fn([traj[-1]["content"]])[0] + out_dict = tokenize_for_trainer(self.tokenizer, traj) + scored["tokens"].append(out_dict["tokens"]) + scored["masks"].append(out_dict["masks"]) + scored["scores"].append(reward) + return scored +