mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-30 17:40:45 +00:00
Refactor Curriculum Attributes (#335)
* remove min_value from AttributeDefinition * remove type from AttributeDefinition * Add CurriculumContext * add ensure_interval option for RangeAttributes * docs: Add legend explaining curriculum indicators in dataset gallery * update GALLERY.md
This commit is contained in:
parent
4e7d9296ee
commit
d2c895f1d3
101 changed files with 286 additions and 677 deletions
|
|
@ -1,25 +1,13 @@
|
|||
from collections import abc
|
||||
from dataclasses import dataclass
|
||||
from enum import StrEnum
|
||||
from typing import Any, Optional
|
||||
|
||||
|
||||
class AttributeType(StrEnum):
|
||||
"""Defines how attribute levels should be interpreted"""
|
||||
|
||||
STATIC = "static" # Each level is independent
|
||||
UBOUND = "ubound" # Each level is an upper bound
|
||||
APPEND = "append" # Each level includes all previous levels
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class AttributeDefinition:
|
||||
name: str
|
||||
levels: list
|
||||
default_level: int
|
||||
default_level: int = 0
|
||||
description: Optional[str] = None
|
||||
attr_type: AttributeType = AttributeType.STATIC # Default to static
|
||||
min_value: Optional[int | float] = None # Minimum value for numeric attributes
|
||||
|
||||
def validate_level(self, level: int, curriculum: str) -> None:
|
||||
"""
|
||||
|
|
@ -37,7 +25,7 @@ class AttributeDefinition:
|
|||
f"Must be between 0 and {len(self.levels)-1}"
|
||||
)
|
||||
|
||||
def get_level_value(self, level: int, curriculum: str) -> Any:
|
||||
def get_level_value(self, level: int) -> Any:
|
||||
"""
|
||||
Get the value for an attribute at a specific level based on its type.
|
||||
Args:
|
||||
|
|
@ -46,14 +34,7 @@ class AttributeDefinition:
|
|||
Returns:
|
||||
Value for the attribute based on its level and type
|
||||
"""
|
||||
if self.attr_type == AttributeType.STATIC:
|
||||
return self.levels[level]
|
||||
elif self.attr_type == AttributeType.UBOUND:
|
||||
return self.levels[level]
|
||||
elif self.attr_type == AttributeType.APPEND:
|
||||
return self.levels[: level + 1]
|
||||
|
||||
raise ValueError(f"Unknown attribute type: {self.attr_type} for attribute '{curriculum}.{self.name}'")
|
||||
return self.levels[level]
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
|
|
@ -65,9 +46,4 @@ class ScalarAttributeDefinition(AttributeDefinition):
|
|||
class RangeAttributeDefinition(AttributeDefinition):
|
||||
lower_field_name: str
|
||||
upper_field_name: str
|
||||
|
||||
def get_level_value(self, level: int, curriculum: str) -> Any:
|
||||
v = super().get_level_value(level, curriculum)
|
||||
if not isinstance(v, abc.Iterable):
|
||||
return [v]
|
||||
return v
|
||||
ensure_interval: bool = False # When True, ensures the range is always an interval between two distinct values
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue