Add complex arithmetic curriculum (#310)

* Add complex arithmetic curriculum
This commit is contained in:
Adefioye 2025-03-09 18:28:51 -05:00 committed by GitHub
parent a1dc28aa73
commit f5141b32c5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 139 additions and 5 deletions

View file

@ -1,9 +1,10 @@
import cmath
import math
import random
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import Optional
from ..coaching import AttributeType, BaseCurriculum, ScalarAttributeDefinition
from ..factory import ProceduralDataset, register_dataset
@ -14,6 +15,7 @@ class ComplexArithmeticConfig:
min_imag: int = -10
max_imag: int = 10
operations: tuple[str, ...] = ("+", "-", "*", "/")
operations_weights: list[float] = field(default_factory=lambda: [0.4, 0.4, 0.1, 0.1])
seed: Optional[int] = None
size: int = 500
@ -22,6 +24,7 @@ class ComplexArithmeticConfig:
assert self.max_real >= self.min_real, "max_real must be >= min_real"
assert self.max_imag >= self.min_imag, "max_imag must be >= min_imag"
assert all(op in ("+", "-", "*", "/") for op in self.operations), "invalid operator"
assert round(sum(self.operations_weights), 1) == 1.0, "operations_weights must sum to 1.0"
class ComplexArithmeticDataset(ProceduralDataset):
@ -57,7 +60,7 @@ class ComplexArithmeticDataset(ProceduralDataset):
rng = random.Random(self.seed + idx)
# Choose random operation
op = rng.choice(self.config.operations)
op = rng.choices(self.config.operations, weights=self.config.operations_weights, k=1)[0]
if op == "/":
# For division, first generate the quotient (a) and divisor (b)
@ -91,6 +94,13 @@ class ComplexArithmeticDataset(ProceduralDataset):
"num2": (b.real, b.imag),
"operation": op,
"result": (int(result.real), int(result.imag)), # Convert to int since we ensure whole numbers
"difficulty": {
"min_real": self.config.min_real,
"max_real": self.config.max_real,
"min_imag": self.config.min_imag,
"max_imag": self.config.max_imag,
"operations_weights": self.config.operations_weights,
},
},
}
@ -169,4 +179,60 @@ class ComplexArithmeticDataset(ProceduralDataset):
return 0.0
register_dataset("complex_arithmetic", ComplexArithmeticDataset, ComplexArithmeticConfig)
class ComplexArithmeticCurriculum(BaseCurriculum):
"""Curriculum for complex number arithmetic problems."""
def __init__(self):
super().__init__(ComplexArithmeticCurriculum.__name__, ComplexArithmeticConfig)
# Define attributes
self._define_attributes(
ScalarAttributeDefinition(
name="min_real",
field_name="min_real",
levels=[-10, -100, -10000, -100000000],
default_level=0,
description="Minimum real part for complex numbers",
attr_type=AttributeType.STATIC,
min_value=-10,
),
ScalarAttributeDefinition(
name="max_real",
field_name="max_real",
levels=[10, 100, 10000, 100000000],
default_level=0,
description="Maximum real part for complex numbers",
attr_type=AttributeType.STATIC,
min_value=10,
),
ScalarAttributeDefinition(
name="min_imag",
field_name="min_imag",
levels=[-10, -100, -10000, -100000000],
default_level=0,
description="Minimum imaginary part for complex numbers",
attr_type=AttributeType.STATIC,
min_value=-10,
),
ScalarAttributeDefinition(
name="max_imag",
field_name="max_imag",
levels=[10, 100, 10000, 100000000],
default_level=0,
description="Maximum imaginary part for complex numbers",
attr_type=AttributeType.STATIC,
min_value=10,
),
ScalarAttributeDefinition(
name="operations_weights",
field_name="operations_weights",
levels=[[0.4, 0.4, 0.1, 0.1], [0.25, 0.25, 0.25, 0.25], [0.2, 0.2, 0.3, 0.3], [0.1, 0.1, 0.4, 0.4]],
default_level=0,
description="Operations weights to sample operation to use for each complex arithmetic problem",
attr_type=AttributeType.STATIC,
min_value=[0.4, 0.4, 0.1, 0.1],
),
)
register_dataset("complex_arithmetic", ComplexArithmeticDataset, ComplexArithmeticConfig, ComplexArithmeticCurriculum)