reasoning-gym/reasoning_gym/algebra/complex_arithmetic.py
Zafir Stojanovski ce0a6c4878
fix(envs): Add source dataset and index to metadata (#388)
* add source dataset and index to metadata

* fix typo

* fix coach class and its test
2025-03-20 11:12:14 +00:00

227 lines
8.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import cmath
import math
import random
from dataclasses import dataclass, field
from typing import Optional
from ..coaching import BaseCurriculum, ScalarAttributeDefinition
from ..factory import ProceduralDataset, register_dataset
DATASET_NAME = "complex_arithmetic"
@dataclass
class ComplexArithmeticConfig:
min_real: int = -10
max_real: int = 10
min_imag: int = -10
max_imag: int = 10
operations: tuple[str, ...] = ("+", "-", "*", "/")
operations_weights: list[float] = field(default_factory=lambda: [0.4, 0.4, 0.1, 0.1])
seed: Optional[int] = None
size: int = 500
def validate(self) -> None:
"""Validate configuration parameters."""
assert self.max_real >= self.min_real, "max_real must be >= min_real"
assert self.max_imag >= self.min_imag, "max_imag must be >= min_imag"
assert all(op in ("+", "-", "*", "/") for op in self.operations), "invalid operator"
assert round(sum(self.operations_weights), 1) == 1.0, "operations_weights must sum to 1.0"
class ComplexArithmeticDataset(ProceduralDataset):
"""Generates complex number arithmetic problems."""
def __init__(self, config: ComplexArithmeticConfig):
self._prompt_templates = {
"+": "Add the complex numbers: ({a}) + ({b})",
"-": "Subtract the complex numbers: ({a}) - ({b})",
"*": "Multiply the complex numbers: ({a}) × ({b})",
"/": "Divide the complex numbers: ({a}) ÷ ({b})",
}
super().__init__(config=config, seed=config.seed, size=config.size)
def _generate_complex(self, rng: random.Random) -> complex:
"""Generate a random complex number."""
real = rng.randint(self.config.min_real, self.config.max_real)
imag = rng.randint(self.config.min_imag, self.config.max_imag)
return complex(real, imag)
def _format_complex(self, z: complex) -> str:
"""Format complex number with 2 decimal places."""
real, imag = z.real, z.imag
if abs(imag) < 1e-10:
return f"{real:.2f}"
elif abs(real) < 1e-10:
return f"{imag:.2f}i"
else:
sign = "+" if imag >= 0 else "-"
return f"{real} {sign} {abs(imag)}i"
def __getitem__(self, idx: int) -> dict:
rng = random.Random(self.seed + idx)
# Choose random operation
op = rng.choices(self.config.operations, weights=self.config.operations_weights, k=1)[0]
if op == "/":
# For division, first generate the quotient (a) and divisor (b)
# Then calculate the dividend (result) as a * b
a = self._generate_complex(rng) # This will be the final result
b = self._generate_complex(rng)
while b == 0: # Ensure non-zero divisor
b = self._generate_complex(rng)
result = a # Store the intended result
a = result * b # Calculate dividend to ensure whole number division
else:
# For other operations, generate numbers normally
a = self._generate_complex(rng)
b = self._generate_complex(rng)
# Calculate result
if op == "+":
result = a + b
elif op == "-":
result = a - b
else: # op == "*"
result = a * b
question = self._prompt_templates[op].format(a=self._format_complex(a), b=self._format_complex(b))
return {
"question": question,
"answer": self._format_complex(result),
"metadata": {
"source_dataset": DATASET_NAME,
"source_index": idx,
"num1": (a.real, a.imag),
"num2": (b.real, b.imag),
"operation": op,
"result": (int(result.real), int(result.imag)), # Convert to int since we ensure whole numbers
"difficulty": {
"min_real": self.config.min_real,
"max_real": self.config.max_real,
"min_imag": self.config.min_imag,
"max_imag": self.config.max_imag,
"operations_weights": self.config.operations_weights,
},
},
}
@staticmethod
def parse_string_to_complex(answer: str) -> complex:
if answer is None:
return None
try:
# Normalize the answer string by removing spaces and converting to lowercase
answer = answer.replace(" ", "").lower()
# remove brackets
while len(answer) > 1 and answer[0] == "(" and answer[-1] == ")":
answer = answer[1:-1]
# Convert mathematical notation 'i' to Python's 'j' for complex numbers
answer = answer.replace("i", "j")
# Handle real numbers (no imaginary part)
if "j" not in answer:
return complex(float(answer))
# Handle pure imaginary numbers
if answer == "j":
return complex(0, 1)
if answer == "-j":
return complex(0, -1)
# Handle cases like "7j" or "-7j" (no real part)
if (
answer.endswith("j")
and not any(c in answer[:-1] for c in "+-")
or (answer.startswith("-") and not any(c in answer[1:-1] for c in "+-"))
):
# Extract coefficient
coef = answer[:-1]
if coef == "":
coef = "1"
elif coef == "-":
coef = "-1"
return complex(0, float(coef))
# Handle complex numbers with both parts
# Make sure there's a + or - before j if it's not at the beginning
if "j" in answer and not answer.endswith("j"):
return None # Invalid format like "3j+2"
# Handle cases like "3+j" (implicit 1)
if "+j" in answer:
answer = answer.replace("+j", "+1j")
if "-j" in answer:
answer = answer.replace("-j", "-1j")
# Parse the normalized string into a complex number
return complex(answer)
except (ValueError, TypeError):
return None
def score_answer(self, answer: Optional[str], entry: dict) -> float:
"""Score the answer using exponential distance-based scoring."""
if answer is None:
return 0.0
metadata = entry["metadata"]
try:
student_result = self.parse_string_to_complex(answer)
expected_result = complex(*metadata["result"])
# Calculate distance-based score using exponential decay
distance = abs(student_result - expected_result)
score = min(1.0, math.exp(-distance)) # Add 'import math' at the top
return score
except (ValueError, TypeError):
return 0.0
class ComplexArithmeticCurriculum(BaseCurriculum):
"""Curriculum for complex number arithmetic problems."""
def __init__(self):
super().__init__(ComplexArithmeticCurriculum.__name__, ComplexArithmeticConfig)
# Define attributes
self._define_attributes(
ScalarAttributeDefinition(
name="min_real",
field_name="min_real",
levels=[-10, -100, -10000, -100000000],
description="Minimum real part for complex numbers",
),
ScalarAttributeDefinition(
name="max_real",
field_name="max_real",
levels=[10, 100, 10000, 100000000],
description="Maximum real part for complex numbers",
),
ScalarAttributeDefinition(
name="min_imag",
field_name="min_imag",
levels=[-10, -100, -10000, -100000000],
description="Minimum imaginary part for complex numbers",
),
ScalarAttributeDefinition(
name="max_imag",
field_name="max_imag",
levels=[10, 100, 10000, 100000000],
description="Maximum imaginary part for complex numbers",
),
ScalarAttributeDefinition(
name="operations_weights",
field_name="operations_weights",
levels=[[0.4, 0.4, 0.1, 0.1], [0.25, 0.25, 0.25, 0.25], [0.2, 0.2, 0.3, 0.3], [0.1, 0.1, 0.4, 0.4]],
description="Operations weights to sample operation to use for each complex arithmetic problem",
),
)
register_dataset(DATASET_NAME, ComplexArithmeticDataset, ComplexArithmeticConfig, ComplexArithmeticCurriculum)