mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-19 12:58:07 +00:00
lint, seed & size for figlet
This commit is contained in:
parent
25505e3a75
commit
fc775eda7e
6 changed files with 2406 additions and 2357 deletions
|
|
@ -9,9 +9,9 @@ The goal is to generate virtually infinite data with adjustable complexity.
|
|||
```
|
||||
git clone https://github.com/open-thought/reasoning-gym.git
|
||||
```
|
||||
2. Create a virtual environment(Here we use conda)
|
||||
2. Create a virtual environment (here we use conda)
|
||||
```
|
||||
conda create --name reasoning_gym python=3.12 -y
|
||||
conda create --name reasoning_gym python=3.11 -y
|
||||
conda activate reasoning_gym
|
||||
```
|
||||
3. Link project and install dependencies
|
||||
|
|
|
|||
|
|
@ -17,7 +17,6 @@ from openrlhf.trainer import PPOTrainer
|
|||
from openrlhf.trainer.ppo_utils.experience_maker import Experience, NaiveExperienceMaker, Samples
|
||||
from openrlhf.utils import blending_datasets, get_strategy, get_tokenizer
|
||||
from openrlhf.utils.logging_utils import init_logger
|
||||
|
||||
from torch.utils.data import Dataset
|
||||
from transformers.trainer import get_scheduler
|
||||
|
||||
|
|
@ -180,7 +179,7 @@ class AlgorithmicRewardExperienceMaker(NaiveExperienceMaker):
|
|||
value = None
|
||||
|
||||
# determine outcome reward
|
||||
completions = sequences[:, -action_mask.size(1):].cpu()
|
||||
completions = sequences[:, -action_mask.size(1) :].cpu()
|
||||
completions = self.tokenizer.batch_decode(completions, skip_special_tokens=True)
|
||||
returns = [
|
||||
self.dataset.score_answer(extract_answer(c, tag_name="answer"), entry=m)
|
||||
|
|
@ -717,7 +716,7 @@ if __name__ == "__main__":
|
|||
args = parser.parse_args()
|
||||
|
||||
if args.advantage_estimator not in ["gae"]:
|
||||
args.critic_pretrain = None
|
||||
args.critic_pretrain = None
|
||||
elif args.critic_pretrain is None:
|
||||
args.critic_pretrain = args.pretrain ## temp
|
||||
|
||||
|
|
|
|||
|
|
@ -1,11 +1,12 @@
|
|||
from dataclasses import dataclass
|
||||
import random
|
||||
import re
|
||||
import pyfiglet
|
||||
from typing import List, Optional, Tuple, Dict
|
||||
from random import Random
|
||||
from typing import Dict, Optional
|
||||
|
||||
import pyfiglet
|
||||
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
from ..data.static import wordle_words
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
|
||||
@dataclass
|
||||
class FigletFontConfig:
|
||||
|
|
@ -14,6 +15,9 @@ class FigletFontConfig:
|
|||
static_word: Optional[str] = None
|
||||
static_font: Optional[str] = None
|
||||
space_letters: bool = True
|
||||
seed: Optional[int] = None
|
||||
size: int = 500
|
||||
|
||||
|
||||
class FigletFontDataset(ProceduralDataset):
|
||||
"""Generates FigletFont tasks"""
|
||||
|
|
@ -23,7 +27,7 @@ class FigletFontDataset(ProceduralDataset):
|
|||
"What word does this say?\n\n{figlet_render}",
|
||||
"Please read the following figlet font:\n\n{figlet_render}",
|
||||
]
|
||||
super().__init__(config=config)
|
||||
super().__init__(config=config, seed=config.seed, size=config.size)
|
||||
|
||||
def __getitem__(self, idx: int) -> dict:
|
||||
"""Generate a single FigletFont task
|
||||
|
|
@ -34,34 +38,83 @@ class FigletFontDataset(ProceduralDataset):
|
|||
- answer: str, the figlet encoded word
|
||||
- metadata: dict with generation parameters
|
||||
"""
|
||||
rng = Random(self.seed + idx)
|
||||
|
||||
word = self.config.static_word if self.config.static_word is not None else random.choice(wordle_words).upper()
|
||||
if(self.config.space_letters):
|
||||
render_word = ' '.join(word)
|
||||
word = self.config.static_word if self.config.static_word is not None else rng.choice(wordle_words).upper()
|
||||
if self.config.space_letters:
|
||||
render_word = " ".join(word)
|
||||
else:
|
||||
render_word = word
|
||||
|
||||
# These ones are funky and probably aren't good for train/testing
|
||||
bad_fonts = [
|
||||
'pyramid', 'runyc', 'assalt_m', 'term', 'tengwar', 'heart_right', 'faces_of', 'heroboti', 'hieroglyphs', 'rainbow_',
|
||||
'notie_ca', 'ghost', 'rampage_', 'atc_____', 'pacos_pe', 'mad_nurs', 'icl-1900', 'joust___', 'dcs_bfmo', 'letter_w',
|
||||
'flyn_sh', 'fun_face', 'morse2', 'tecrvs__', 'ntgreek', 'tsalagi', 'etcrvs__', 'faces_of', 'future_8', 'efti_robot',
|
||||
'danc4', 'p_s_h_m_', 'smkeyboard', 'konto', 'odel_lak', 'courb', 'jerusalem', 'nfi1____', 'keyboard', 'konto_slant'
|
||||
'rot13', 'mirror', 'katakana', 'cards', 'eftichess', 'heart_left', 'trashman', 'morse', 'eftipiti', 'smtengwar', 'e__fist_',
|
||||
'mike', 'bear', 'hills___', 'rotated', 'wow', 'eftipiti', 'relief2'
|
||||
"pyramid",
|
||||
"runyc",
|
||||
"assalt_m",
|
||||
"term",
|
||||
"tengwar",
|
||||
"heart_right",
|
||||
"faces_of",
|
||||
"heroboti",
|
||||
"hieroglyphs",
|
||||
"rainbow_",
|
||||
"notie_ca",
|
||||
"ghost",
|
||||
"rampage_",
|
||||
"atc_____",
|
||||
"pacos_pe",
|
||||
"mad_nurs",
|
||||
"icl-1900",
|
||||
"joust___",
|
||||
"dcs_bfmo",
|
||||
"letter_w",
|
||||
"flyn_sh",
|
||||
"fun_face",
|
||||
"morse2",
|
||||
"tecrvs__",
|
||||
"ntgreek",
|
||||
"tsalagi",
|
||||
"etcrvs__",
|
||||
"faces_of",
|
||||
"future_8",
|
||||
"efti_robot",
|
||||
"danc4",
|
||||
"p_s_h_m_",
|
||||
"smkeyboard",
|
||||
"konto",
|
||||
"odel_lak",
|
||||
"courb",
|
||||
"jerusalem",
|
||||
"nfi1____",
|
||||
"keyboard",
|
||||
"konto_slant" "rot13",
|
||||
"mirror",
|
||||
"katakana",
|
||||
"cards",
|
||||
"eftichess",
|
||||
"heart_left",
|
||||
"trashman",
|
||||
"morse",
|
||||
"eftipiti",
|
||||
"smtengwar",
|
||||
"e__fist_",
|
||||
"mike",
|
||||
"bear",
|
||||
"hills___",
|
||||
"rotated",
|
||||
"wow",
|
||||
"eftipiti",
|
||||
"relief2",
|
||||
]
|
||||
all_fonts = pyfiglet.FigletFont.getFonts()
|
||||
ok_fonts = list(filter(lambda x: x not in bad_fonts, all_fonts))
|
||||
chosen_font = self.config.static_font if self.config.static_font is not None else random.choice(ok_fonts)
|
||||
chosen_font = self.config.static_font if self.config.static_font is not None else rng.choice(ok_fonts)
|
||||
figlet_render = pyfiglet.figlet_format(render_word, font=chosen_font)
|
||||
|
||||
return {
|
||||
"question": random.choice(self._prompt_templates).format(figlet_render=figlet_render),
|
||||
"question": rng.choice(self._prompt_templates).format(figlet_render=figlet_render),
|
||||
"answer": word,
|
||||
"metadata": {
|
||||
"font": chosen_font,
|
||||
"space_letters": self.config.space_letters
|
||||
},
|
||||
"metadata": {"font": chosen_font, "space_letters": self.config.space_letters},
|
||||
}
|
||||
|
||||
def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float:
|
||||
|
|
@ -83,20 +136,21 @@ class FigletFontDataset(ProceduralDataset):
|
|||
return 0.0 # No answer given
|
||||
|
||||
# Normalize case
|
||||
answer = answer.replace(' ', '').strip().lower()
|
||||
answer = answer.replace(" ", "").strip().lower()
|
||||
correct_word = correct_word.strip().lower()
|
||||
|
||||
if answer == correct_word:
|
||||
return 1.0 # Correct!
|
||||
return 1.0 # Correct!
|
||||
|
||||
# Calculate similarity
|
||||
correct_count = sum(1 for a, b in zip(answer, correct_word) if a == b)
|
||||
max_length = max(len(correct_word), len(answer))
|
||||
|
||||
|
||||
# Compute a partial score
|
||||
score = min(correct_count * 0.1, 1.0)
|
||||
|
||||
return score
|
||||
|
||||
|
||||
# Register the dataset
|
||||
register_dataset("FigletFont", FigletFontDataset, FigletFontConfig)
|
||||
register_dataset("figlet_font", FigletFontDataset, FigletFontConfig)
|
||||
|
|
|
|||
|
|
@ -1,9 +1,10 @@
|
|||
import re
|
||||
from dataclasses import dataclass
|
||||
from random import Random
|
||||
import re
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from magiccube.cube import Cube, CubeMove, CubeMoveType
|
||||
from magiccube.solver.basic.basic_solver import BasicSolver
|
||||
from typing import List, Optional, Dict
|
||||
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
|
|
@ -5,8 +5,7 @@ from reasoning_gym.cognition.figlet_fonts import FigletFontConfig, FigletFontDat
|
|||
|
||||
def test_figlet():
|
||||
"""Test basic properties and solution of generated items"""
|
||||
config = FigletFontConfig(
|
||||
)
|
||||
config = FigletFontConfig(size=40)
|
||||
dataset = FigletFontDataset(config)
|
||||
|
||||
for item in dataset:
|
||||
|
|
@ -19,15 +18,12 @@ def test_figlet():
|
|||
assert "font" in item["metadata"]
|
||||
|
||||
# Test the scoring
|
||||
assert dataset.score_answer(answer=item['answer'], entry=item) == 1.0
|
||||
assert dataset.score_answer(answer=item["answer"], entry=item) == 1.0
|
||||
|
||||
|
||||
def test_static_figlet():
|
||||
"""Test basic properties and solution of generated items"""
|
||||
config = FigletFontConfig(
|
||||
static_word="TESTY",
|
||||
static_font="caligraphy",
|
||||
space_letters=False
|
||||
)
|
||||
config = FigletFontConfig(static_word="TESTY", static_font="caligraphy", space_letters=False, size=15)
|
||||
dataset = FigletFontDataset(config)
|
||||
|
||||
# Test partial scoring
|
||||
|
|
@ -35,4 +31,3 @@ def test_static_figlet():
|
|||
assert dataset.score_answer(answer="TESTY", entry=item) == 1.0
|
||||
assert dataset.score_answer(answer="WESTY", entry=item) == 0.4
|
||||
assert dataset.score_answer(answer=None, entry=item) == 0
|
||||
break
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue