diff --git a/atroposlib/envs/reward_fns/chandas_meter_reward.py b/atroposlib/envs/reward_fns/chandas_meter_reward.py new file mode 100644 index 00000000..bb73100c --- /dev/null +++ b/atroposlib/envs/reward_fns/chandas_meter_reward.py @@ -0,0 +1,131 @@ +import logging +import re +from typing import Any, List + +try: + from pkg_resources import resource_filename +except Exception: # pragma: no cover - optional dependency + resource_filename = None + +from .registry import registry +from .reward_function import RewardFunction + +logger = logging.getLogger(__name__) + +# Basic mapping from IAST characters/digraphs to SLP1 +_IAST_TO_SLP1 = [ + ("kh", "K"), + ("gh", "G"), + ("ch", "C"), + ("jh", "J"), + ("ṭh", "W"), + ("ḍh", "Q"), + ("th", "T"), + ("dh", "D"), + ("ph", "P"), + ("bh", "B"), + ("ai", "E"), + ("au", "O"), + ("ā", "A"), + ("ī", "I"), + ("ū", "U"), + ("ṛ", "f"), + ("ṝ", "F"), + ("ḷ", "x"), + ("ḹ", "X"), + ("ṅ", "N"), + ("ñ", "Y"), + ("ṭ", "w"), + ("ḍ", "q"), + ("ṇ", "R"), + ("ś", "S"), + ("ṣ", "z"), + ("ṃ", "M"), + ("ṁ", "M"), + ("ḥ", "H"), +] + +_SINGLE_CHAR_MAP = { + "a": "a", + "i": "i", + "u": "u", + "e": "e", + "o": "o", + "k": "k", + "g": "g", + "c": "c", + "j": "j", + "t": "t", + "d": "d", + "n": "n", + "p": "p", + "b": "b", + "m": "m", + "y": "y", + "r": "r", + "l": "l", + "v": "v", + "s": "s", + "h": "h", +} + +_DIGRAPH_RE = re.compile("|".join(re.escape(d[0]) for d in _IAST_TO_SLP1), re.UNICODE) + + +def iast_to_slp1(text: str) -> str: + """Convert a string from IAST to SLP1.""" + + def _replace(match: re.Match) -> str: + for iast, slp in _IAST_TO_SLP1: + if match.group(0) == iast: + return slp + return match.group(0) + + text = _DIGRAPH_RE.sub(_replace, text) + return "".join(_SINGLE_CHAR_MAP.get(ch, ch) for ch in text) + + +@registry.register +class ChandasMeterReward(RewardFunction): + """Reward based on how closely a poem matches a target Sanskrit meter.""" + + def __init__(self, meter: str = "tristubh", weight: float = 1.0, **kwargs): + super().__init__(weight=weight, **kwargs) + self.meter = meter + try: + from chandas import Classifier # type: ignore + + if resource_filename is not None: + data_path = resource_filename("chandas", "data/data.json") + self.classifier = Classifier.from_json_file(data_path) + else: + self.classifier = Classifier.from_default_location() + except Exception as e: # pragma: no cover - optional dependency + logger.error("Failed to load chandas Classifier: %s", e) + self.classifier = None + + def _score_text(self, text: str) -> float: + if not self.classifier: + return 0.0 + try: + slp_text = iast_to_slp1(text) + result = self.classifier.classify(slp_text) + if not result: + return 0.0 + predicted = getattr(result, "name", str(result)).lower() + raw_score = getattr( + result, "score", 1.0 if predicted == self.meter.lower() else 0.0 + ) + if predicted != self.meter.lower(): + raw_score = 1.0 - raw_score if raw_score <= 1.0 else 0.0 + return max(0.0, min(1.0, float(raw_score))) + except Exception as e: # pragma: no cover - runtime safeguard + logger.error("Error scoring text with chandas: %s", e) + return 0.0 + + def compute(self, completions: List[Any], **kwargs) -> List[float]: + rewards: List[float] = [] + for completion in completions: + text = self.get_content(completion) + rewards.append(self._score_text(text)) + return rewards diff --git a/environments/community/README.md b/environments/community/README.md index 9f1defce..7ef1d0ee 100644 --- a/environments/community/README.md +++ b/environments/community/README.md @@ -2179,6 +2179,98 @@ python test_stl_env.py --- +### 25. Sanskrit Poetry Environment (`sanskrit_poetry/`) + +**Contributors**: KhoomeiK +**PR**: [#71](https://github.com/NousResearch/atropos/pull/71) +**Integration Status**: ✅ Integrated + +**Description**: A specialized reinforcement learning environment for generating Sanskrit poetry that adheres to traditional metrical patterns. This environment trains language models to compose authentic Sanskrit verse using the chandas (meter) classification system, combining linguistic knowledge with poetic creativity. + +**Core Features**: + +**Metrical Poetry Generation**: +- **Chandas Meter Validation**: Uses the `chandas` classifier to verify adherence to traditional Sanskrit meters +- **IAST Transliteration**: Supports International Alphabet of Sanskrit Transliteration for accurate representation +- **SLP1 Conversion**: Automatic conversion from IAST to SLP1 encoding for meter analysis +- **Multiple Meter Support**: Configurable target meters including tristubh, anushtubh, and others + +**Reward System Integration**: +- **Registry-Based Rewards**: Leverages Atropos reward function registry for modular scoring +- **ChandasMeterReward**: Custom reward function that scores poetry based on metrical accuracy +- **Weighted Scoring**: Configurable reward weights for different aspects of poetic quality +- **Real-Time Feedback**: Immediate scoring during training for rapid learning + +**Environment Configuration**: +- **Flexible Meter Selection**: Easy configuration of target Sanskrit meters +- **Temperature Control**: Adjustable creativity vs accuracy balance (default 0.7) +- **Token Limits**: Configurable maximum poem length (default 256 tokens) +- **System Prompts**: Customizable instructions for different poetic styles + +**Technical Implementation**: +- **Pydantic Configuration**: Type-safe environment configuration with validation +- **Async Processing**: Non-blocking completion generation for efficient training +- **Trajectory Collection**: Comprehensive conversation tracking for RL training +- **Tokenization Support**: Integration with Atropos tokenization utilities + +**Sanskrit Linguistic Features**: +- **Character Mapping**: Comprehensive IAST to SLP1 character conversion +- **Digraph Handling**: Proper processing of Sanskrit consonant clusters +- **Unicode Support**: Full support for Sanskrit diacritical marks +- **Meter Classification**: Integration with scholarly meter analysis tools + +**Training Workflow**: +- **Prompt Generation**: Automatic creation of meter-specific composition prompts +- **Multi-Sample Generation**: Parallel generation of diverse poetic attempts +- **Metrical Scoring**: Real-time evaluation of generated verses against target meters +- **Iterative Improvement**: RL-based refinement of poetic capabilities + +**Research Applications**: +- **Computational Linguistics**: Study of AI understanding of prosodic patterns +- **Cultural Preservation**: Digital preservation and generation of traditional verse forms +- **Cross-Lingual Poetry**: Exploration of metrical patterns across languages +- **Educational Tools**: Interactive learning systems for Sanskrit prosody + +**External Dependencies**: +- **Chandas Package**: Must be built from [source](https://github.com/sanskrit/chandas) for meter classification +- **Sanskrit Corpus**: Access to traditional texts for training data (optional) +- **Unicode Libraries**: Proper handling of Sanskrit character encoding + +**Configuration Examples**: +```python +# Tristubh meter (11 syllables per quarter) +config = SanskritPoetryEnvConfig( + meter="tristubh", + temperature=0.7, + max_tokens=256 +) + +# Anushtubh meter (8 syllables per quarter) +config = SanskritPoetryEnvConfig( + meter="anushtubh", + temperature=0.8, + max_tokens=128 +) +``` + +**Evaluation Metrics**: +- **Metrical Accuracy**: Percentage of verses matching target meter +- **Linguistic Quality**: Grammatical correctness and vocabulary usage +- **Poetic Coherence**: Thematic consistency and aesthetic appeal +- **Training Efficiency**: Convergence speed and sample efficiency + +**Future Enhancements**: +- **Multi-Meter Compositions**: Training on mixed metrical patterns +- **Semantic Constraints**: Content-aware poetry generation with thematic guidance +- **Historical Styles**: Emulation of specific periods or authors +- **Interactive Composition**: Real-time collaborative poetry creation + +**Educational Value**: This environment serves as an excellent introduction to computational prosody, Sanskrit linguistics, and the intersection of AI with traditional literary forms. It demonstrates how modern ML techniques can be applied to preserve and extend classical cultural knowledge. + +**Requirements**: pydantic, chandas (from source), atroposlib + +--- + ## Support For questions or issues with community environments: diff --git a/environments/community/sanskrit_poetry/sanskrit_poetry_env.py b/environments/community/sanskrit_poetry/sanskrit_poetry_env.py new file mode 100644 index 00000000..64c8b4fc --- /dev/null +++ b/environments/community/sanskrit_poetry/sanskrit_poetry_env.py @@ -0,0 +1,120 @@ +import logging +from typing import List, Optional, Tuple + +from pydantic import Field + +from atroposlib.envs.base import ( + APIServerConfig, + BaseEnv, + BaseEnvConfig, + ScoredDataGroup, +) +from atroposlib.envs.reward_fns.registry import registry +from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer + +logger = logging.getLogger(__name__) + + +class SanskritPoetryEnvConfig(BaseEnvConfig): + meter: str = Field("tristubh", description="Desired Sanskrit meter") + system_prompt: Optional[str] = Field( + "You are a Sanskrit poet. Respond only with the poem in IAST.", + description="System prompt for the model", + ) + temperature: float = 0.7 + top_p: float = 0.9 + max_tokens: int = 256 + + +class SanskritPoetryEnv(BaseEnv): + env_config_cls = SanskritPoetryEnvConfig + + def __init__( + self, + config: SanskritPoetryEnvConfig, + server_configs: List[APIServerConfig], + slurm: bool = True, + testing: bool = False, + ): + super().__init__(config, server_configs, slurm, testing) + # Create reward function using registry for easy configuration + self.reward_fn = registry.create( + {"type": "chandas_meter", "params": {"meter": config.meter}} + ) + self.iter = 0 + + @classmethod + def config_init(cls) -> Tuple[SanskritPoetryEnvConfig, List[APIServerConfig]]: + env_config = SanskritPoetryEnvConfig( + group_size=8, + use_wandb=True, + rollout_server_url="http://localhost:8000", + total_steps=1000, + batch_size=32, + steps_per_eval=50, + max_token_length=512, + wandb_name="sanskrit_poetry", + ) + server_configs = [ + APIServerConfig( + base_url="http://localhost:9001", + api_key="x", + num_requests_for_eval=64, + model_name="Qwen/Qwen3-1.7B", + server_type="trl", + ) + ] + return env_config, server_configs + + async def setup(self): + self.iter = 0 + + async def get_next_item(self): + prompt = ( + f"Compose a four line Sanskrit poem in the {self.config.meter} meter. " + "Use IAST transliteration only." + ) + user_msg = {"role": "user", "content": prompt} + return (tuple([frozenset(user_msg.items())]), None, None) + + async def collect_trajectories(self, item): + user_content = dict(item[0][0])["content"] + messages = [] + if self.config.system_prompt: + messages.append({"role": "system", "content": self.config.system_prompt}) + messages.append({"role": "user", "content": user_content}) + prompt = self.tokenizer.apply_chat_template(messages, tokenize=False) + completions = await self.server.completion( + prompt=prompt, + n=self.config.group_size, + max_tokens=self.config.max_tokens, + temperature=self.config.temperature, + top_p=self.config.top_p, + ) + trajectories = [] + for completion in completions.choices: + completion_text = ( + completion.text + if hasattr(completion, "text") + else completion.message.content + ) + msg_seq = [] + if self.config.system_prompt: + msg_seq.append({"role": "system", "content": self.config.system_prompt}) + msg_seq.append({"role": "user", "content": user_content}) + msg_seq.append({"role": "assistant", "content": completion_text}) + trajectories.append(msg_seq) + return trajectories, [] + + async def score(self, rollout_group_data: List) -> Optional[ScoredDataGroup]: + scored = ScoredDataGroup() + scored["tokens"] = [] + scored["masks"] = [] + scored["scores"] = [] + for traj in rollout_group_data: + reward = self.reward_fn([traj[-1]["content"]])[0] + out_dict = tokenize_for_trainer(self.tokenizer, traj) + scored["tokens"].append(out_dict["tokens"]) + scored["masks"].append(out_dict["masks"]) + scored["scores"].append(reward) + return scored