time intervals curriculum (#363)

This commit is contained in:
Zafir Stojanovski 2025-03-14 16:11:55 +01:00 committed by GitHub
parent a71994ad03
commit 41f3ef876c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 61 additions and 3 deletions

View file

@ -6,6 +6,7 @@ from typing import Optional
import pytz
from dateutil import parser
from ..coaching import AttributeType, BaseCurriculum, ScalarAttributeDefinition
from ..factory import ProceduralDataset, register_dataset
@ -138,6 +139,10 @@ class TimeIntervalsDataset(ProceduralDataset):
"end_time": end_dt,
"format": format_str,
"expected_format": expected_format,
"difficulty": {
"max_time_difference_seconds": self.config.max_time_difference_seconds,
"max_date_difference_days": self.config.max_date_difference_days,
},
},
}
@ -319,5 +324,32 @@ class TimeIntervalsDataset(ProceduralDataset):
return 0.0
class TimeIntervalsCurriculum(BaseCurriculum):
def __init__(self):
super().__init__(TimeIntervalsCurriculum.__name__, TimeIntervalsConfig)
# Define attributes
self._define_attributes(
ScalarAttributeDefinition(
name="max_time_difference_seconds",
field_name="max_time_difference_seconds",
levels=[60, 24 * 60 * 60, 7 * 24 * 60 * 60, 30 * 24 * 60 * 60, 365 * 24 * 60 * 60],
default_level=0,
description="Maximum time difference in seconds",
attr_type=AttributeType.STATIC,
min_value=1,
),
ScalarAttributeDefinition(
name="max_date_difference_days",
field_name="max_date_difference_days",
levels=[1, 7, 30, 365, 5 * 365],
default_level=0,
description="Maximum date difference in days",
attr_type=AttributeType.STATIC,
min_value=1,
),
)
# Register the dataset
register_dataset("time_intervals", TimeIntervalsDataset, TimeIntervalsConfig)
register_dataset("time_intervals", TimeIntervalsDataset, TimeIntervalsConfig, TimeIntervalsCurriculum)