mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-23 16:55:05 +00:00
Merge branch 'main' into rich/graphcolor
This commit is contained in:
commit
b64d0af2bc
19 changed files with 385 additions and 61 deletions
|
|
@ -76,6 +76,11 @@ class IntermediateIntegrationDataset(ProceduralDataset):
|
||||||
"Calculate the antiderivative: ∫ {integrand} dx",
|
"Calculate the antiderivative: ∫ {integrand} dx",
|
||||||
"Evaluate the indefinite integral: ∫ {integrand} dx",
|
"Evaluate the indefinite integral: ∫ {integrand} dx",
|
||||||
]
|
]
|
||||||
|
self.added_instruction = """
|
||||||
|
In addition, when doing calculation, use the following instructions together with your mathematical ingenuity to solve the integral problems
|
||||||
|
## 1. Use ** instead ^ to represent powers. For example 7*X**2 instead of 7*X^2.
|
||||||
|
## 2. Always use * when doing all sorts of multiplcation in your reasoning steps. For example Use [-3*X**3*sin(X) - 9*X**2*cos(X) + 18*X*sin(X) + 18*cos(X) + C] instead of [-3x3sin(x) - 9x2cos(x) + 18xsin(x) + 18cos(x) + C].
|
||||||
|
"""
|
||||||
|
|
||||||
def _get_outer_constant(self, rng: random.Random) -> int:
|
def _get_outer_constant(self, rng: random.Random) -> int:
|
||||||
"""Helper to generate signed outer constant from config"""
|
"""Helper to generate signed outer constant from config"""
|
||||||
|
|
@ -222,9 +227,10 @@ class IntermediateIntegrationDataset(ProceduralDataset):
|
||||||
|
|
||||||
answer = sympy.integrate(integrand, x)
|
answer = sympy.integrate(integrand, x)
|
||||||
answer_str = str(answer) + " + C"
|
answer_str = str(answer) + " + C"
|
||||||
|
question = rng.choice(self.prompt_template).format(integrand=integrand) + self.added_instruction
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"question": rng.choice(self.prompt_template).format(integrand=integrand),
|
"question": question,
|
||||||
"answer": answer_str,
|
"answer": answer_str,
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"integrand": str(integrand),
|
"integrand": str(integrand),
|
||||||
|
|
|
||||||
|
|
@ -62,6 +62,14 @@ class PolynomialEquationsDataset(ProceduralDataset):
|
||||||
"Determine the real value(s) of {variable} that satisfies: {polynomial_expanded} = 0",
|
"Determine the real value(s) of {variable} that satisfies: {polynomial_expanded} = 0",
|
||||||
"Solve the polynomial equation for real {variable}:\n{polynomial_expanded} = 0",
|
"Solve the polynomial equation for real {variable}:\n{polynomial_expanded} = 0",
|
||||||
]
|
]
|
||||||
|
self.added_instruction = """
|
||||||
|
In solving the equations, please abide by the following instruction:
|
||||||
|
## 1. All answers should be comma-separated. For example "-0.3773, 0.4005" etc.
|
||||||
|
## 2. In cases where your answer is b = 2 + sqrt(4560) / 172 and b = 2 - sqrt(4560) / 172. Since b can be 2 numbers, resolve your answer like this instead, "-0.3773, 0.4005".
|
||||||
|
## 3. If there are no real values of i that satisfy the equation, report your answer as empty string, "".
|
||||||
|
## 4. If there are 2 answers, resolve the answers as comma-separated floats of 2 numbers, if 3 answers, make it comma-separated floats of 3 numbers.
|
||||||
|
## 5. Resolve all numbers as floats in the string of comma-separated numbers. Round the floats higher than 4 decimal place(d.p) down to 4 d.p.
|
||||||
|
"""
|
||||||
super().__init__(config=config, seed=config.seed, size=config.size)
|
super().__init__(config=config, seed=config.seed, size=config.size)
|
||||||
|
|
||||||
def __getitem__(self, idx: int) -> dict:
|
def __getitem__(self, idx: int) -> dict:
|
||||||
|
|
@ -89,19 +97,20 @@ class PolynomialEquationsDataset(ProceduralDataset):
|
||||||
for sol in solutions:
|
for sol in solutions:
|
||||||
if sol.is_real:
|
if sol.is_real:
|
||||||
# Evaluate symbolic solution to a floating approximation
|
# Evaluate symbolic solution to a floating approximation
|
||||||
real_solutions.append(float(sol.evalf()))
|
real_solutions.append(round(float(sol.evalf()), 4))
|
||||||
|
|
||||||
if len(real_solutions) > 0:
|
if len(real_solutions) > 0:
|
||||||
real_solutions.sort()
|
real_solutions.sort()
|
||||||
break
|
break
|
||||||
|
|
||||||
answer_str = ", ".join(str(x) for x in real_solutions)
|
answer_str = ", ".join(str(x) for x in real_solutions)
|
||||||
|
question = (
|
||||||
|
rng.choice(self._prompt_templates).format(variable=variable, polynomial_expanded=polynomial_expanded)
|
||||||
|
+ self.added_instruction
|
||||||
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"question": rng.choice(self._prompt_templates).format(
|
"question": question,
|
||||||
variable=variable,
|
|
||||||
polynomial_expanded=polynomial_expanded,
|
|
||||||
),
|
|
||||||
"answer": answer_str,
|
"answer": answer_str,
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"polynomial_expr": str(polynomial_expanded),
|
"polynomial_expr": str(polynomial_expanded),
|
||||||
|
|
|
||||||
|
|
@ -61,6 +61,11 @@ class PolynomialMultiplicationDataset(ProceduralDataset):
|
||||||
"Simplify this expression: {polynomial_expr}",
|
"Simplify this expression: {polynomial_expr}",
|
||||||
"Calculate the following: {polynomial_expr}",
|
"Calculate the following: {polynomial_expr}",
|
||||||
]
|
]
|
||||||
|
self.added_instruction = """
|
||||||
|
In addition, When doing calculation, Use the following instructions together with your mathematical ingenuity to solve the integral problems
|
||||||
|
## 1. Use ** instead ^ to represent powers. For example 7*X**2 instead of 7*X^2.
|
||||||
|
## 2. Always use * when doing all sorts of multiplcation in your reasoning steps and even in reporting answers.
|
||||||
|
"""
|
||||||
super().__init__(config=config, seed=config.seed, size=config.size)
|
super().__init__(config=config, seed=config.seed, size=config.size)
|
||||||
|
|
||||||
def __getitem__(self, idx: int) -> dict:
|
def __getitem__(self, idx: int) -> dict:
|
||||||
|
|
@ -79,11 +84,10 @@ class PolynomialMultiplicationDataset(ProceduralDataset):
|
||||||
|
|
||||||
polynomial_expr = sp.prod(polynomials)
|
polynomial_expr = sp.prod(polynomials)
|
||||||
product = sp.expand(polynomial_expr)
|
product = sp.expand(polynomial_expr)
|
||||||
|
question = rng.choice(self._prompt_templates).format(polynomial_expr=polynomial_expr) + self.added_instruction
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"question": rng.choice(self._prompt_templates).format(
|
"question": question,
|
||||||
polynomial_expr=polynomial_expr,
|
|
||||||
),
|
|
||||||
"answer": product,
|
"answer": product,
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"polynomial_expr": str(polynomial_expr),
|
"polynomial_expr": str(polynomial_expr),
|
||||||
|
|
|
||||||
|
|
@ -41,6 +41,11 @@ class SimpleIntegrationDataset(ProceduralDataset):
|
||||||
"Calculate the antiderivative: ∫ {integrand} dx",
|
"Calculate the antiderivative: ∫ {integrand} dx",
|
||||||
"Evaluate the indefinite integral: ∫ {integrand} dx",
|
"Evaluate the indefinite integral: ∫ {integrand} dx",
|
||||||
]
|
]
|
||||||
|
self.added_instruction = """
|
||||||
|
In addition, When doing calculation, Use the following instructions together with your mathematical ingenuity to solve the integral problems
|
||||||
|
## 1. Use ** instead ^ to represent powers. For example 7*X**2 instead of 7*X^2.
|
||||||
|
## 2. Always use * when doing all sorts of multiplcation in your reasoning steps. For example Use [-3*X**3*sin(X) - 9*X**2*cos(X) + 18*X*sin(X) + 18*cos(X) + C] instead of [-3x3sin(x) - 9x2cos(x) + 18xsin(x) + 18cos(x) + C].
|
||||||
|
"""
|
||||||
super().__init__(config=config, seed=config.seed, size=config.size)
|
super().__init__(config=config, seed=config.seed, size=config.size)
|
||||||
|
|
||||||
def _generate_coefficient(self, rng: random.Random) -> Fraction:
|
def _generate_coefficient(self, rng: random.Random) -> Fraction:
|
||||||
|
|
@ -69,9 +74,10 @@ class SimpleIntegrationDataset(ProceduralDataset):
|
||||||
rng = random.Random(self.seed + idx)
|
rng = random.Random(self.seed + idx)
|
||||||
symbol, polynomial = self._generate_polynomial(rng)
|
symbol, polynomial = self._generate_polynomial(rng)
|
||||||
derivative = sympy.diff(polynomial, symbol)
|
derivative = sympy.diff(polynomial, symbol)
|
||||||
|
question = rng.choice(self._prompt_templates).format(integrand=derivative) + self.added_instruction
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"question": rng.choice(self._prompt_templates).format(integrand=derivative),
|
"question": question,
|
||||||
"answer": str(polynomial) + " + C",
|
"answer": str(polynomial) + " + C",
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"integrand": str(derivative),
|
"integrand": str(derivative),
|
||||||
|
|
|
||||||
|
|
@ -26,7 +26,6 @@ from .rotate_matrix import RotateMatrixConfig, RotateMatrixDataset
|
||||||
from .sentence_reordering import SentenceReorderingConfig, SentenceReorderingDataset
|
from .sentence_reordering import SentenceReorderingConfig, SentenceReorderingDataset
|
||||||
from .spell_backward import SpellBackwardConfig, SpellBackwardDataset
|
from .spell_backward import SpellBackwardConfig, SpellBackwardDataset
|
||||||
from .spiral_matrix import SpiralMatrixConfig, SpiralMatrixDataset
|
from .spiral_matrix import SpiralMatrixConfig, SpiralMatrixDataset
|
||||||
from .string_insertion import StringInsertionConfig, StringInsertionDataset
|
|
||||||
from .string_manipulation import StringManipulationConfig, StringManipulationDataset
|
from .string_manipulation import StringManipulationConfig, StringManipulationDataset
|
||||||
from .word_ladder import WordLadderConfig, WordLadderDataset
|
from .word_ladder import WordLadderConfig, WordLadderDataset
|
||||||
from .word_sequence_reversal import WordSequenceReversalConfig, WordSequenceReversalDataset
|
from .word_sequence_reversal import WordSequenceReversalConfig, WordSequenceReversalDataset
|
||||||
|
|
|
||||||
|
|
@ -34,6 +34,11 @@ class NumberSortingDataset(ProceduralDataset):
|
||||||
|
|
||||||
def __init__(self, config: NumberSortingConfig):
|
def __init__(self, config: NumberSortingConfig):
|
||||||
super().__init__(config=config, seed=config.seed, size=config.size)
|
super().__init__(config=config, seed=config.seed, size=config.size)
|
||||||
|
self.added_instruction = """
|
||||||
|
Please follow the instruction below:
|
||||||
|
## 1. Let all your answers be a list of numbers. Instead of reporting your answer as -69, -13, 1, 7, 11, 43, 59, 61, use ['-69', '-13', '1', '7', '11', '43', '59', '61'] instead
|
||||||
|
## 2. Convert all numbers in the square brackets as strings. For example, ['-69', '-13', '1', '7', '11', '43', '59', '61']
|
||||||
|
"""
|
||||||
|
|
||||||
def _format_number(self, num: float, decimals: int) -> str:
|
def _format_number(self, num: float, decimals: int) -> str:
|
||||||
"""Format number with specified decimal places"""
|
"""Format number with specified decimal places"""
|
||||||
|
|
@ -78,9 +83,10 @@ class NumberSortingDataset(ProceduralDataset):
|
||||||
is_ascending = rng.choice([True, False])
|
is_ascending = rng.choice([True, False])
|
||||||
direction = "ascending" if is_ascending else "descending"
|
direction = "ascending" if is_ascending else "descending"
|
||||||
answer = asc_answer if is_ascending else desc_answer
|
answer = asc_answer if is_ascending else desc_answer
|
||||||
|
question = f"Sort these numbers in {direction} order: {', '.join(number_strs)}" + self.added_instruction
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"question": f"Sort these numbers in {direction} order: {', '.join(number_strs)}",
|
"question": question,
|
||||||
"answer": str(answer),
|
"answer": str(answer),
|
||||||
"metadata": {"original_numbers": number_strs, "direction": direction, "sorted_numbers": answer},
|
"metadata": {"original_numbers": number_strs, "direction": direction, "sorted_numbers": answer},
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -58,27 +58,27 @@ class Arc1DDataset(ProceduralDataset):
|
||||||
- metadata: dict with generation parameters
|
- metadata: dict with generation parameters
|
||||||
"""
|
"""
|
||||||
# Create deterministic RNG from base seed and idx
|
# Create deterministic RNG from base seed and idx
|
||||||
item_rng = Random(self.seed + idx)
|
rng = Random(self.seed + idx)
|
||||||
|
|
||||||
# Select random task
|
# Select random task
|
||||||
task_name = item_rng.choice(self.task_names)
|
task_name = rng.choice(self.task_names)
|
||||||
task_func, task_kwargs = self.ARC_1D_TASKS[task_name]
|
task_func, task_kwargs = self.ARC_1D_TASKS[task_name]
|
||||||
|
|
||||||
# Generate training examples
|
# Generate training examples
|
||||||
train_examples = []
|
train_examples = []
|
||||||
size = item_rng.randint(self.config.min_size, self.config.max_size)
|
size = rng.randint(self.config.min_size, self.config.max_size)
|
||||||
|
|
||||||
for _ in range(self.config.num_train):
|
for _ in range(self.config.num_train):
|
||||||
example = None
|
example = None
|
||||||
while example is None:
|
while example is None:
|
||||||
example = task_func(item_rng, size, **task_kwargs)
|
example = task_func(rng, size, **task_kwargs)
|
||||||
|
|
||||||
train_examples.append(example)
|
train_examples.append(example)
|
||||||
|
|
||||||
# Generate test example
|
# Generate test example
|
||||||
test_example = None
|
test_example = None
|
||||||
while test_example is None:
|
while test_example is None:
|
||||||
test_example = task_func(item_rng, size, **task_kwargs)
|
test_example = task_func(rng, size, **task_kwargs)
|
||||||
|
|
||||||
# Format question
|
# Format question
|
||||||
question = "Find the common rule that maps an input grid to an output grid, given the examples below.\n\n"
|
question = "Find the common rule that maps an input grid to an output grid, given the examples below.\n\n"
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@ Arithmetic tasks for training reasoning capabilities:
|
||||||
|
|
||||||
from .basic_arithmetic import BasicArithmeticDataset, BasicArithmeticDatasetConfig
|
from .basic_arithmetic import BasicArithmeticDataset, BasicArithmeticDatasetConfig
|
||||||
from .calendar_arithmetic import CalendarArithmeticConfig, CalendarArithmeticDataset
|
from .calendar_arithmetic import CalendarArithmeticConfig, CalendarArithmeticDataset
|
||||||
from .chain_sum import ChainSum, ChainSumConfig
|
from .chain_sum import ChainSumConfig, ChainSumDataset
|
||||||
from .count_bits import CountBitsConfig, CountBitsDataset
|
from .count_bits import CountBitsConfig, CountBitsDataset
|
||||||
from .dice import DiceConfig, DiceDataset
|
from .dice import DiceConfig, DiceDataset
|
||||||
from .fraction_simplification import FractionSimplificationConfig, FractionSimplificationDataset
|
from .fraction_simplification import FractionSimplificationConfig, FractionSimplificationDataset
|
||||||
|
|
@ -14,12 +14,13 @@ from .lcm import LCMConfig, LCMDataset
|
||||||
from .leg_counting import LegCountingConfig, LegCountingDataset
|
from .leg_counting import LegCountingConfig, LegCountingDataset
|
||||||
from .power_function import PowerFunctionConfig, PowerFunctionDataset
|
from .power_function import PowerFunctionConfig, PowerFunctionDataset
|
||||||
from .prime_factorization import PrimeFactorizationConfig, PrimeFactorizationDataset
|
from .prime_factorization import PrimeFactorizationConfig, PrimeFactorizationDataset
|
||||||
|
from .products import ProductsConfig, ProductsDataset
|
||||||
from .time_intervals import TimeIntervalsConfig, TimeIntervalsDataset
|
from .time_intervals import TimeIntervalsConfig, TimeIntervalsDataset
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"BasicArithmeticDataset",
|
"BasicArithmeticDataset",
|
||||||
"BasicArithmeticDatasetConfig",
|
"BasicArithmeticDatasetConfig",
|
||||||
"ChainSum",
|
"ChainSumDataset",
|
||||||
"ChainSumConfig",
|
"ChainSumConfig",
|
||||||
"CalendarArithmeticConfig",
|
"CalendarArithmeticConfig",
|
||||||
"CalendarArithmeticDataset",
|
"CalendarArithmeticDataset",
|
||||||
|
|
@ -31,8 +32,12 @@ __all__ = [
|
||||||
"LCMDataset",
|
"LCMDataset",
|
||||||
"LegCountingConfig",
|
"LegCountingConfig",
|
||||||
"LegCountingDataset",
|
"LegCountingDataset",
|
||||||
|
"PowerFunctionConfig",
|
||||||
|
"PowerFunctionDataset",
|
||||||
"PrimeFactorizationConfig",
|
"PrimeFactorizationConfig",
|
||||||
"PrimeFactorizationDataset",
|
"PrimeFactorizationDataset",
|
||||||
|
"ProductsDataset",
|
||||||
|
"ProductsConfig",
|
||||||
"GSMSymbolicDatasetConfig",
|
"GSMSymbolicDatasetConfig",
|
||||||
"GSMSymbolicDataset",
|
"GSMSymbolicDataset",
|
||||||
"TimeIntervalsConfig",
|
"TimeIntervalsConfig",
|
||||||
|
|
|
||||||
|
|
@ -78,17 +78,17 @@ class BasicArithmeticDataset(ProceduralDataset):
|
||||||
- metadata: dict with generation parameters
|
- metadata: dict with generation parameters
|
||||||
"""
|
"""
|
||||||
# Create deterministic RNG from base seed and idx
|
# Create deterministic RNG from base seed and idx
|
||||||
item_rng = Random(self.seed + idx)
|
rng = Random(self.seed + idx)
|
||||||
|
|
||||||
num_terms = item_rng.randint(self.config.min_terms, self.config.max_terms)
|
num_terms = rng.randint(self.config.min_terms, self.config.max_terms)
|
||||||
num_digits = item_rng.randint(self.config.min_digits, self.config.max_digits)
|
num_digits = rng.randint(self.config.min_digits, self.config.max_digits)
|
||||||
|
|
||||||
if self.config.allow_parentheses:
|
if self.config.allow_parentheses:
|
||||||
expression, result = self._generate_complex_task(item_rng, num_terms, num_digits)
|
expression, result = self._generate_complex_task(rng, num_terms, num_digits)
|
||||||
else:
|
else:
|
||||||
expression, result = self._generate_simple_task(item_rng, num_terms, num_digits)
|
expression, result = self._generate_simple_task(rng, num_terms, num_digits)
|
||||||
|
|
||||||
question = self._format_question(item_rng, expression)
|
question = self._format_question(rng, expression)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"question": question,
|
"question": question,
|
||||||
|
|
|
||||||
|
|
@ -122,9 +122,9 @@ class CalendarArithmeticDataset(ProceduralDataset):
|
||||||
self.tasks = [self.task_handlers[task] for task in self.config.tasks]
|
self.tasks = [self.task_handlers[task] for task in self.config.tasks]
|
||||||
|
|
||||||
def __getitem__(self, idx: int) -> dict:
|
def __getitem__(self, idx: int) -> dict:
|
||||||
item_rng = random.Random(self.seed + idx)
|
rng = random.Random(self.seed + idx)
|
||||||
task = item_rng.choice(self.tasks)
|
task = rng.choice(self.tasks)
|
||||||
question, answer, metadata = task(item_rng)
|
question, answer, metadata = task(rng)
|
||||||
return {
|
return {
|
||||||
"question": question,
|
"question": question,
|
||||||
"answer": str(answer),
|
"answer": str(answer),
|
||||||
|
|
|
||||||
|
|
@ -32,7 +32,7 @@ class ChainSumConfig:
|
||||||
assert 10 ** (self.min_digits - 1) >= 1, "min_digits would result in invalid number range"
|
assert 10 ** (self.min_digits - 1) >= 1, "min_digits would result in invalid number range"
|
||||||
|
|
||||||
|
|
||||||
class ChainSum(ProceduralDataset):
|
class ChainSumDataset(ProceduralDataset):
|
||||||
"""Generates simple arithmetic tasks using only + and - operators"""
|
"""Generates simple arithmetic tasks using only + and - operators"""
|
||||||
|
|
||||||
def __init__(self, config: ChainSumConfig):
|
def __init__(self, config: ChainSumConfig):
|
||||||
|
|
@ -51,16 +51,16 @@ class ChainSum(ProceduralDataset):
|
||||||
- metadata: dict with generation parameters
|
- metadata: dict with generation parameters
|
||||||
"""
|
"""
|
||||||
# Create deterministic RNG from base seed and idx
|
# Create deterministic RNG from base seed and idx
|
||||||
item_rng = random.Random(self.seed + idx)
|
rng = random.Random(self.seed + idx)
|
||||||
|
|
||||||
num_terms = item_rng.randint(self.config.min_terms, self.config.max_terms)
|
num_terms = rng.randint(self.config.min_terms, self.config.max_terms)
|
||||||
num_digits = item_rng.randint(self.config.min_digits, self.config.max_digits)
|
num_digits = rng.randint(self.config.min_digits, self.config.max_digits)
|
||||||
|
|
||||||
# Calculate value ranges based on number of digits
|
# Calculate value ranges based on number of digits
|
||||||
min_value = 0 if num_digits == 1 else 10 ** (num_digits - 1) # Special case for 1 digit
|
min_value = 0 if num_digits == 1 else 10 ** (num_digits - 1) # Special case for 1 digit
|
||||||
max_value = (10**num_digits) - 1 # e.g., 999 for 3 digits
|
max_value = (10**num_digits) - 1 # e.g., 999 for 3 digits
|
||||||
|
|
||||||
expression, result = self._generate_task(item_rng, num_terms, min_value, max_value)
|
expression, result = self._generate_task(rng, num_terms, min_value, max_value)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"question": f"{expression} =",
|
"question": f"{expression} =",
|
||||||
|
|
@ -143,4 +143,4 @@ class ChainSumCurriculum(BaseCurriculum):
|
||||||
|
|
||||||
|
|
||||||
# Register the dataset
|
# Register the dataset
|
||||||
register_dataset("chain_sum", ChainSum, ChainSumConfig)
|
register_dataset("chain_sum", ChainSumDataset, ChainSumConfig)
|
||||||
|
|
|
||||||
130
reasoning_gym/arithmetic/products.py
Normal file
130
reasoning_gym/arithmetic/products.py
Normal file
|
|
@ -0,0 +1,130 @@
|
||||||
|
import random
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition
|
||||||
|
from ..factory import ProceduralDataset, register_dataset
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ProductsConfig:
|
||||||
|
"""Configuration for products task generation"""
|
||||||
|
|
||||||
|
min_terms: int = 2
|
||||||
|
max_terms: int = 2
|
||||||
|
min_digits: int = 1
|
||||||
|
max_digits: int = 5
|
||||||
|
seed: Optional[int] = None
|
||||||
|
size: int = 500
|
||||||
|
|
||||||
|
def validate(self) -> None:
|
||||||
|
"""Validate configuration parameters"""
|
||||||
|
assert self.size > 0, "size must be positive"
|
||||||
|
assert self.min_terms > 0, "min_terms must be positive"
|
||||||
|
assert self.max_terms >= self.min_terms, "max_terms must be >= min_terms"
|
||||||
|
assert self.min_digits > 0, "min_digits must be positive"
|
||||||
|
assert self.max_digits >= self.min_digits, "max_digits must be >= min_digits"
|
||||||
|
|
||||||
|
|
||||||
|
class ProductsDataset(ProceduralDataset):
|
||||||
|
"""Generates multiplication tasks with configurable number of terms"""
|
||||||
|
|
||||||
|
def __init__(self, config: ProductsConfig):
|
||||||
|
super().__init__(config=config, seed=config.seed, size=config.size)
|
||||||
|
|
||||||
|
def __getitem__(self, idx: int) -> dict:
|
||||||
|
"""Generate a single multiplication task
|
||||||
|
|
||||||
|
Args:
|
||||||
|
idx: Index of the item to generate
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict with keys:
|
||||||
|
- question: str, the formatted multiplication expression
|
||||||
|
- answer: str, the ground truth result
|
||||||
|
- metadata: dict with generation parameters
|
||||||
|
"""
|
||||||
|
# Create deterministic RNG from base seed and idx
|
||||||
|
rng = random.Random(self.seed + idx)
|
||||||
|
|
||||||
|
num_terms = rng.randint(self.config.min_terms, self.config.max_terms)
|
||||||
|
num_digits = rng.randint(self.config.min_digits, self.config.max_digits)
|
||||||
|
|
||||||
|
# Calculate value ranges based on number of digits
|
||||||
|
min_value = 0 if num_digits == 1 else 10 ** (num_digits - 1) # Special case for 1 digit
|
||||||
|
max_value = (10**num_digits) - 1 # e.g., 999 for 3 digits
|
||||||
|
|
||||||
|
expression, result = self._generate_task(rng, num_terms, min_value, max_value)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"question": f"{expression} =",
|
||||||
|
"answer": str(result),
|
||||||
|
"metadata": {
|
||||||
|
"difficulty": {
|
||||||
|
"num_terms": num_terms,
|
||||||
|
"num_digits": num_digits,
|
||||||
|
},
|
||||||
|
"expression": expression,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
def _generate_task(self, rng: random.Random, num_terms: int, min_value: int, max_value: int) -> tuple[str, int]:
|
||||||
|
"""Generate a multiplication task
|
||||||
|
|
||||||
|
Args:
|
||||||
|
rng: Random number generator
|
||||||
|
num_terms: Number of terms in the expression
|
||||||
|
min_value: Minimum value for generated numbers
|
||||||
|
max_value: Maximum value for generated numbers
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (expression string, result integer)
|
||||||
|
"""
|
||||||
|
# Generate random numbers within the specified range
|
||||||
|
constants = [rng.randint(min_value, max_value) for _ in range(num_terms)]
|
||||||
|
|
||||||
|
# Build expression and compute result
|
||||||
|
expression_parts = []
|
||||||
|
result = constants[0]
|
||||||
|
|
||||||
|
expression_parts.append(str(constants[0]))
|
||||||
|
for i in range(1, len(constants)):
|
||||||
|
expression_parts.append("*")
|
||||||
|
expression_parts.append(str(constants[i]))
|
||||||
|
result *= constants[i]
|
||||||
|
|
||||||
|
expression = " ".join(expression_parts)
|
||||||
|
return expression, result
|
||||||
|
|
||||||
|
|
||||||
|
class ProductsCurriculum(BaseCurriculum):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(ProductsCurriculum.__name__, ProductsConfig)
|
||||||
|
|
||||||
|
# Define attributes
|
||||||
|
self._define_attributes(
|
||||||
|
RangeAttributeDefinition(
|
||||||
|
name="num_terms",
|
||||||
|
levels=[2, 3, 4, 5],
|
||||||
|
default_level=0, # Start with 2 terms
|
||||||
|
description="Maximum number of terms in the expression",
|
||||||
|
attr_type=AttributeType.APPEND,
|
||||||
|
min_value=2, # Ensure at least 2 terms
|
||||||
|
lower_field_name="min_terms",
|
||||||
|
upper_field_name="max_terms",
|
||||||
|
),
|
||||||
|
RangeAttributeDefinition(
|
||||||
|
name="num_digits",
|
||||||
|
levels=[1, 2, 3, 4],
|
||||||
|
default_level=0, # Start with 1-digit numbers
|
||||||
|
description="Number of digits in each operand",
|
||||||
|
attr_type=AttributeType.APPEND,
|
||||||
|
min_value=1, # Ensure numbers are at least 1 digit
|
||||||
|
lower_field_name="min_digits",
|
||||||
|
upper_field_name="max_digits",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Register the dataset
|
||||||
|
register_dataset("products", ProductsDataset, ProductsConfig)
|
||||||
|
|
@ -82,14 +82,14 @@ class TimeIntervalsDataset(ProceduralDataset):
|
||||||
|
|
||||||
def __getitem__(self, idx: int) -> dict:
|
def __getitem__(self, idx: int) -> dict:
|
||||||
"""Generate a single time interval calculation task"""
|
"""Generate a single time interval calculation task"""
|
||||||
item_rng = random.Random(self.seed + idx)
|
rng = random.Random(self.seed + idx)
|
||||||
|
|
||||||
# Randomly choose task type from config
|
# Randomly choose task type from config
|
||||||
task_type = item_rng.choice(self.config.task_types)
|
task_type = rng.choice(self.config.task_types)
|
||||||
|
|
||||||
start_time, end_time, format_str, expected_format = self._generate_times(item_rng, task_type)
|
start_time, end_time, format_str, expected_format = self._generate_times(rng, task_type)
|
||||||
|
|
||||||
template = item_rng.choice(self.TEMPLATES)
|
template = rng.choice(self.TEMPLATES)
|
||||||
question = template.format(start=start_time, end=end_time, format=expected_format)
|
question = template.format(start=start_time, end=end_time, format=expected_format)
|
||||||
|
|
||||||
# Calculate the actual difference
|
# Calculate the actual difference
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
|
import json
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from random import Random
|
from random import Random
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, Optional
|
||||||
|
|
||||||
import cellpylib as cpl
|
import cellpylib as cpl
|
||||||
|
|
||||||
|
|
@ -11,8 +12,8 @@ from ..factory import ProceduralDataset, register_dataset
|
||||||
class GameOfLifeConfig:
|
class GameOfLifeConfig:
|
||||||
"""Configuration for sudoku puzzle generation"""
|
"""Configuration for sudoku puzzle generation"""
|
||||||
|
|
||||||
grid_size_x: int = 20
|
grid_size_x: int = 10
|
||||||
grid_size_y: int = 20
|
grid_size_y: int = 10
|
||||||
filled_cells: int = 100 # actually a max
|
filled_cells: int = 100 # actually a max
|
||||||
simulation_steps: int = 1
|
simulation_steps: int = 1
|
||||||
seed: Optional[int] = None
|
seed: Optional[int] = None
|
||||||
|
|
@ -31,7 +32,7 @@ class GameOfLifeDataset(ProceduralDataset):
|
||||||
|
|
||||||
def __init__(self, config: GameOfLifeConfig):
|
def __init__(self, config: GameOfLifeConfig):
|
||||||
self._prompt_templates = [
|
self._prompt_templates = [
|
||||||
"What will this Game of Life board look like after {simulation_steps} steps of simulation?\n\n{board}"
|
"What will this Game of Life board look like after {simulation_steps} steps of simulation? Reply as array of array representing rows in the grid from top to bottom in JSON format. (An empty 3x3 grid would look like this: [[0,0,0],[0,0,0],[0,0,0]])\n\n{board}."
|
||||||
]
|
]
|
||||||
|
|
||||||
super().__init__(config=config, seed=config.seed, size=config.size)
|
super().__init__(config=config, seed=config.seed, size=config.size)
|
||||||
|
|
@ -59,11 +60,18 @@ class GameOfLifeDataset(ProceduralDataset):
|
||||||
|
|
||||||
# Simulate the result to get the answer
|
# Simulate the result to get the answer
|
||||||
evolved = cpl.evolve2d(
|
evolved = cpl.evolve2d(
|
||||||
board, timesteps=self.config.simulation_steps + 1, apply_rule=cpl.game_of_life_rule, memoize="recursive"
|
board,
|
||||||
|
timesteps=self.config.simulation_steps + 1,
|
||||||
|
apply_rule=cpl.game_of_life_rule,
|
||||||
|
memoize="recursive",
|
||||||
)
|
)
|
||||||
|
|
||||||
board_str = str(board[0])
|
rows = [json.dumps(board[0, i].tolist(), separators=(",", ":")) for i in range(board.shape[1])]
|
||||||
result_str = str(evolved[-1])
|
board_str = "[" + ", \n ".join(rows) + "]"
|
||||||
|
|
||||||
|
final_step = evolved[-1]
|
||||||
|
final_step_list = final_step.tolist()
|
||||||
|
result_str = json.dumps(final_step_list, separators=(",", ":"))
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"question": rng.choice(self._prompt_templates).format(
|
"question": rng.choice(self._prompt_templates).format(
|
||||||
|
|
@ -93,10 +101,17 @@ class GameOfLifeDataset(ProceduralDataset):
|
||||||
|
|
||||||
if answer == None:
|
if answer == None:
|
||||||
return 0.0
|
return 0.0
|
||||||
if answer.replace("\n", "") != entry["answer"].replace("\n", ""):
|
|
||||||
|
try:
|
||||||
|
ans_arr = json.loads(answer)
|
||||||
|
correct_arr = json.loads(entry["answer"])
|
||||||
|
|
||||||
|
if correct_arr != ans_arr:
|
||||||
return 0.01
|
return 0.01
|
||||||
else:
|
else:
|
||||||
return 1.0 # Yay
|
return 1.0 # Yay
|
||||||
|
except Exception as e:
|
||||||
|
return 0.01
|
||||||
|
|
||||||
|
|
||||||
register_dataset("game_of_life", GameOfLifeDataset, GameOfLifeConfig)
|
register_dataset("game_of_life", GameOfLifeDataset, GameOfLifeConfig)
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from reasoning_gym.arithmetic import ChainSum, ChainSumConfig
|
from reasoning_gym.arithmetic import ChainSumConfig, ChainSumDataset
|
||||||
from reasoning_gym.arithmetic.chain_sum import ChainSumCurriculum
|
from reasoning_gym.arithmetic.chain_sum import ChainSumCurriculum
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -18,8 +18,8 @@ def test_chain_sum_config_validation():
|
||||||
def test_chain_sum_deterministic():
|
def test_chain_sum_deterministic():
|
||||||
"""Test that dataset generates same items with same seed"""
|
"""Test that dataset generates same items with same seed"""
|
||||||
config = ChainSumConfig(seed=42, size=10)
|
config = ChainSumConfig(seed=42, size=10)
|
||||||
dataset1 = ChainSum(config)
|
dataset1 = ChainSumDataset(config)
|
||||||
dataset2 = ChainSum(config)
|
dataset2 = ChainSumDataset(config)
|
||||||
|
|
||||||
for i in range(len(dataset1)):
|
for i in range(len(dataset1)):
|
||||||
assert dataset1[i] == dataset2[i]
|
assert dataset1[i] == dataset2[i]
|
||||||
|
|
@ -28,7 +28,7 @@ def test_chain_sum_deterministic():
|
||||||
def test_chain_sum_items():
|
def test_chain_sum_items():
|
||||||
"""Test basic properties of generated items"""
|
"""Test basic properties of generated items"""
|
||||||
config = ChainSumConfig(min_terms=2, max_terms=4, min_digits=1, max_digits=2, size=100, seed=42)
|
config = ChainSumConfig(min_terms=2, max_terms=4, min_digits=1, max_digits=2, size=100, seed=42)
|
||||||
dataset = ChainSum(config)
|
dataset = ChainSumDataset(config)
|
||||||
|
|
||||||
for i in range(len(dataset)):
|
for i in range(len(dataset)):
|
||||||
item = dataset[i]
|
item = dataset[i]
|
||||||
|
|
@ -57,7 +57,7 @@ def test_chain_sum_number_ranges():
|
||||||
size=50,
|
size=50,
|
||||||
seed=42,
|
seed=42,
|
||||||
)
|
)
|
||||||
dataset = ChainSum(config)
|
dataset = ChainSumDataset(config)
|
||||||
|
|
||||||
for i in range(len(dataset)):
|
for i in range(len(dataset)):
|
||||||
item = dataset[i]
|
item = dataset[i]
|
||||||
|
|
@ -71,7 +71,7 @@ def test_chain_sum_number_ranges():
|
||||||
|
|
||||||
# Test 1-digit numbers
|
# Test 1-digit numbers
|
||||||
config = ChainSumConfig(min_terms=2, max_terms=2, min_digits=1, max_digits=1, size=50, seed=42)
|
config = ChainSumConfig(min_terms=2, max_terms=2, min_digits=1, max_digits=1, size=50, seed=42)
|
||||||
dataset = ChainSum(config)
|
dataset = ChainSumDataset(config)
|
||||||
for i in range(len(dataset)):
|
for i in range(len(dataset)):
|
||||||
item = dataset[i]
|
item = dataset[i]
|
||||||
expression = item["metadata"]["expression"]
|
expression = item["metadata"]["expression"]
|
||||||
|
|
@ -88,7 +88,7 @@ def test_chain_sum_negation():
|
||||||
config = ChainSumConfig(
|
config = ChainSumConfig(
|
||||||
min_terms=2, max_terms=2, min_digits=2, max_digits=2, size=100, seed=42, allow_negation=True
|
min_terms=2, max_terms=2, min_digits=2, max_digits=2, size=100, seed=42, allow_negation=True
|
||||||
)
|
)
|
||||||
dataset = ChainSum(config)
|
dataset = ChainSumDataset(config)
|
||||||
|
|
||||||
# Track if we see both positive and negative numbers
|
# Track if we see both positive and negative numbers
|
||||||
has_positive = False
|
has_positive = False
|
||||||
|
|
@ -112,7 +112,7 @@ def test_chain_sum_negation():
|
||||||
def test_chain_sum_iteration():
|
def test_chain_sum_iteration():
|
||||||
"""Test that iteration respects dataset size"""
|
"""Test that iteration respects dataset size"""
|
||||||
config = ChainSumConfig(min_terms=2, max_terms=2, size=5, seed=42) # Small size for testing
|
config = ChainSumConfig(min_terms=2, max_terms=2, size=5, seed=42) # Small size for testing
|
||||||
dataset = ChainSum(config)
|
dataset = ChainSumDataset(config)
|
||||||
|
|
||||||
# Test manual iteration
|
# Test manual iteration
|
||||||
items = []
|
items = []
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,7 @@ from pathlib import Path
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from reasoning_gym.arithmetic.chain_sum import ChainSum, ChainSumConfig
|
from reasoning_gym.arithmetic.chain_sum import ChainSumConfig, ChainSumDataset
|
||||||
from reasoning_gym.arithmetic.leg_counting import LegCountingConfig
|
from reasoning_gym.arithmetic.leg_counting import LegCountingConfig
|
||||||
from reasoning_gym.coaching import Coach, GroupedScores
|
from reasoning_gym.coaching import Coach, GroupedScores
|
||||||
from reasoning_gym.composite import CompositeConfig, CompositeDataset, DatasetSpec
|
from reasoning_gym.composite import CompositeConfig, CompositeDataset, DatasetSpec
|
||||||
|
|
@ -14,7 +14,7 @@ from reasoning_gym.composite import CompositeConfig, CompositeDataset, DatasetSp
|
||||||
def test_coach_with_chain_sum():
|
def test_coach_with_chain_sum():
|
||||||
# Create a small ChainSum dataset
|
# Create a small ChainSum dataset
|
||||||
config = ChainSumConfig(min_terms=2, max_terms=3, min_digits=1, max_digits=2, size=10, seed=42)
|
config = ChainSumConfig(min_terms=2, max_terms=3, min_digits=1, max_digits=2, size=10, seed=42)
|
||||||
dataset = ChainSum(config)
|
dataset = ChainSumDataset(config)
|
||||||
coach = Coach(dataset)
|
coach = Coach(dataset)
|
||||||
|
|
||||||
# Simulate an agent working on tasks
|
# Simulate an agent working on tasks
|
||||||
|
|
@ -208,7 +208,7 @@ def test_coach_score_logging(tmp_path):
|
||||||
|
|
||||||
# Create dataset and coach with logging
|
# Create dataset and coach with logging
|
||||||
config = ChainSumConfig(min_terms=2, max_terms=3, min_digits=1, max_digits=2, size=10, seed=42)
|
config = ChainSumConfig(min_terms=2, max_terms=3, min_digits=1, max_digits=2, size=10, seed=42)
|
||||||
dataset = ChainSum(config)
|
dataset = ChainSumDataset(config)
|
||||||
coach = Coach(dataset, score_log=log_file)
|
coach = Coach(dataset, score_log=log_file)
|
||||||
|
|
||||||
# Score a few answers
|
# Score a few answers
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,7 @@ def test_game_of_life():
|
||||||
"""Test basic properties and solution of generated items"""
|
"""Test basic properties and solution of generated items"""
|
||||||
|
|
||||||
# Easy
|
# Easy
|
||||||
config = GameOfLifeConfig(seed=42, size=1, grid_size_x=20, grid_size_y=20, filled_cells=10, simulation_steps=1)
|
config = GameOfLifeConfig(seed=42, size=10, grid_size_x=20, grid_size_y=20, filled_cells=200, simulation_steps=1)
|
||||||
dataset = GameOfLifeDataset(config)
|
dataset = GameOfLifeDataset(config)
|
||||||
|
|
||||||
for item in dataset:
|
for item in dataset:
|
||||||
|
|
|
||||||
|
|
@ -112,7 +112,7 @@ def test_polynomial_solutions_evaluation():
|
||||||
evaluated_value = poly_expr.subs(x, solution)
|
evaluated_value = poly_expr.subs(x, solution)
|
||||||
|
|
||||||
# Ensure the evaluated value is close to zero (numerical stability threshold)
|
# Ensure the evaluated value is close to zero (numerical stability threshold)
|
||||||
assert abs(evaluated_value) < 1e-6, (
|
assert abs(evaluated_value) < 1e-5, (
|
||||||
f"Solution {solution} does not satisfy the polynomial {poly_str}. "
|
f"Solution {solution} does not satisfy the polynomial {poly_str}. "
|
||||||
f"Evaluated value: {evaluated_value}"
|
f"Evaluated value: {evaluated_value}"
|
||||||
)
|
)
|
||||||
|
|
|
||||||
144
tests/test_products.py
Normal file
144
tests/test_products.py
Normal file
|
|
@ -0,0 +1,144 @@
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from reasoning_gym.arithmetic import ProductsConfig, ProductsDataset
|
||||||
|
from reasoning_gym.arithmetic.products import ProductsCurriculum
|
||||||
|
|
||||||
|
|
||||||
|
def test_products_config_validation():
|
||||||
|
"""Test that invalid configs raise appropriate errors"""
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
config = ProductsConfig(min_terms=0)
|
||||||
|
config.validate()
|
||||||
|
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
config = ProductsConfig(min_terms=3, max_terms=2)
|
||||||
|
config.validate()
|
||||||
|
|
||||||
|
|
||||||
|
def test_products_deterministic():
|
||||||
|
"""Test that dataset generates same items with same seed"""
|
||||||
|
config = ProductsConfig(seed=42, size=10)
|
||||||
|
dataset1 = ProductsDataset(config)
|
||||||
|
dataset2 = ProductsDataset(config)
|
||||||
|
|
||||||
|
for i in range(len(dataset1)):
|
||||||
|
assert dataset1[i] == dataset2[i]
|
||||||
|
|
||||||
|
|
||||||
|
def test_products_items():
|
||||||
|
"""Test basic properties of generated items"""
|
||||||
|
config = ProductsConfig(min_terms=2, max_terms=4, min_digits=1, max_digits=2, size=100, seed=42)
|
||||||
|
dataset = ProductsDataset(config)
|
||||||
|
|
||||||
|
for i in range(len(dataset)):
|
||||||
|
item = dataset[i]
|
||||||
|
assert isinstance(item, dict)
|
||||||
|
assert "question" in item
|
||||||
|
assert "answer" in item
|
||||||
|
assert "metadata" in item
|
||||||
|
|
||||||
|
# Verify only * is used
|
||||||
|
expression = item["metadata"]["expression"]
|
||||||
|
assert all(op in ["*", " "] or op.isdigit() for op in expression)
|
||||||
|
|
||||||
|
# Verify the answer matches the expression
|
||||||
|
answer = eval(expression) # Safe here as we control the expression
|
||||||
|
assert str(answer) == item["answer"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_products_number_ranges():
|
||||||
|
"""Test that generated numbers respect digit constraints"""
|
||||||
|
# Test 3-digit numbers
|
||||||
|
config = ProductsConfig(
|
||||||
|
min_terms=2,
|
||||||
|
max_terms=2, # Fix to 2 terms for easier testing
|
||||||
|
min_digits=3, # Should generate numbers >= 100
|
||||||
|
max_digits=3, # Should generate numbers <= 999
|
||||||
|
size=50,
|
||||||
|
seed=42,
|
||||||
|
)
|
||||||
|
dataset = ProductsDataset(config)
|
||||||
|
|
||||||
|
for i in range(len(dataset)):
|
||||||
|
item = dataset[i]
|
||||||
|
expression = item["metadata"]["expression"]
|
||||||
|
numbers = [int(n) for n in expression.split() if n.isdigit()]
|
||||||
|
for num in numbers:
|
||||||
|
assert 100 <= num <= 999, f"Number {num} outside valid range for 3 digits"
|
||||||
|
|
||||||
|
# Test 1-digit numbers
|
||||||
|
config = ProductsConfig(min_terms=2, max_terms=2, min_digits=1, max_digits=1, size=50, seed=42)
|
||||||
|
dataset = ProductsDataset(config)
|
||||||
|
for i in range(len(dataset)):
|
||||||
|
item = dataset[i]
|
||||||
|
expression = item["metadata"]["expression"]
|
||||||
|
numbers = [int(n) for n in expression.split() if n.isdigit()]
|
||||||
|
for num in numbers:
|
||||||
|
assert 0 <= num <= 9, f"Number {num} outside valid range for 1 digit"
|
||||||
|
|
||||||
|
|
||||||
|
def test_products_iteration():
|
||||||
|
"""Test that iteration respects dataset size"""
|
||||||
|
config = ProductsConfig(min_terms=2, max_terms=2, size=5, seed=42) # Small size for testing
|
||||||
|
dataset = ProductsDataset(config)
|
||||||
|
|
||||||
|
# Test manual iteration
|
||||||
|
items = []
|
||||||
|
for item in dataset:
|
||||||
|
items.append(item)
|
||||||
|
assert len(items) == config.size, "Iterator should yield exactly size items"
|
||||||
|
|
||||||
|
# Test list conversion
|
||||||
|
items = list(dataset)
|
||||||
|
assert len(items) == config.size, "Iterator should yield exactly size items"
|
||||||
|
|
||||||
|
# Test multiple iterations
|
||||||
|
first_items = list(dataset)
|
||||||
|
second_items = list(dataset)
|
||||||
|
assert first_items == second_items, "Multiple iterations should yield same items"
|
||||||
|
|
||||||
|
|
||||||
|
def test_products_scoring():
|
||||||
|
"""Test that scoring works correctly"""
|
||||||
|
config = ProductsConfig(min_terms=2, max_terms=2, size=10, seed=42)
|
||||||
|
dataset = ProductsDataset(config)
|
||||||
|
|
||||||
|
# Test scoring with exact match
|
||||||
|
item = dataset[0]
|
||||||
|
assert dataset.score_answer(item["answer"], item) == 1.0, "Exact match should score 1.0"
|
||||||
|
|
||||||
|
# Test scoring with wrong answer
|
||||||
|
assert dataset.score_answer("wrong", item) == 0.01, "Wrong answer should score 0.01"
|
||||||
|
|
||||||
|
# Test scoring with partial match (answer contained in response)
|
||||||
|
assert dataset.score_answer(f"The answer is {item['answer']}", item) == 0.5, "Partial match should score 0.5"
|
||||||
|
|
||||||
|
# Test scoring with None
|
||||||
|
assert dataset.score_answer(None, item) == 0.0, "None should score 0.0"
|
||||||
|
|
||||||
|
|
||||||
|
def test_products_curriculum():
|
||||||
|
curriculum = ProductsCurriculum()
|
||||||
|
|
||||||
|
base_value = {"size": 150, "seed": 1}
|
||||||
|
|
||||||
|
base_cfg: ProductsConfig = curriculum.generate_configuration(base_value)
|
||||||
|
assert base_cfg.seed == 1
|
||||||
|
assert base_cfg.size == 150
|
||||||
|
assert base_cfg.min_digits == 1 and base_cfg.max_digits == 1
|
||||||
|
assert base_cfg.min_terms == 2 and base_cfg.max_terms == 2
|
||||||
|
|
||||||
|
# test incrementing attribute levels for num_terms & num_digits attributes
|
||||||
|
curriculum.increment_attr_level("num_terms")
|
||||||
|
curriculum.increment_attr_level("num_digits")
|
||||||
|
|
||||||
|
increased_cfg = curriculum.generate_configuration(base_value)
|
||||||
|
assert increased_cfg.min_digits == 1 and increased_cfg.max_digits == 2
|
||||||
|
assert increased_cfg.min_terms == 2 and increased_cfg.max_terms == 3
|
||||||
|
|
||||||
|
# test decrementing attribute level for num_digits again
|
||||||
|
curriculum.decrement_attr_level("num_digits")
|
||||||
|
|
||||||
|
partially_decreased_cfg = curriculum.generate_configuration(base_value)
|
||||||
|
assert partially_decreased_cfg.min_digits == 1 and partially_decreased_cfg.max_digits == 1
|
||||||
|
assert partially_decreased_cfg.min_terms == 2 and partially_decreased_cfg.max_terms == 3
|
||||||
Loading…
Add table
Add a link
Reference in a new issue