From 425ae24f3b22a783f753568d17e7e6efc8d60a59 Mon Sep 17 00:00:00 2001 From: joesharratt1229 Date: Fri, 21 Feb 2025 17:57:41 +0000 Subject: [PATCH] added emoji dataset --- reasoning_gym/games/__init__.py | 3 + reasoning_gym/games/emoji_mystery.py | 230 +++++++++++++++++++++++++++ 2 files changed, 233 insertions(+) create mode 100644 reasoning_gym/games/emoji_mystery.py diff --git a/reasoning_gym/games/__init__.py b/reasoning_gym/games/__init__.py index dd1ed898..db1db169 100644 --- a/reasoning_gym/games/__init__.py +++ b/reasoning_gym/games/__init__.py @@ -7,6 +7,7 @@ Game tasks for training reasoning capabilities: """ from .countdown import CountdownConfig, CountdownDataset +from .emoji_mystery import EmojiMysteryConfig, EmojiMysteryDataset from .game_of_life import GameOfLifeConfig, GameOfLifeDataset from .knight_swap import KnightSwapConfig, KnightSwapDataset from .maze import MazeConfig, MazeDataset @@ -20,6 +21,8 @@ from .tsumego import TsumegoConfig, TsumegoDataset __all__ = [ "CountdownConfig", "CountdownDataset", + "EmojiMysteryConfig", + "EmojiMysteryDataset", "MiniSudokuConfig", "MiniSudokuDataset", "SudokuConfig", diff --git a/reasoning_gym/games/emoji_mystery.py b/reasoning_gym/games/emoji_mystery.py new file mode 100644 index 00000000..508a338b --- /dev/null +++ b/reasoning_gym/games/emoji_mystery.py @@ -0,0 +1,230 @@ +import re +from dataclasses import dataclass +from random import Random +from typing import Any, Dict, List, Optional, Tuple + +from ..data import read_data_file +from ..factory import ProceduralDataset, register_dataset + +_EMOJIS = [ + "๐Ÿ˜€", + "๐Ÿ˜ƒ", + "๐Ÿ˜„", + "๐Ÿ˜", + "๐Ÿ˜†", + "๐Ÿ˜…", + "๐Ÿคฃ", + "๐Ÿ˜‚", + "๐Ÿ™‚", + "๐Ÿ™ƒ", + "๐Ÿ˜‰", + "๐Ÿ˜Š", + "๐Ÿ˜‡", + "๐Ÿฅฐ", + "๐Ÿ˜", + "๐Ÿคฉ", + "๐Ÿ˜˜", + "๐Ÿ˜—", + "๐Ÿ˜š", + "๐Ÿ˜™", + "๐Ÿฅฒ", + "๐Ÿ˜‹", + "๐Ÿ˜›", + "๐Ÿ˜œ", + "๐Ÿคช", + "๐Ÿ˜", + "๐Ÿค‘", + "๐Ÿค—", + "๐Ÿคญ", + "๐Ÿคซ", + "๐Ÿค”", + "๐Ÿค", + "๐Ÿคจ", + "๐Ÿ˜", + "๐Ÿ˜‘", + "๐Ÿ˜ถ", + "๐Ÿ˜", + "๐Ÿ˜’", + "๐Ÿ™„", + "๐Ÿ˜ฌ", + "๐Ÿ˜ฎ", + "๐Ÿ˜ฏ", + "๐Ÿ˜ฒ", + "๐Ÿ˜ณ", + "๐Ÿฅบ", + "๐Ÿ˜ฆ", + "๐Ÿ˜ง", + "๐Ÿ˜จ", + "๐Ÿ˜ฐ", + "๐Ÿ˜ฅ", + "๐Ÿ˜ข", + "๐Ÿ˜ญ", + "๐Ÿ˜ฑ", + "๐Ÿ˜–", + "๐Ÿ˜ฃ", + "๐Ÿ˜ž", + "๐Ÿ˜“", + "๐Ÿ˜ฉ", + "๐Ÿ˜ซ", + "๐Ÿฅฑ", + "๐Ÿ˜ค", + "๐Ÿ˜ก", + "๐Ÿ˜ ", + "๐Ÿคฌ", + "๐Ÿ˜ˆ", + "๐Ÿ‘ฟ", + "๐Ÿ’€", + "โ˜ ", + "๐Ÿ’ฉ", + "๐Ÿคก", + "๐Ÿ‘น", + "๐Ÿ‘บ", + "๐Ÿ‘ป", + "๐Ÿ‘ฝ", + "๐Ÿ‘พ", + "๐Ÿค–", + "๐Ÿ˜บ", + "๐Ÿ˜ธ", + "๐Ÿ˜น", + "๐Ÿ˜ป", + "๐Ÿ˜ผ", + "๐Ÿ˜ฝ", + "๐Ÿ™€", + "๐Ÿ˜ฟ", + "๐Ÿ˜พ", + "๐Ÿ™ˆ", + "๐Ÿ™‰", + "๐Ÿ™Š", + "๐Ÿ’‹", + "๐Ÿ’Œ", + "๐Ÿ’˜", + "๐Ÿ’", + "๐Ÿ’–", + "๐Ÿ’—", + "๐Ÿ’“", + "๐Ÿ’ž", + "๐Ÿ’•", + "๐Ÿ’Ÿ", + "โฃ", + "๐Ÿ’”", + "โค๏ธ", + "๐Ÿงก", + "๐Ÿ’›", + "๐Ÿ’š", + "๐Ÿ’™", + "๐Ÿ’œ", + "๐ŸคŽ", + "๐Ÿ–ค", + "๐Ÿค", +] + + +hint_function = """def variance_selector_to_byte(variation_selector): + variation_selector_codepoint = ord(variation_selector) + if 0xFE00 <= variation_selector_codepoint <= 0xFE0F: + return variation_selector_codepoint - 0xFE00 + elif 0xE0100 <= variation_selector_codepoint <= 0xE01EF: + return variation_selector_codepoint - 0xE0100 + 16 + else: + return None +def decode(encoded_sentence): + decoded_bytes = [] + variation_selectors_part = encoded_sentence[1:] + for char in variation_selectors_part: + byte_val = variance_selector_to_byte(char) + if byte_val is not None: + decoded_bytes.append(byte_val) + return bytes(decoded_bytes).decode('utf-8')""" + + +QUESTION_TEMPLATE = "\n".join( + [ + "The following emoji is encoded with a sentence", + "Decode the following sentence from the emoji: {sentence}", + "Here is a hint: {hint_function}", + ] +) + + +@dataclass +class EmojiMysteryConfig: + """Configuration for Emoji Mystery task generation""" + + size: int = 1000 + seed: Optional[int] = None + min_words_in_sentence: int = 3 + max_words_in_sentence: int = 35 + + def validate(self): + assert self.min_words_in_sentence > 0, "min_words_in_sentence must be positive" + assert ( + self.max_words_in_sentence >= self.min_words_in_sentence + ), "max_words_in_sentence must be >= min_words_in_sentence" + assert self.size > 0, "size must be positive" + + +class EmojiMysteryDataset(ProceduralDataset): + def __init__(self, config: EmojiMysteryConfig): + super().__init__(config=config, seed=config.seed, size=config.size) + text = read_data_file("in_the_year_2889.txt") + self.emojis = _EMOJIS + self.sentences = [ + sentence.strip() + for sentence in re.findall(r"[^.!?]+[.!?]", text) + if self.config.min_words_in_sentence + <= len(re.findall(r"\b\w+\b", sentence)) + <= self.config.max_words_in_sentence + ] + + def __getitem__(self, idx: int) -> Dict[str, Any]: + rng = Random(self.seed + idx) + secret_emoji = rng.choice(self.emojis) + secret_sentence = rng.choice(self.sentences).strip().replace("\n", " ") + encoded_sentence = self.encode(secret_sentence, secret_emoji) + question = QUESTION_TEMPLATE.format(sentence=encoded_sentence, hint_function=hint_function) + return {"question": question, "answer": secret_sentence, "metadata": {"emoji": secret_emoji}} + + def variance_selector_to_byte(self, variation_selector: str) -> Optional[int]: + variation_selector_codepoint = ord(variation_selector) + if 0xFE00 <= variation_selector_codepoint <= 0xFE0F: + return variation_selector_codepoint - 0xFE00 + elif 0xE0100 <= variation_selector_codepoint <= 0xE01EF: + return variation_selector_codepoint - 0xE0100 + 16 + + def decode(self, encoded_sentence: str) -> str: + decoded_bytes = [] + variation_selectors_part = encoded_sentence[1:] + + for char in variation_selectors_part: + byte_val = self.variance_selector_to_byte(char) + if byte_val is not None: + decoded_bytes.append(byte_val) + return bytes(decoded_bytes).decode("utf-8") + + def byte_to_variance_selector(self, byte: bytes) -> bytes: + if byte < 16: + return chr(0xFE00 + byte) + else: + return chr(0xE0100 + (byte - 16)) + + def encode(self, sentence: str, base: str) -> str: + encoded_bytes = sentence.encode("utf-8") + return base + "".join(self.byte_to_variance_selector(b) for b in encoded_bytes) + + def score_answer(self, answer: str | None, entry: Dict[str, Any]) -> int: + reward = 0.0 + if answer is not None: + try: + if answer == entry["answer"]: + return 1.0 + elif len(answer) == len(entry["answer"]): + score = [1.0 if a == b else 0.0 for a, b in zip(answer, entry["answer"])] + reward = sum(score) / len(score) + else: + reward = 0.01 + except: + reward = 0.01 + return reward + + +register_dataset("emoji_mystery", EmojiMysteryDataset, EmojiMysteryConfig)