fix template

This commit is contained in:
Zafir Stojanovski 2025-02-16 19:51:24 +01:00
parent 95f179f34e
commit b47b6f94c9
2 changed files with 60 additions and 20 deletions

View file

@ -19,6 +19,23 @@ class TextTransformation(StrEnum):
RANDOMCASE = "randomcase"
QUESTION_TEMPLATE = """Your task is to sort words in ascending or descending order using ASCII/Unicode ordering.
Example:
- Input: Sort these words in ascending order (using ASCII/Unicode ordering) and return them as a comma-separated list: freely, idea, indemnify, last, END, solving
- Output: END, freely, idea, indemnify, last, solving
- Explanation:
- Uppercase letters come before lowercase letters, hence why "END" comes first.
- "freely" comes before "idea" because "f" comes before "i".
- "idea" comes before "indemnify" because even though they both start with "i", "d" comes before "n".
- "indemnify" comes before "last" because "i" comes before "l".
- "last" comes before "solving" because "l" comes before "s".
- Finally, the output is provided as a comma separated list of the sorted words.
Now, sort these words in {direction} order (using ASCII/Unicode ordering) and return them as a comma-separated list: {words}
"""
@dataclass
class WordSortingConfig:
"""Configuration for word sorting task generation"""
@ -94,7 +111,7 @@ class WordSortingDataset(ProceduralDataset):
answer = asc_words if is_ascending else desc_words
return {
"question": f"Sort these words in {direction} order (using ASCII/Unicode ordering) and return them as a comma-separated list:\n{', '.join(transformed_words)}",
"question": QUESTION_TEMPLATE.format(direction=direction, words=", ".join(transformed_words)),
"answer": ", ".join(answer),
"metadata": {
"original_words": original_words,
@ -106,26 +123,17 @@ class WordSortingDataset(ProceduralDataset):
}
def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float:
"""Determine if the solution provided solves this task.
oracle_answer = entry["metadata"]["sorted_words"]
if answer is not None and len(answer) > 0:
parsed_answer = [word.strip() for word in re.split(r",\s*", answer)]
if parsed_answer == oracle_answer:
return 1.0
elif sorted(parsed_answer) == oracle_answer:
return 0.2
else:
return 0.01
The function awards 1.0 for a correct answer.
Args:
answer (Optional[str]): The user's answer.
entry (Dict[str, any]): The original dataset entry containing the correct answer.
Returns:
float: The computed score between 0.0 and 1.0.
"""
if answer == None:
return 0.0
s_answer = answer.strip().replace(" ", "")
if not s_answer == entry["answer"].strip().replace(" ", ""):
return 0.01
else:
return 1.0
return 0.0
register_dataset("word_sorting", WordSortingDataset, WordSortingConfig)