mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-19 12:58:07 +00:00
minor formatting changes
This commit is contained in:
parent
8989bfca6c
commit
203394314e
2 changed files with 18 additions and 9 deletions
|
|
@ -65,7 +65,7 @@ class BasicArithmeticDataset(ProceduralDataset):
|
|||
def __init__(self, config: BasicArithmeticDatasetConfig):
|
||||
super().__init__(config=config, seed=config.seed, size=config.size)
|
||||
self.added_instruction = (
|
||||
"Ensure to report the answer as an integer. Please do not add commas to the integer answers reported."
|
||||
" Ensure to report the answer as an integer. Do not add commas to the integer answers reported."
|
||||
)
|
||||
|
||||
def __getitem__(self, idx: int) -> dict[str, Any]:
|
||||
|
|
@ -226,15 +226,14 @@ class BasicArithmeticDataset(ProceduralDataset):
|
|||
return expression, result
|
||||
|
||||
def _format_question(self, rng: Random, expression: str) -> str:
|
||||
"""Format the expression with clear answer positioning"""
|
||||
# answer_instruction = "Put your final answer after '=' without additional text."
|
||||
"""Format the the question with the arithmetic expression"""
|
||||
|
||||
if self.config.format_style == "simple":
|
||||
return f"Calculate {expression}."
|
||||
else:
|
||||
templates = ["What is {0}. ", "Solve {0}. ", "Compute {0}. ", "Evaluate: {0}. "]
|
||||
template = rng.choice(templates).format(expression)
|
||||
return f"{template}"
|
||||
templates = ["What is {0}?", "Solve {0}.", "Compute {0}.", "Evaluate: {0}."]
|
||||
template = rng.choice(templates)
|
||||
return template.format(expression)
|
||||
|
||||
|
||||
# Register the dataset
|
||||
|
|
|
|||
|
|
@ -1,5 +1,3 @@
|
|||
from random import Random
|
||||
|
||||
import pytest
|
||||
|
||||
from reasoning_gym.arithmetic.basic_arithmetic import (
|
||||
|
|
@ -66,6 +64,18 @@ def test_arithmetic_dataset_format_styles():
|
|||
dataset = BasicArithmeticDataset(config)
|
||||
assert all(item["question"].strip().endswith(".") for item in dataset)
|
||||
|
||||
config = BasicArithmeticDatasetConfig(
|
||||
size=10,
|
||||
seed=42,
|
||||
format_style="natural",
|
||||
min_terms=2,
|
||||
max_terms=3, # Keep expressions simple for testing
|
||||
min_digits=1,
|
||||
max_digits=2,
|
||||
)
|
||||
dataset = BasicArithmeticDataset(config)
|
||||
assert all(item["question"].strip().endswith(".") for item in dataset)
|
||||
|
||||
|
||||
def test_arithmetic_dataset_iteration():
|
||||
"""Test that iteration respects dataset size"""
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue