mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-22 16:49:06 +00:00
Add complex arithmetic curriculum (#310)
* Add complex arithmetic curriculum
This commit is contained in:
parent
a1dc28aa73
commit
f5141b32c5
2 changed files with 139 additions and 5 deletions
|
|
@ -1,9 +1,10 @@
|
|||
import cmath
|
||||
import math
|
||||
import random
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
from ..coaching import AttributeType, BaseCurriculum, ScalarAttributeDefinition
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
|
||||
|
|
@ -14,6 +15,7 @@ class ComplexArithmeticConfig:
|
|||
min_imag: int = -10
|
||||
max_imag: int = 10
|
||||
operations: tuple[str, ...] = ("+", "-", "*", "/")
|
||||
operations_weights: list[float] = field(default_factory=lambda: [0.4, 0.4, 0.1, 0.1])
|
||||
seed: Optional[int] = None
|
||||
size: int = 500
|
||||
|
||||
|
|
@ -22,6 +24,7 @@ class ComplexArithmeticConfig:
|
|||
assert self.max_real >= self.min_real, "max_real must be >= min_real"
|
||||
assert self.max_imag >= self.min_imag, "max_imag must be >= min_imag"
|
||||
assert all(op in ("+", "-", "*", "/") for op in self.operations), "invalid operator"
|
||||
assert round(sum(self.operations_weights), 1) == 1.0, "operations_weights must sum to 1.0"
|
||||
|
||||
|
||||
class ComplexArithmeticDataset(ProceduralDataset):
|
||||
|
|
@ -57,7 +60,7 @@ class ComplexArithmeticDataset(ProceduralDataset):
|
|||
rng = random.Random(self.seed + idx)
|
||||
|
||||
# Choose random operation
|
||||
op = rng.choice(self.config.operations)
|
||||
op = rng.choices(self.config.operations, weights=self.config.operations_weights, k=1)[0]
|
||||
|
||||
if op == "/":
|
||||
# For division, first generate the quotient (a) and divisor (b)
|
||||
|
|
@ -91,6 +94,13 @@ class ComplexArithmeticDataset(ProceduralDataset):
|
|||
"num2": (b.real, b.imag),
|
||||
"operation": op,
|
||||
"result": (int(result.real), int(result.imag)), # Convert to int since we ensure whole numbers
|
||||
"difficulty": {
|
||||
"min_real": self.config.min_real,
|
||||
"max_real": self.config.max_real,
|
||||
"min_imag": self.config.min_imag,
|
||||
"max_imag": self.config.max_imag,
|
||||
"operations_weights": self.config.operations_weights,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
|
@ -169,4 +179,60 @@ class ComplexArithmeticDataset(ProceduralDataset):
|
|||
return 0.0
|
||||
|
||||
|
||||
register_dataset("complex_arithmetic", ComplexArithmeticDataset, ComplexArithmeticConfig)
|
||||
class ComplexArithmeticCurriculum(BaseCurriculum):
|
||||
"""Curriculum for complex number arithmetic problems."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(ComplexArithmeticCurriculum.__name__, ComplexArithmeticConfig)
|
||||
|
||||
# Define attributes
|
||||
self._define_attributes(
|
||||
ScalarAttributeDefinition(
|
||||
name="min_real",
|
||||
field_name="min_real",
|
||||
levels=[-10, -100, -10000, -100000000],
|
||||
default_level=0,
|
||||
description="Minimum real part for complex numbers",
|
||||
attr_type=AttributeType.STATIC,
|
||||
min_value=-10,
|
||||
),
|
||||
ScalarAttributeDefinition(
|
||||
name="max_real",
|
||||
field_name="max_real",
|
||||
levels=[10, 100, 10000, 100000000],
|
||||
default_level=0,
|
||||
description="Maximum real part for complex numbers",
|
||||
attr_type=AttributeType.STATIC,
|
||||
min_value=10,
|
||||
),
|
||||
ScalarAttributeDefinition(
|
||||
name="min_imag",
|
||||
field_name="min_imag",
|
||||
levels=[-10, -100, -10000, -100000000],
|
||||
default_level=0,
|
||||
description="Minimum imaginary part for complex numbers",
|
||||
attr_type=AttributeType.STATIC,
|
||||
min_value=-10,
|
||||
),
|
||||
ScalarAttributeDefinition(
|
||||
name="max_imag",
|
||||
field_name="max_imag",
|
||||
levels=[10, 100, 10000, 100000000],
|
||||
default_level=0,
|
||||
description="Maximum imaginary part for complex numbers",
|
||||
attr_type=AttributeType.STATIC,
|
||||
min_value=10,
|
||||
),
|
||||
ScalarAttributeDefinition(
|
||||
name="operations_weights",
|
||||
field_name="operations_weights",
|
||||
levels=[[0.4, 0.4, 0.1, 0.1], [0.25, 0.25, 0.25, 0.25], [0.2, 0.2, 0.3, 0.3], [0.1, 0.1, 0.4, 0.4]],
|
||||
default_level=0,
|
||||
description="Operations weights to sample operation to use for each complex arithmetic problem",
|
||||
attr_type=AttributeType.STATIC,
|
||||
min_value=[0.4, 0.4, 0.1, 0.1],
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
register_dataset("complex_arithmetic", ComplexArithmeticDataset, ComplexArithmeticConfig, ComplexArithmeticCurriculum)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue