reasoning-gym/reasoning_gym/core/attributes.py
2025-02-09 12:38:16 +00:00

153 lines
No EOL
6.3 KiB
Python

"""
Core definitions for curriculum attributes and types.
"""
from typing import Dict, List, Union, Any, Set, Optional
from dataclasses import dataclass
from enum import Enum
import random
class AttributeType(Enum):
"""Defines how attribute levels should be interpreted"""
STATIC = "static" # Each level is independent
UBOUND = "ubound" # Each level is an upper bound
APPEND = "append" # Each level includes all previous levels
@dataclass
class AttributeDefinition:
"""Defines a difficulty attribute with its possible levels and properties"""
levels: List[Any]
default_level: int
description: str
attr_type: AttributeType = AttributeType.STATIC # Default to static
min_value: Optional[Union[int, float]] = None # Minimum value for numeric attributes
@classmethod
def validate_attributes(cls, attributes: Dict[str, 'AttributeDefinition'], valid_types: Set[AttributeType], curriculum: str) -> None:
"""
Validates that all attributes use types from the valid_types set.
Args:
attributes: Dictionary of attribute definitions
valid_types: Set of allowed AttributeTypes for this curriculum
curriculum: A string identifier for the curriculum or class that owns these attributes
Raises:
ValueError: If any attribute uses an invalid type or has invalid configuration
"""
if not valid_types:
raise ValueError(f"Curriculum {curriculum} has no valid attribute types defined")
if not attributes:
raise ValueError(f"Curriculum {curriculum} has no attributes defined")
for name, attr in attributes.items():
# Check attribute type is valid
if attr.attr_type not in valid_types:
curriculum_class = f"{curriculum}." if curriculum else ""
raise ValueError(
f"Attribute '{curriculum_class}{name}' uses type {attr.attr_type.value} "
f"which is not in the curriculum's valid types: {[t.value for t in valid_types]}"
)
# Check levels exist
if not attr.levels:
raise ValueError(f"Attribute '{curriculum}.{name}' has no levels defined")
# Check default level is valid
if not 0 <= attr.default_level < len(attr.levels):
raise ValueError(
f"Invalid default level: {attr.default_level} for attribute '{curriculum}.{name}'. "
f"Must be between 0 and {len(attr.levels)-1}"
)
@classmethod
def check_attribute_exists(cls, attributes: Dict[str, 'AttributeDefinition'], attr_name: str, curriculum: str) -> 'AttributeDefinition':
"""
Check if attribute exists and return its definition.
Args:
attributes: Dictionary of attribute definitions
attr_name: Name of the attribute to check
curriculum: Name of the curriculum
Returns:
The AttributeDefinition for the attribute
Raises:
KeyError: If attribute doesn't exist
"""
if attr_name not in attributes:
raise KeyError(f"Attribute '{curriculum}.{attr_name}' does not exist")
return attributes[attr_name]
@classmethod
def validate_level(cls, attr: 'AttributeDefinition', level: int, attr_name: str, curriculum: str) -> None:
"""
Validate that a level is valid for an attribute.
Args:
attr: The attribute definition
level: Level to validate
attr_name: Name of the attribute
curriculum: Name of the curriculum
Raises:
ValueError: If level is invalid
"""
# TODO: if > set as [-1], if <0 set as [0]
if not 0 <= level < len(attr.levels):
raise ValueError(
f"Invalid level: {level} for attribute '{curriculum}.{attr_name}'. "
f"Must be between 0 and {len(attr.levels)-1}"
)
@classmethod
def get_level_value(cls, attr: 'AttributeDefinition', level: int, attr_name: str, curriculum: str) -> Any:
"""
Get the value for an attribute at a specific level based on its type.
Args:
attr: The attribute definition
level: Level to get value for
attr_name: Name of the attribute
curriculum: Name of the curriculum
Returns:
Value for the attribute based on its level and type
"""
if attr.attr_type == AttributeType.STATIC:
return attr.levels[level]
elif attr.attr_type == AttributeType.UBOUND:
return attr.levels[level]
elif attr.attr_type == AttributeType.APPEND:
return attr.levels[:level + 1]
raise ValueError(f"Unknown attribute type: {attr.attr_type} for attribute '{curriculum}.{attr_name}'")
def get_generator(self, level: int, rng: random.Random):
"""Returns a generator function based on attribute type and current level"""
match self.attr_type:
case AttributeType.STATIC:
# Returns exactly the value at current level
return lambda: self.levels[level]
case AttributeType.UBOUND:
# Returns random value up to current level bound, respecting min_value
max_val = self.levels[level]
min_val = self.min_value if self.min_value is not None else 0
# Handle both float and int values
if isinstance(max_val, float) or isinstance(min_val, float):
return lambda: rng.uniform(min_val, max_val)
return lambda: rng.randint(min_val, max_val)
case AttributeType.APPEND:
# Returns random choice from accumulated values up to current level
available_values = self.levels[:level + 1]
if isinstance(self.levels[0], list):
available_values = sum(available_values, [])
elif isinstance(self.levels[0], dict):
available_values = [{k: v for d in available_values for k, v in d.items()}]
return lambda: rng.choice(available_values)
raise ValueError(f"Unknown attribute type: {self.attr_type} for attribute '{self.description}'")