formatting

This commit is contained in:
Andreas Koepf 2025-02-16 16:18:39 +01:00
parent d2d4b3a644
commit 6bf2dfa36c
4 changed files with 7 additions and 20 deletions

View file

@ -102,23 +102,16 @@ class SentenceReorderingDataset(ProceduralDataset):
goal_words = expected_answer.split()
answer_words = answer.split()
if len(goal_words) == len(answer_words):
credit = [1 if goal_word.lower() == answer_word.lower() else 0 for goal_word, answer_word in zip(goal_words, answer_words)]
credit = [
1 if goal_word.lower() == answer_word.lower() else 0
for goal_word, answer_word in zip(goal_words, answer_words)
]
reward = sum(credit) / len(credit)
else:
reward = 0.05
except:
reward = 0.01
return reward
register_dataset("sentence_reordering", SentenceReorderingDataset, SentenceReorderingConfig)

View file

@ -8,7 +8,6 @@ from typing import Dict, List, Optional, Set, Tuple
from ..data import get_data_file_path
from ..factory import ProceduralDataset, register_dataset
QUESTION_TEMPLATE = """Transform the word ladder '{start}' to '{end}' by changing one letter at a time.
Provide your answer as a comma-separated sequence of uppercase letters without spaces.
Each step must be a valid English word."""

View file

@ -225,16 +225,11 @@ class BasicArithmeticDataset(ProceduralDataset):
def _format_question(self, rng: Random, expression: str) -> str:
"""Format the expression with clear answer positioning"""
answer_instruction = "Put your final answer after '=' without additional text."
if self.config.format_style == "simple":
return f"{answer_instruction} Calculate {expression} ="
else:
templates = [
"What is {0} =",
"Solve {0}=",
"Compute {0} =",
"Evaluate: {0} ="
]
templates = ["What is {0} =", "Solve {0}=", "Compute {0} =", "Evaluate: {0} ="]
template = rng.choice(templates).format(expression)
return f"{answer_instruction} {template}"

View file

@ -37,7 +37,7 @@ def test_getitem(dataset, config):
assert "metadata" in item
assert item["metadata"]["word_count"] >= config.min_words_in_sentence
assert item["metadata"]["word_count"] <= config.max_words_in_sentence
assert len(item['answer'].split()) == item['metadata']['word_count']
assert len(item["answer"].split()) == item["metadata"]["word_count"]
def test_key_error_in_getitem(dataset):