lint, seed & size for figlet

This commit is contained in:
Andreas Koepf 2025-01-30 00:58:34 +01:00
parent 25505e3a75
commit fc775eda7e
6 changed files with 2406 additions and 2357 deletions

View file

@ -9,9 +9,9 @@ The goal is to generate virtually infinite data with adjustable complexity.
```
git clone https://github.com/open-thought/reasoning-gym.git
```
2. Create a virtual environment(Here we use conda)
2. Create a virtual environment (here we use conda)
```
conda create --name reasoning_gym python=3.12 -y
conda create --name reasoning_gym python=3.11 -y
conda activate reasoning_gym
```
3. Link project and install dependencies

View file

@ -17,7 +17,6 @@ from openrlhf.trainer import PPOTrainer
from openrlhf.trainer.ppo_utils.experience_maker import Experience, NaiveExperienceMaker, Samples
from openrlhf.utils import blending_datasets, get_strategy, get_tokenizer
from openrlhf.utils.logging_utils import init_logger
from torch.utils.data import Dataset
from transformers.trainer import get_scheduler
@ -180,7 +179,7 @@ class AlgorithmicRewardExperienceMaker(NaiveExperienceMaker):
value = None
# determine outcome reward
completions = sequences[:, -action_mask.size(1):].cpu()
completions = sequences[:, -action_mask.size(1) :].cpu()
completions = self.tokenizer.batch_decode(completions, skip_special_tokens=True)
returns = [
self.dataset.score_answer(extract_answer(c, tag_name="answer"), entry=m)
@ -717,7 +716,7 @@ if __name__ == "__main__":
args = parser.parse_args()
if args.advantage_estimator not in ["gae"]:
args.critic_pretrain = None
args.critic_pretrain = None
elif args.critic_pretrain is None:
args.critic_pretrain = args.pretrain ## temp

View file

@ -1,11 +1,12 @@
from dataclasses import dataclass
import random
import re
import pyfiglet
from typing import List, Optional, Tuple, Dict
from random import Random
from typing import Dict, Optional
import pyfiglet
from ..factory import ProceduralDataset, register_dataset
from ..data.static import wordle_words
from ..factory import ProceduralDataset, register_dataset
@dataclass
class FigletFontConfig:
@ -14,6 +15,9 @@ class FigletFontConfig:
static_word: Optional[str] = None
static_font: Optional[str] = None
space_letters: bool = True
seed: Optional[int] = None
size: int = 500
class FigletFontDataset(ProceduralDataset):
"""Generates FigletFont tasks"""
@ -23,7 +27,7 @@ class FigletFontDataset(ProceduralDataset):
"What word does this say?\n\n{figlet_render}",
"Please read the following figlet font:\n\n{figlet_render}",
]
super().__init__(config=config)
super().__init__(config=config, seed=config.seed, size=config.size)
def __getitem__(self, idx: int) -> dict:
"""Generate a single FigletFont task
@ -34,34 +38,83 @@ class FigletFontDataset(ProceduralDataset):
- answer: str, the figlet encoded word
- metadata: dict with generation parameters
"""
rng = Random(self.seed + idx)
word = self.config.static_word if self.config.static_word is not None else random.choice(wordle_words).upper()
if(self.config.space_letters):
render_word = ' '.join(word)
word = self.config.static_word if self.config.static_word is not None else rng.choice(wordle_words).upper()
if self.config.space_letters:
render_word = " ".join(word)
else:
render_word = word
# These ones are funky and probably aren't good for train/testing
bad_fonts = [
'pyramid', 'runyc', 'assalt_m', 'term', 'tengwar', 'heart_right', 'faces_of', 'heroboti', 'hieroglyphs', 'rainbow_',
'notie_ca', 'ghost', 'rampage_', 'atc_____', 'pacos_pe', 'mad_nurs', 'icl-1900', 'joust___', 'dcs_bfmo', 'letter_w',
'flyn_sh', 'fun_face', 'morse2', 'tecrvs__', 'ntgreek', 'tsalagi', 'etcrvs__', 'faces_of', 'future_8', 'efti_robot',
'danc4', 'p_s_h_m_', 'smkeyboard', 'konto', 'odel_lak', 'courb', 'jerusalem', 'nfi1____', 'keyboard', 'konto_slant'
'rot13', 'mirror', 'katakana', 'cards', 'eftichess', 'heart_left', 'trashman', 'morse', 'eftipiti', 'smtengwar', 'e__fist_',
'mike', 'bear', 'hills___', 'rotated', 'wow', 'eftipiti', 'relief2'
"pyramid",
"runyc",
"assalt_m",
"term",
"tengwar",
"heart_right",
"faces_of",
"heroboti",
"hieroglyphs",
"rainbow_",
"notie_ca",
"ghost",
"rampage_",
"atc_____",
"pacos_pe",
"mad_nurs",
"icl-1900",
"joust___",
"dcs_bfmo",
"letter_w",
"flyn_sh",
"fun_face",
"morse2",
"tecrvs__",
"ntgreek",
"tsalagi",
"etcrvs__",
"faces_of",
"future_8",
"efti_robot",
"danc4",
"p_s_h_m_",
"smkeyboard",
"konto",
"odel_lak",
"courb",
"jerusalem",
"nfi1____",
"keyboard",
"konto_slant" "rot13",
"mirror",
"katakana",
"cards",
"eftichess",
"heart_left",
"trashman",
"morse",
"eftipiti",
"smtengwar",
"e__fist_",
"mike",
"bear",
"hills___",
"rotated",
"wow",
"eftipiti",
"relief2",
]
all_fonts = pyfiglet.FigletFont.getFonts()
ok_fonts = list(filter(lambda x: x not in bad_fonts, all_fonts))
chosen_font = self.config.static_font if self.config.static_font is not None else random.choice(ok_fonts)
chosen_font = self.config.static_font if self.config.static_font is not None else rng.choice(ok_fonts)
figlet_render = pyfiglet.figlet_format(render_word, font=chosen_font)
return {
"question": random.choice(self._prompt_templates).format(figlet_render=figlet_render),
"question": rng.choice(self._prompt_templates).format(figlet_render=figlet_render),
"answer": word,
"metadata": {
"font": chosen_font,
"space_letters": self.config.space_letters
},
"metadata": {"font": chosen_font, "space_letters": self.config.space_letters},
}
def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float:
@ -83,20 +136,21 @@ class FigletFontDataset(ProceduralDataset):
return 0.0 # No answer given
# Normalize case
answer = answer.replace(' ', '').strip().lower()
answer = answer.replace(" ", "").strip().lower()
correct_word = correct_word.strip().lower()
if answer == correct_word:
return 1.0 # Correct!
return 1.0 # Correct!
# Calculate similarity
correct_count = sum(1 for a, b in zip(answer, correct_word) if a == b)
max_length = max(len(correct_word), len(answer))
# Compute a partial score
score = min(correct_count * 0.1, 1.0)
return score
# Register the dataset
register_dataset("FigletFont", FigletFontDataset, FigletFontConfig)
register_dataset("figlet_font", FigletFontDataset, FigletFontConfig)

View file

@ -1,9 +1,10 @@
import re
from dataclasses import dataclass
from random import Random
import re
from typing import Dict, List, Optional
from magiccube.cube import Cube, CubeMove, CubeMoveType
from magiccube.solver.basic.basic_solver import BasicSolver
from typing import List, Optional, Dict
from ..factory import ProceduralDataset, register_dataset

File diff suppressed because it is too large Load diff

View file

@ -5,8 +5,7 @@ from reasoning_gym.cognition.figlet_fonts import FigletFontConfig, FigletFontDat
def test_figlet():
"""Test basic properties and solution of generated items"""
config = FigletFontConfig(
)
config = FigletFontConfig(size=40)
dataset = FigletFontDataset(config)
for item in dataset:
@ -19,15 +18,12 @@ def test_figlet():
assert "font" in item["metadata"]
# Test the scoring
assert dataset.score_answer(answer=item['answer'], entry=item) == 1.0
assert dataset.score_answer(answer=item["answer"], entry=item) == 1.0
def test_static_figlet():
"""Test basic properties and solution of generated items"""
config = FigletFontConfig(
static_word="TESTY",
static_font="caligraphy",
space_letters=False
)
config = FigletFontConfig(static_word="TESTY", static_font="caligraphy", space_letters=False, size=15)
dataset = FigletFontDataset(config)
# Test partial scoring
@ -35,4 +31,3 @@ def test_static_figlet():
assert dataset.score_answer(answer="TESTY", entry=item) == 1.0
assert dataset.score_answer(answer="WESTY", entry=item) == 0.4
assert dataset.score_answer(answer=None, entry=item) == 0
break