minor formatting changes

This commit is contained in:
Andreas Koepf 2025-02-17 18:20:18 +01:00
parent 8989bfca6c
commit 203394314e
2 changed files with 18 additions and 9 deletions

View file

@ -65,7 +65,7 @@ class BasicArithmeticDataset(ProceduralDataset):
def __init__(self, config: BasicArithmeticDatasetConfig):
super().__init__(config=config, seed=config.seed, size=config.size)
self.added_instruction = (
"Ensure to report the answer as an integer. Please do not add commas to the integer answers reported."
" Ensure to report the answer as an integer. Do not add commas to the integer answers reported."
)
def __getitem__(self, idx: int) -> dict[str, Any]:
@ -226,15 +226,14 @@ class BasicArithmeticDataset(ProceduralDataset):
return expression, result
def _format_question(self, rng: Random, expression: str) -> str:
"""Format the expression with clear answer positioning"""
# answer_instruction = "Put your final answer after '=' without additional text."
"""Format the the question with the arithmetic expression"""
if self.config.format_style == "simple":
return f"Calculate {expression}. "
return f"Calculate {expression}."
else:
templates = ["What is {0}. ", "Solve {0}. ", "Compute {0}. ", "Evaluate: {0}. "]
template = rng.choice(templates).format(expression)
return f"{template}"
templates = ["What is {0}?", "Solve {0}.", "Compute {0}.", "Evaluate: {0}."]
template = rng.choice(templates)
return template.format(expression)
# Register the dataset

View file

@ -1,5 +1,3 @@
from random import Random
import pytest
from reasoning_gym.arithmetic.basic_arithmetic import (
@ -66,6 +64,18 @@ def test_arithmetic_dataset_format_styles():
dataset = BasicArithmeticDataset(config)
assert all(item["question"].strip().endswith(".") for item in dataset)
config = BasicArithmeticDatasetConfig(
size=10,
seed=42,
format_style="natural",
min_terms=2,
max_terms=3, # Keep expressions simple for testing
min_digits=1,
max_digits=2,
)
dataset = BasicArithmeticDataset(config)
assert all(item["question"].strip().endswith(".") for item in dataset)
def test_arithmetic_dataset_iteration():
"""Test that iteration respects dataset size"""