Merge branch 'main' into rich/graphcolor

This commit is contained in:
Andreas Köpf 2025-02-14 07:09:38 +01:00 committed by GitHub
commit b64d0af2bc
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
19 changed files with 385 additions and 61 deletions

View file

@ -76,6 +76,11 @@ class IntermediateIntegrationDataset(ProceduralDataset):
"Calculate the antiderivative: ∫ {integrand} dx", "Calculate the antiderivative: ∫ {integrand} dx",
"Evaluate the indefinite integral: ∫ {integrand} dx", "Evaluate the indefinite integral: ∫ {integrand} dx",
] ]
self.added_instruction = """
In addition, when doing calculation, use the following instructions together with your mathematical ingenuity to solve the integral problems
## 1. Use ** instead ^ to represent powers. For example 7*X**2 instead of 7*X^2.
## 2. Always use * when doing all sorts of multiplcation in your reasoning steps. For example Use [-3*X**3*sin(X) - 9*X**2*cos(X) + 18*X*sin(X) + 18*cos(X) + C] instead of [-3x3sin(x) - 9x2cos(x) + 18xsin(x) + 18cos(x) + C].
"""
def _get_outer_constant(self, rng: random.Random) -> int: def _get_outer_constant(self, rng: random.Random) -> int:
"""Helper to generate signed outer constant from config""" """Helper to generate signed outer constant from config"""
@ -222,9 +227,10 @@ class IntermediateIntegrationDataset(ProceduralDataset):
answer = sympy.integrate(integrand, x) answer = sympy.integrate(integrand, x)
answer_str = str(answer) + " + C" answer_str = str(answer) + " + C"
question = rng.choice(self.prompt_template).format(integrand=integrand) + self.added_instruction
return { return {
"question": rng.choice(self.prompt_template).format(integrand=integrand), "question": question,
"answer": answer_str, "answer": answer_str,
"metadata": { "metadata": {
"integrand": str(integrand), "integrand": str(integrand),

View file

@ -62,6 +62,14 @@ class PolynomialEquationsDataset(ProceduralDataset):
"Determine the real value(s) of {variable} that satisfies: {polynomial_expanded} = 0", "Determine the real value(s) of {variable} that satisfies: {polynomial_expanded} = 0",
"Solve the polynomial equation for real {variable}:\n{polynomial_expanded} = 0", "Solve the polynomial equation for real {variable}:\n{polynomial_expanded} = 0",
] ]
self.added_instruction = """
In solving the equations, please abide by the following instruction:
## 1. All answers should be comma-separated. For example "-0.3773, 0.4005" etc.
## 2. In cases where your answer is b = 2 + sqrt(4560) / 172 and b = 2 - sqrt(4560) / 172. Since b can be 2 numbers, resolve your answer like this instead, "-0.3773, 0.4005".
## 3. If there are no real values of i that satisfy the equation, report your answer as empty string, "".
## 4. If there are 2 answers, resolve the answers as comma-separated floats of 2 numbers, if 3 answers, make it comma-separated floats of 3 numbers.
## 5. Resolve all numbers as floats in the string of comma-separated numbers. Round the floats higher than 4 decimal place(d.p) down to 4 d.p.
"""
super().__init__(config=config, seed=config.seed, size=config.size) super().__init__(config=config, seed=config.seed, size=config.size)
def __getitem__(self, idx: int) -> dict: def __getitem__(self, idx: int) -> dict:
@ -89,19 +97,20 @@ class PolynomialEquationsDataset(ProceduralDataset):
for sol in solutions: for sol in solutions:
if sol.is_real: if sol.is_real:
# Evaluate symbolic solution to a floating approximation # Evaluate symbolic solution to a floating approximation
real_solutions.append(float(sol.evalf())) real_solutions.append(round(float(sol.evalf()), 4))
if len(real_solutions) > 0: if len(real_solutions) > 0:
real_solutions.sort() real_solutions.sort()
break break
answer_str = ", ".join(str(x) for x in real_solutions) answer_str = ", ".join(str(x) for x in real_solutions)
question = (
rng.choice(self._prompt_templates).format(variable=variable, polynomial_expanded=polynomial_expanded)
+ self.added_instruction
)
return { return {
"question": rng.choice(self._prompt_templates).format( "question": question,
variable=variable,
polynomial_expanded=polynomial_expanded,
),
"answer": answer_str, "answer": answer_str,
"metadata": { "metadata": {
"polynomial_expr": str(polynomial_expanded), "polynomial_expr": str(polynomial_expanded),

View file

@ -61,6 +61,11 @@ class PolynomialMultiplicationDataset(ProceduralDataset):
"Simplify this expression: {polynomial_expr}", "Simplify this expression: {polynomial_expr}",
"Calculate the following: {polynomial_expr}", "Calculate the following: {polynomial_expr}",
] ]
self.added_instruction = """
In addition, When doing calculation, Use the following instructions together with your mathematical ingenuity to solve the integral problems
## 1. Use ** instead ^ to represent powers. For example 7*X**2 instead of 7*X^2.
## 2. Always use * when doing all sorts of multiplcation in your reasoning steps and even in reporting answers.
"""
super().__init__(config=config, seed=config.seed, size=config.size) super().__init__(config=config, seed=config.seed, size=config.size)
def __getitem__(self, idx: int) -> dict: def __getitem__(self, idx: int) -> dict:
@ -79,11 +84,10 @@ class PolynomialMultiplicationDataset(ProceduralDataset):
polynomial_expr = sp.prod(polynomials) polynomial_expr = sp.prod(polynomials)
product = sp.expand(polynomial_expr) product = sp.expand(polynomial_expr)
question = rng.choice(self._prompt_templates).format(polynomial_expr=polynomial_expr) + self.added_instruction
return { return {
"question": rng.choice(self._prompt_templates).format( "question": question,
polynomial_expr=polynomial_expr,
),
"answer": product, "answer": product,
"metadata": { "metadata": {
"polynomial_expr": str(polynomial_expr), "polynomial_expr": str(polynomial_expr),

View file

@ -41,6 +41,11 @@ class SimpleIntegrationDataset(ProceduralDataset):
"Calculate the antiderivative: ∫ {integrand} dx", "Calculate the antiderivative: ∫ {integrand} dx",
"Evaluate the indefinite integral: ∫ {integrand} dx", "Evaluate the indefinite integral: ∫ {integrand} dx",
] ]
self.added_instruction = """
In addition, When doing calculation, Use the following instructions together with your mathematical ingenuity to solve the integral problems
## 1. Use ** instead ^ to represent powers. For example 7*X**2 instead of 7*X^2.
## 2. Always use * when doing all sorts of multiplcation in your reasoning steps. For example Use [-3*X**3*sin(X) - 9*X**2*cos(X) + 18*X*sin(X) + 18*cos(X) + C] instead of [-3x3sin(x) - 9x2cos(x) + 18xsin(x) + 18cos(x) + C].
"""
super().__init__(config=config, seed=config.seed, size=config.size) super().__init__(config=config, seed=config.seed, size=config.size)
def _generate_coefficient(self, rng: random.Random) -> Fraction: def _generate_coefficient(self, rng: random.Random) -> Fraction:
@ -69,9 +74,10 @@ class SimpleIntegrationDataset(ProceduralDataset):
rng = random.Random(self.seed + idx) rng = random.Random(self.seed + idx)
symbol, polynomial = self._generate_polynomial(rng) symbol, polynomial = self._generate_polynomial(rng)
derivative = sympy.diff(polynomial, symbol) derivative = sympy.diff(polynomial, symbol)
question = rng.choice(self._prompt_templates).format(integrand=derivative) + self.added_instruction
return { return {
"question": rng.choice(self._prompt_templates).format(integrand=derivative), "question": question,
"answer": str(polynomial) + " + C", "answer": str(polynomial) + " + C",
"metadata": { "metadata": {
"integrand": str(derivative), "integrand": str(derivative),

View file

@ -26,7 +26,6 @@ from .rotate_matrix import RotateMatrixConfig, RotateMatrixDataset
from .sentence_reordering import SentenceReorderingConfig, SentenceReorderingDataset from .sentence_reordering import SentenceReorderingConfig, SentenceReorderingDataset
from .spell_backward import SpellBackwardConfig, SpellBackwardDataset from .spell_backward import SpellBackwardConfig, SpellBackwardDataset
from .spiral_matrix import SpiralMatrixConfig, SpiralMatrixDataset from .spiral_matrix import SpiralMatrixConfig, SpiralMatrixDataset
from .string_insertion import StringInsertionConfig, StringInsertionDataset
from .string_manipulation import StringManipulationConfig, StringManipulationDataset from .string_manipulation import StringManipulationConfig, StringManipulationDataset
from .word_ladder import WordLadderConfig, WordLadderDataset from .word_ladder import WordLadderConfig, WordLadderDataset
from .word_sequence_reversal import WordSequenceReversalConfig, WordSequenceReversalDataset from .word_sequence_reversal import WordSequenceReversalConfig, WordSequenceReversalDataset

View file

@ -34,6 +34,11 @@ class NumberSortingDataset(ProceduralDataset):
def __init__(self, config: NumberSortingConfig): def __init__(self, config: NumberSortingConfig):
super().__init__(config=config, seed=config.seed, size=config.size) super().__init__(config=config, seed=config.seed, size=config.size)
self.added_instruction = """
Please follow the instruction below:
## 1. Let all your answers be a list of numbers. Instead of reporting your answer as -69, -13, 1, 7, 11, 43, 59, 61, use ['-69', '-13', '1', '7', '11', '43', '59', '61'] instead
## 2. Convert all numbers in the square brackets as strings. For example, ['-69', '-13', '1', '7', '11', '43', '59', '61']
"""
def _format_number(self, num: float, decimals: int) -> str: def _format_number(self, num: float, decimals: int) -> str:
"""Format number with specified decimal places""" """Format number with specified decimal places"""
@ -78,9 +83,10 @@ class NumberSortingDataset(ProceduralDataset):
is_ascending = rng.choice([True, False]) is_ascending = rng.choice([True, False])
direction = "ascending" if is_ascending else "descending" direction = "ascending" if is_ascending else "descending"
answer = asc_answer if is_ascending else desc_answer answer = asc_answer if is_ascending else desc_answer
question = f"Sort these numbers in {direction} order: {', '.join(number_strs)}" + self.added_instruction
return { return {
"question": f"Sort these numbers in {direction} order: {', '.join(number_strs)}", "question": question,
"answer": str(answer), "answer": str(answer),
"metadata": {"original_numbers": number_strs, "direction": direction, "sorted_numbers": answer}, "metadata": {"original_numbers": number_strs, "direction": direction, "sorted_numbers": answer},
} }

View file

@ -58,27 +58,27 @@ class Arc1DDataset(ProceduralDataset):
- metadata: dict with generation parameters - metadata: dict with generation parameters
""" """
# Create deterministic RNG from base seed and idx # Create deterministic RNG from base seed and idx
item_rng = Random(self.seed + idx) rng = Random(self.seed + idx)
# Select random task # Select random task
task_name = item_rng.choice(self.task_names) task_name = rng.choice(self.task_names)
task_func, task_kwargs = self.ARC_1D_TASKS[task_name] task_func, task_kwargs = self.ARC_1D_TASKS[task_name]
# Generate training examples # Generate training examples
train_examples = [] train_examples = []
size = item_rng.randint(self.config.min_size, self.config.max_size) size = rng.randint(self.config.min_size, self.config.max_size)
for _ in range(self.config.num_train): for _ in range(self.config.num_train):
example = None example = None
while example is None: while example is None:
example = task_func(item_rng, size, **task_kwargs) example = task_func(rng, size, **task_kwargs)
train_examples.append(example) train_examples.append(example)
# Generate test example # Generate test example
test_example = None test_example = None
while test_example is None: while test_example is None:
test_example = task_func(item_rng, size, **task_kwargs) test_example = task_func(rng, size, **task_kwargs)
# Format question # Format question
question = "Find the common rule that maps an input grid to an output grid, given the examples below.\n\n" question = "Find the common rule that maps an input grid to an output grid, given the examples below.\n\n"

View file

@ -4,7 +4,7 @@ Arithmetic tasks for training reasoning capabilities:
from .basic_arithmetic import BasicArithmeticDataset, BasicArithmeticDatasetConfig from .basic_arithmetic import BasicArithmeticDataset, BasicArithmeticDatasetConfig
from .calendar_arithmetic import CalendarArithmeticConfig, CalendarArithmeticDataset from .calendar_arithmetic import CalendarArithmeticConfig, CalendarArithmeticDataset
from .chain_sum import ChainSum, ChainSumConfig from .chain_sum import ChainSumConfig, ChainSumDataset
from .count_bits import CountBitsConfig, CountBitsDataset from .count_bits import CountBitsConfig, CountBitsDataset
from .dice import DiceConfig, DiceDataset from .dice import DiceConfig, DiceDataset
from .fraction_simplification import FractionSimplificationConfig, FractionSimplificationDataset from .fraction_simplification import FractionSimplificationConfig, FractionSimplificationDataset
@ -14,12 +14,13 @@ from .lcm import LCMConfig, LCMDataset
from .leg_counting import LegCountingConfig, LegCountingDataset from .leg_counting import LegCountingConfig, LegCountingDataset
from .power_function import PowerFunctionConfig, PowerFunctionDataset from .power_function import PowerFunctionConfig, PowerFunctionDataset
from .prime_factorization import PrimeFactorizationConfig, PrimeFactorizationDataset from .prime_factorization import PrimeFactorizationConfig, PrimeFactorizationDataset
from .products import ProductsConfig, ProductsDataset
from .time_intervals import TimeIntervalsConfig, TimeIntervalsDataset from .time_intervals import TimeIntervalsConfig, TimeIntervalsDataset
__all__ = [ __all__ = [
"BasicArithmeticDataset", "BasicArithmeticDataset",
"BasicArithmeticDatasetConfig", "BasicArithmeticDatasetConfig",
"ChainSum", "ChainSumDataset",
"ChainSumConfig", "ChainSumConfig",
"CalendarArithmeticConfig", "CalendarArithmeticConfig",
"CalendarArithmeticDataset", "CalendarArithmeticDataset",
@ -31,8 +32,12 @@ __all__ = [
"LCMDataset", "LCMDataset",
"LegCountingConfig", "LegCountingConfig",
"LegCountingDataset", "LegCountingDataset",
"PowerFunctionConfig",
"PowerFunctionDataset",
"PrimeFactorizationConfig", "PrimeFactorizationConfig",
"PrimeFactorizationDataset", "PrimeFactorizationDataset",
"ProductsDataset",
"ProductsConfig",
"GSMSymbolicDatasetConfig", "GSMSymbolicDatasetConfig",
"GSMSymbolicDataset", "GSMSymbolicDataset",
"TimeIntervalsConfig", "TimeIntervalsConfig",

View file

@ -78,17 +78,17 @@ class BasicArithmeticDataset(ProceduralDataset):
- metadata: dict with generation parameters - metadata: dict with generation parameters
""" """
# Create deterministic RNG from base seed and idx # Create deterministic RNG from base seed and idx
item_rng = Random(self.seed + idx) rng = Random(self.seed + idx)
num_terms = item_rng.randint(self.config.min_terms, self.config.max_terms) num_terms = rng.randint(self.config.min_terms, self.config.max_terms)
num_digits = item_rng.randint(self.config.min_digits, self.config.max_digits) num_digits = rng.randint(self.config.min_digits, self.config.max_digits)
if self.config.allow_parentheses: if self.config.allow_parentheses:
expression, result = self._generate_complex_task(item_rng, num_terms, num_digits) expression, result = self._generate_complex_task(rng, num_terms, num_digits)
else: else:
expression, result = self._generate_simple_task(item_rng, num_terms, num_digits) expression, result = self._generate_simple_task(rng, num_terms, num_digits)
question = self._format_question(item_rng, expression) question = self._format_question(rng, expression)
return { return {
"question": question, "question": question,

View file

@ -122,9 +122,9 @@ class CalendarArithmeticDataset(ProceduralDataset):
self.tasks = [self.task_handlers[task] for task in self.config.tasks] self.tasks = [self.task_handlers[task] for task in self.config.tasks]
def __getitem__(self, idx: int) -> dict: def __getitem__(self, idx: int) -> dict:
item_rng = random.Random(self.seed + idx) rng = random.Random(self.seed + idx)
task = item_rng.choice(self.tasks) task = rng.choice(self.tasks)
question, answer, metadata = task(item_rng) question, answer, metadata = task(rng)
return { return {
"question": question, "question": question,
"answer": str(answer), "answer": str(answer),

View file

@ -32,7 +32,7 @@ class ChainSumConfig:
assert 10 ** (self.min_digits - 1) >= 1, "min_digits would result in invalid number range" assert 10 ** (self.min_digits - 1) >= 1, "min_digits would result in invalid number range"
class ChainSum(ProceduralDataset): class ChainSumDataset(ProceduralDataset):
"""Generates simple arithmetic tasks using only + and - operators""" """Generates simple arithmetic tasks using only + and - operators"""
def __init__(self, config: ChainSumConfig): def __init__(self, config: ChainSumConfig):
@ -51,16 +51,16 @@ class ChainSum(ProceduralDataset):
- metadata: dict with generation parameters - metadata: dict with generation parameters
""" """
# Create deterministic RNG from base seed and idx # Create deterministic RNG from base seed and idx
item_rng = random.Random(self.seed + idx) rng = random.Random(self.seed + idx)
num_terms = item_rng.randint(self.config.min_terms, self.config.max_terms) num_terms = rng.randint(self.config.min_terms, self.config.max_terms)
num_digits = item_rng.randint(self.config.min_digits, self.config.max_digits) num_digits = rng.randint(self.config.min_digits, self.config.max_digits)
# Calculate value ranges based on number of digits # Calculate value ranges based on number of digits
min_value = 0 if num_digits == 1 else 10 ** (num_digits - 1) # Special case for 1 digit min_value = 0 if num_digits == 1 else 10 ** (num_digits - 1) # Special case for 1 digit
max_value = (10**num_digits) - 1 # e.g., 999 for 3 digits max_value = (10**num_digits) - 1 # e.g., 999 for 3 digits
expression, result = self._generate_task(item_rng, num_terms, min_value, max_value) expression, result = self._generate_task(rng, num_terms, min_value, max_value)
return { return {
"question": f"{expression} =", "question": f"{expression} =",
@ -143,4 +143,4 @@ class ChainSumCurriculum(BaseCurriculum):
# Register the dataset # Register the dataset
register_dataset("chain_sum", ChainSum, ChainSumConfig) register_dataset("chain_sum", ChainSumDataset, ChainSumConfig)

View file

@ -0,0 +1,130 @@
import random
from dataclasses import dataclass
from typing import Optional
from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition
from ..factory import ProceduralDataset, register_dataset
@dataclass
class ProductsConfig:
"""Configuration for products task generation"""
min_terms: int = 2
max_terms: int = 2
min_digits: int = 1
max_digits: int = 5
seed: Optional[int] = None
size: int = 500
def validate(self) -> None:
"""Validate configuration parameters"""
assert self.size > 0, "size must be positive"
assert self.min_terms > 0, "min_terms must be positive"
assert self.max_terms >= self.min_terms, "max_terms must be >= min_terms"
assert self.min_digits > 0, "min_digits must be positive"
assert self.max_digits >= self.min_digits, "max_digits must be >= min_digits"
class ProductsDataset(ProceduralDataset):
"""Generates multiplication tasks with configurable number of terms"""
def __init__(self, config: ProductsConfig):
super().__init__(config=config, seed=config.seed, size=config.size)
def __getitem__(self, idx: int) -> dict:
"""Generate a single multiplication task
Args:
idx: Index of the item to generate
Returns:
dict with keys:
- question: str, the formatted multiplication expression
- answer: str, the ground truth result
- metadata: dict with generation parameters
"""
# Create deterministic RNG from base seed and idx
rng = random.Random(self.seed + idx)
num_terms = rng.randint(self.config.min_terms, self.config.max_terms)
num_digits = rng.randint(self.config.min_digits, self.config.max_digits)
# Calculate value ranges based on number of digits
min_value = 0 if num_digits == 1 else 10 ** (num_digits - 1) # Special case for 1 digit
max_value = (10**num_digits) - 1 # e.g., 999 for 3 digits
expression, result = self._generate_task(rng, num_terms, min_value, max_value)
return {
"question": f"{expression} =",
"answer": str(result),
"metadata": {
"difficulty": {
"num_terms": num_terms,
"num_digits": num_digits,
},
"expression": expression,
},
}
def _generate_task(self, rng: random.Random, num_terms: int, min_value: int, max_value: int) -> tuple[str, int]:
"""Generate a multiplication task
Args:
rng: Random number generator
num_terms: Number of terms in the expression
min_value: Minimum value for generated numbers
max_value: Maximum value for generated numbers
Returns:
Tuple of (expression string, result integer)
"""
# Generate random numbers within the specified range
constants = [rng.randint(min_value, max_value) for _ in range(num_terms)]
# Build expression and compute result
expression_parts = []
result = constants[0]
expression_parts.append(str(constants[0]))
for i in range(1, len(constants)):
expression_parts.append("*")
expression_parts.append(str(constants[i]))
result *= constants[i]
expression = " ".join(expression_parts)
return expression, result
class ProductsCurriculum(BaseCurriculum):
def __init__(self):
super().__init__(ProductsCurriculum.__name__, ProductsConfig)
# Define attributes
self._define_attributes(
RangeAttributeDefinition(
name="num_terms",
levels=[2, 3, 4, 5],
default_level=0, # Start with 2 terms
description="Maximum number of terms in the expression",
attr_type=AttributeType.APPEND,
min_value=2, # Ensure at least 2 terms
lower_field_name="min_terms",
upper_field_name="max_terms",
),
RangeAttributeDefinition(
name="num_digits",
levels=[1, 2, 3, 4],
default_level=0, # Start with 1-digit numbers
description="Number of digits in each operand",
attr_type=AttributeType.APPEND,
min_value=1, # Ensure numbers are at least 1 digit
lower_field_name="min_digits",
upper_field_name="max_digits",
),
)
# Register the dataset
register_dataset("products", ProductsDataset, ProductsConfig)

View file

@ -82,14 +82,14 @@ class TimeIntervalsDataset(ProceduralDataset):
def __getitem__(self, idx: int) -> dict: def __getitem__(self, idx: int) -> dict:
"""Generate a single time interval calculation task""" """Generate a single time interval calculation task"""
item_rng = random.Random(self.seed + idx) rng = random.Random(self.seed + idx)
# Randomly choose task type from config # Randomly choose task type from config
task_type = item_rng.choice(self.config.task_types) task_type = rng.choice(self.config.task_types)
start_time, end_time, format_str, expected_format = self._generate_times(item_rng, task_type) start_time, end_time, format_str, expected_format = self._generate_times(rng, task_type)
template = item_rng.choice(self.TEMPLATES) template = rng.choice(self.TEMPLATES)
question = template.format(start=start_time, end=end_time, format=expected_format) question = template.format(start=start_time, end=end_time, format=expected_format)
# Calculate the actual difference # Calculate the actual difference

View file

@ -1,6 +1,7 @@
import json
from dataclasses import dataclass from dataclasses import dataclass
from random import Random from random import Random
from typing import Dict, List, Optional, Tuple from typing import Dict, Optional
import cellpylib as cpl import cellpylib as cpl
@ -11,8 +12,8 @@ from ..factory import ProceduralDataset, register_dataset
class GameOfLifeConfig: class GameOfLifeConfig:
"""Configuration for sudoku puzzle generation""" """Configuration for sudoku puzzle generation"""
grid_size_x: int = 20 grid_size_x: int = 10
grid_size_y: int = 20 grid_size_y: int = 10
filled_cells: int = 100 # actually a max filled_cells: int = 100 # actually a max
simulation_steps: int = 1 simulation_steps: int = 1
seed: Optional[int] = None seed: Optional[int] = None
@ -31,7 +32,7 @@ class GameOfLifeDataset(ProceduralDataset):
def __init__(self, config: GameOfLifeConfig): def __init__(self, config: GameOfLifeConfig):
self._prompt_templates = [ self._prompt_templates = [
"What will this Game of Life board look like after {simulation_steps} steps of simulation?\n\n{board}" "What will this Game of Life board look like after {simulation_steps} steps of simulation? Reply as array of array representing rows in the grid from top to bottom in JSON format. (An empty 3x3 grid would look like this: [[0,0,0],[0,0,0],[0,0,0]])\n\n{board}."
] ]
super().__init__(config=config, seed=config.seed, size=config.size) super().__init__(config=config, seed=config.seed, size=config.size)
@ -59,11 +60,18 @@ class GameOfLifeDataset(ProceduralDataset):
# Simulate the result to get the answer # Simulate the result to get the answer
evolved = cpl.evolve2d( evolved = cpl.evolve2d(
board, timesteps=self.config.simulation_steps + 1, apply_rule=cpl.game_of_life_rule, memoize="recursive" board,
timesteps=self.config.simulation_steps + 1,
apply_rule=cpl.game_of_life_rule,
memoize="recursive",
) )
board_str = str(board[0]) rows = [json.dumps(board[0, i].tolist(), separators=(",", ":")) for i in range(board.shape[1])]
result_str = str(evolved[-1]) board_str = "[" + ", \n ".join(rows) + "]"
final_step = evolved[-1]
final_step_list = final_step.tolist()
result_str = json.dumps(final_step_list, separators=(",", ":"))
return { return {
"question": rng.choice(self._prompt_templates).format( "question": rng.choice(self._prompt_templates).format(
@ -93,10 +101,17 @@ class GameOfLifeDataset(ProceduralDataset):
if answer == None: if answer == None:
return 0.0 return 0.0
if answer.replace("\n", "") != entry["answer"].replace("\n", ""):
try:
ans_arr = json.loads(answer)
correct_arr = json.loads(entry["answer"])
if correct_arr != ans_arr:
return 0.01 return 0.01
else: else:
return 1.0 # Yay return 1.0 # Yay
except Exception as e:
return 0.01
register_dataset("game_of_life", GameOfLifeDataset, GameOfLifeConfig) register_dataset("game_of_life", GameOfLifeDataset, GameOfLifeConfig)

View file

@ -1,6 +1,6 @@
import pytest import pytest
from reasoning_gym.arithmetic import ChainSum, ChainSumConfig from reasoning_gym.arithmetic import ChainSumConfig, ChainSumDataset
from reasoning_gym.arithmetic.chain_sum import ChainSumCurriculum from reasoning_gym.arithmetic.chain_sum import ChainSumCurriculum
@ -18,8 +18,8 @@ def test_chain_sum_config_validation():
def test_chain_sum_deterministic(): def test_chain_sum_deterministic():
"""Test that dataset generates same items with same seed""" """Test that dataset generates same items with same seed"""
config = ChainSumConfig(seed=42, size=10) config = ChainSumConfig(seed=42, size=10)
dataset1 = ChainSum(config) dataset1 = ChainSumDataset(config)
dataset2 = ChainSum(config) dataset2 = ChainSumDataset(config)
for i in range(len(dataset1)): for i in range(len(dataset1)):
assert dataset1[i] == dataset2[i] assert dataset1[i] == dataset2[i]
@ -28,7 +28,7 @@ def test_chain_sum_deterministic():
def test_chain_sum_items(): def test_chain_sum_items():
"""Test basic properties of generated items""" """Test basic properties of generated items"""
config = ChainSumConfig(min_terms=2, max_terms=4, min_digits=1, max_digits=2, size=100, seed=42) config = ChainSumConfig(min_terms=2, max_terms=4, min_digits=1, max_digits=2, size=100, seed=42)
dataset = ChainSum(config) dataset = ChainSumDataset(config)
for i in range(len(dataset)): for i in range(len(dataset)):
item = dataset[i] item = dataset[i]
@ -57,7 +57,7 @@ def test_chain_sum_number_ranges():
size=50, size=50,
seed=42, seed=42,
) )
dataset = ChainSum(config) dataset = ChainSumDataset(config)
for i in range(len(dataset)): for i in range(len(dataset)):
item = dataset[i] item = dataset[i]
@ -71,7 +71,7 @@ def test_chain_sum_number_ranges():
# Test 1-digit numbers # Test 1-digit numbers
config = ChainSumConfig(min_terms=2, max_terms=2, min_digits=1, max_digits=1, size=50, seed=42) config = ChainSumConfig(min_terms=2, max_terms=2, min_digits=1, max_digits=1, size=50, seed=42)
dataset = ChainSum(config) dataset = ChainSumDataset(config)
for i in range(len(dataset)): for i in range(len(dataset)):
item = dataset[i] item = dataset[i]
expression = item["metadata"]["expression"] expression = item["metadata"]["expression"]
@ -88,7 +88,7 @@ def test_chain_sum_negation():
config = ChainSumConfig( config = ChainSumConfig(
min_terms=2, max_terms=2, min_digits=2, max_digits=2, size=100, seed=42, allow_negation=True min_terms=2, max_terms=2, min_digits=2, max_digits=2, size=100, seed=42, allow_negation=True
) )
dataset = ChainSum(config) dataset = ChainSumDataset(config)
# Track if we see both positive and negative numbers # Track if we see both positive and negative numbers
has_positive = False has_positive = False
@ -112,7 +112,7 @@ def test_chain_sum_negation():
def test_chain_sum_iteration(): def test_chain_sum_iteration():
"""Test that iteration respects dataset size""" """Test that iteration respects dataset size"""
config = ChainSumConfig(min_terms=2, max_terms=2, size=5, seed=42) # Small size for testing config = ChainSumConfig(min_terms=2, max_terms=2, size=5, seed=42) # Small size for testing
dataset = ChainSum(config) dataset = ChainSumDataset(config)
# Test manual iteration # Test manual iteration
items = [] items = []

View file

@ -5,7 +5,7 @@ from pathlib import Path
import pytest import pytest
from reasoning_gym.arithmetic.chain_sum import ChainSum, ChainSumConfig from reasoning_gym.arithmetic.chain_sum import ChainSumConfig, ChainSumDataset
from reasoning_gym.arithmetic.leg_counting import LegCountingConfig from reasoning_gym.arithmetic.leg_counting import LegCountingConfig
from reasoning_gym.coaching import Coach, GroupedScores from reasoning_gym.coaching import Coach, GroupedScores
from reasoning_gym.composite import CompositeConfig, CompositeDataset, DatasetSpec from reasoning_gym.composite import CompositeConfig, CompositeDataset, DatasetSpec
@ -14,7 +14,7 @@ from reasoning_gym.composite import CompositeConfig, CompositeDataset, DatasetSp
def test_coach_with_chain_sum(): def test_coach_with_chain_sum():
# Create a small ChainSum dataset # Create a small ChainSum dataset
config = ChainSumConfig(min_terms=2, max_terms=3, min_digits=1, max_digits=2, size=10, seed=42) config = ChainSumConfig(min_terms=2, max_terms=3, min_digits=1, max_digits=2, size=10, seed=42)
dataset = ChainSum(config) dataset = ChainSumDataset(config)
coach = Coach(dataset) coach = Coach(dataset)
# Simulate an agent working on tasks # Simulate an agent working on tasks
@ -208,7 +208,7 @@ def test_coach_score_logging(tmp_path):
# Create dataset and coach with logging # Create dataset and coach with logging
config = ChainSumConfig(min_terms=2, max_terms=3, min_digits=1, max_digits=2, size=10, seed=42) config = ChainSumConfig(min_terms=2, max_terms=3, min_digits=1, max_digits=2, size=10, seed=42)
dataset = ChainSum(config) dataset = ChainSumDataset(config)
coach = Coach(dataset, score_log=log_file) coach = Coach(dataset, score_log=log_file)
# Score a few answers # Score a few answers

View file

@ -7,7 +7,7 @@ def test_game_of_life():
"""Test basic properties and solution of generated items""" """Test basic properties and solution of generated items"""
# Easy # Easy
config = GameOfLifeConfig(seed=42, size=1, grid_size_x=20, grid_size_y=20, filled_cells=10, simulation_steps=1) config = GameOfLifeConfig(seed=42, size=10, grid_size_x=20, grid_size_y=20, filled_cells=200, simulation_steps=1)
dataset = GameOfLifeDataset(config) dataset = GameOfLifeDataset(config)
for item in dataset: for item in dataset:

View file

@ -112,7 +112,7 @@ def test_polynomial_solutions_evaluation():
evaluated_value = poly_expr.subs(x, solution) evaluated_value = poly_expr.subs(x, solution)
# Ensure the evaluated value is close to zero (numerical stability threshold) # Ensure the evaluated value is close to zero (numerical stability threshold)
assert abs(evaluated_value) < 1e-6, ( assert abs(evaluated_value) < 1e-5, (
f"Solution {solution} does not satisfy the polynomial {poly_str}. " f"Solution {solution} does not satisfy the polynomial {poly_str}. "
f"Evaluated value: {evaluated_value}" f"Evaluated value: {evaluated_value}"
) )

144
tests/test_products.py Normal file
View file

@ -0,0 +1,144 @@
import pytest
from reasoning_gym.arithmetic import ProductsConfig, ProductsDataset
from reasoning_gym.arithmetic.products import ProductsCurriculum
def test_products_config_validation():
"""Test that invalid configs raise appropriate errors"""
with pytest.raises(AssertionError):
config = ProductsConfig(min_terms=0)
config.validate()
with pytest.raises(AssertionError):
config = ProductsConfig(min_terms=3, max_terms=2)
config.validate()
def test_products_deterministic():
"""Test that dataset generates same items with same seed"""
config = ProductsConfig(seed=42, size=10)
dataset1 = ProductsDataset(config)
dataset2 = ProductsDataset(config)
for i in range(len(dataset1)):
assert dataset1[i] == dataset2[i]
def test_products_items():
"""Test basic properties of generated items"""
config = ProductsConfig(min_terms=2, max_terms=4, min_digits=1, max_digits=2, size=100, seed=42)
dataset = ProductsDataset(config)
for i in range(len(dataset)):
item = dataset[i]
assert isinstance(item, dict)
assert "question" in item
assert "answer" in item
assert "metadata" in item
# Verify only * is used
expression = item["metadata"]["expression"]
assert all(op in ["*", " "] or op.isdigit() for op in expression)
# Verify the answer matches the expression
answer = eval(expression) # Safe here as we control the expression
assert str(answer) == item["answer"]
def test_products_number_ranges():
"""Test that generated numbers respect digit constraints"""
# Test 3-digit numbers
config = ProductsConfig(
min_terms=2,
max_terms=2, # Fix to 2 terms for easier testing
min_digits=3, # Should generate numbers >= 100
max_digits=3, # Should generate numbers <= 999
size=50,
seed=42,
)
dataset = ProductsDataset(config)
for i in range(len(dataset)):
item = dataset[i]
expression = item["metadata"]["expression"]
numbers = [int(n) for n in expression.split() if n.isdigit()]
for num in numbers:
assert 100 <= num <= 999, f"Number {num} outside valid range for 3 digits"
# Test 1-digit numbers
config = ProductsConfig(min_terms=2, max_terms=2, min_digits=1, max_digits=1, size=50, seed=42)
dataset = ProductsDataset(config)
for i in range(len(dataset)):
item = dataset[i]
expression = item["metadata"]["expression"]
numbers = [int(n) for n in expression.split() if n.isdigit()]
for num in numbers:
assert 0 <= num <= 9, f"Number {num} outside valid range for 1 digit"
def test_products_iteration():
"""Test that iteration respects dataset size"""
config = ProductsConfig(min_terms=2, max_terms=2, size=5, seed=42) # Small size for testing
dataset = ProductsDataset(config)
# Test manual iteration
items = []
for item in dataset:
items.append(item)
assert len(items) == config.size, "Iterator should yield exactly size items"
# Test list conversion
items = list(dataset)
assert len(items) == config.size, "Iterator should yield exactly size items"
# Test multiple iterations
first_items = list(dataset)
second_items = list(dataset)
assert first_items == second_items, "Multiple iterations should yield same items"
def test_products_scoring():
"""Test that scoring works correctly"""
config = ProductsConfig(min_terms=2, max_terms=2, size=10, seed=42)
dataset = ProductsDataset(config)
# Test scoring with exact match
item = dataset[0]
assert dataset.score_answer(item["answer"], item) == 1.0, "Exact match should score 1.0"
# Test scoring with wrong answer
assert dataset.score_answer("wrong", item) == 0.01, "Wrong answer should score 0.01"
# Test scoring with partial match (answer contained in response)
assert dataset.score_answer(f"The answer is {item['answer']}", item) == 0.5, "Partial match should score 0.5"
# Test scoring with None
assert dataset.score_answer(None, item) == 0.0, "None should score 0.0"
def test_products_curriculum():
curriculum = ProductsCurriculum()
base_value = {"size": 150, "seed": 1}
base_cfg: ProductsConfig = curriculum.generate_configuration(base_value)
assert base_cfg.seed == 1
assert base_cfg.size == 150
assert base_cfg.min_digits == 1 and base_cfg.max_digits == 1
assert base_cfg.min_terms == 2 and base_cfg.max_terms == 2
# test incrementing attribute levels for num_terms & num_digits attributes
curriculum.increment_attr_level("num_terms")
curriculum.increment_attr_level("num_digits")
increased_cfg = curriculum.generate_configuration(base_value)
assert increased_cfg.min_digits == 1 and increased_cfg.max_digits == 2
assert increased_cfg.min_terms == 2 and increased_cfg.max_terms == 3
# test decrementing attribute level for num_digits again
curriculum.decrement_attr_level("num_digits")
partially_decreased_cfg = curriculum.generate_configuration(base_value)
assert partially_decreased_cfg.min_digits == 1 and partially_decreased_cfg.max_digits == 1
assert partially_decreased_cfg.min_terms == 2 and partially_decreased_cfg.max_terms == 3