formatting

This commit is contained in:
Andreas Koepf 2025-02-16 16:18:39 +01:00
parent d2d4b3a644
commit 6bf2dfa36c
4 changed files with 7 additions and 20 deletions

View file

@ -102,7 +102,10 @@ class SentenceReorderingDataset(ProceduralDataset):
goal_words = expected_answer.split() goal_words = expected_answer.split()
answer_words = answer.split() answer_words = answer.split()
if len(goal_words) == len(answer_words): if len(goal_words) == len(answer_words):
credit = [1 if goal_word.lower() == answer_word.lower() else 0 for goal_word, answer_word in zip(goal_words, answer_words)] credit = [
1 if goal_word.lower() == answer_word.lower() else 0
for goal_word, answer_word in zip(goal_words, answer_words)
]
reward = sum(credit) / len(credit) reward = sum(credit) / len(credit)
else: else:
reward = 0.05 reward = 0.05
@ -111,14 +114,4 @@ class SentenceReorderingDataset(ProceduralDataset):
return reward return reward
register_dataset("sentence_reordering", SentenceReorderingDataset, SentenceReorderingConfig) register_dataset("sentence_reordering", SentenceReorderingDataset, SentenceReorderingConfig)

View file

@ -8,7 +8,6 @@ from typing import Dict, List, Optional, Set, Tuple
from ..data import get_data_file_path from ..data import get_data_file_path
from ..factory import ProceduralDataset, register_dataset from ..factory import ProceduralDataset, register_dataset
QUESTION_TEMPLATE = """Transform the word ladder '{start}' to '{end}' by changing one letter at a time. QUESTION_TEMPLATE = """Transform the word ladder '{start}' to '{end}' by changing one letter at a time.
Provide your answer as a comma-separated sequence of uppercase letters without spaces. Provide your answer as a comma-separated sequence of uppercase letters without spaces.
Each step must be a valid English word.""" Each step must be a valid English word."""

View file

@ -229,12 +229,7 @@ class BasicArithmeticDataset(ProceduralDataset):
if self.config.format_style == "simple": if self.config.format_style == "simple":
return f"{answer_instruction} Calculate {expression} =" return f"{answer_instruction} Calculate {expression} ="
else: else:
templates = [ templates = ["What is {0} =", "Solve {0}=", "Compute {0} =", "Evaluate: {0} ="]
"What is {0} =",
"Solve {0}=",
"Compute {0} =",
"Evaluate: {0} ="
]
template = rng.choice(templates).format(expression) template = rng.choice(templates).format(expression)
return f"{answer_instruction} {template}" return f"{answer_instruction} {template}"

View file

@ -37,7 +37,7 @@ def test_getitem(dataset, config):
assert "metadata" in item assert "metadata" in item
assert item["metadata"]["word_count"] >= config.min_words_in_sentence assert item["metadata"]["word_count"] >= config.min_words_in_sentence
assert item["metadata"]["word_count"] <= config.max_words_in_sentence assert item["metadata"]["word_count"] <= config.max_words_in_sentence
assert len(item['answer'].split()) == item['metadata']['word_count'] assert len(item["answer"].split()) == item["metadata"]["word_count"]
def test_key_error_in_getitem(dataset): def test_key_error_in_getitem(dataset):