(evals): Medium configs (#415)

* updated medium configs

* fix problematic curriculum values / small issues causing exceptions to be raised

* optimus alpha config

* all configs so far

* fix tests
This commit is contained in:
Zafir Stojanovski 2025-04-14 08:25:31 +02:00 committed by GitHub
parent cd1a9ea58b
commit 290bfc4fdd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
25 changed files with 7050 additions and 63 deletions

View file

@ -42,6 +42,12 @@ class ReArcConfig:
assert self.min_examples <= self.max_examples, "min_examples must be <= max_examples"
assert self.diff_lb <= self.diff_ub, "diff_lb must be <= diff_ub."
assert self.size > 0, "Size of dataset must be positive."
assert len(self.rng_difficulty_ranges) == len(
self.rng_difficulty_weights
), "rng_difficulty_ranges and rng_difficulty_weights must have the same length."
assert len(self.pso_difficulty_ranges) == len(
self.pso_difficulty_weights
), "pso_difficulty_ranges and pso_difficulty_weights must have the same length."
class ReArcDataset(ProceduralDataset):
@ -93,6 +99,7 @@ class ReArcDataset(ProceduralDataset):
Generate a single ReArc task
"""
rng = Random(self.seed + idx)
pso_difficulty_range = rng.choices(
self.config.pso_difficulty_ranges, weights=self.config.pso_difficulty_weights, k=1
)[0]
@ -154,14 +161,13 @@ class ReArcCurriculum(BaseCurriculum):
field_name="pso_difficulty_weights",
description="The range of PSO difficulty for the Arc problem",
levels=[
[1, 0, 0, 0, 0, 0, 0, 0], # only sample/generate the easiest tasks wrs PSO difficulty
[0, 1, 0, 0, 0, 0, 0, 0],
[0, 0, 1, 0, 0, 0, 0, 0],
[0, 0, 0, 1, 0, 0, 0, 0],
[0, 0, 0, 0, 1, 0, 0, 0],
[0, 0, 0, 0, 0, 1, 0, 0],
[0, 0, 0, 0, 0, 0, 1, 0],
[0, 0, 0, 0, 0, 0, 0, 1],
[1, 0, 0, 0, 0, 0, 0], # only sample/generate the easiest tasks wrs PSO difficulty
[0, 1, 0, 0, 0, 0, 0],
[0, 0, 1, 0, 0, 0, 0],
[0, 0, 0, 1, 0, 0, 0],
[0, 0, 0, 0, 1, 0, 0],
[0, 0, 0, 0, 0, 1, 0],
[0, 0, 0, 0, 0, 0, 1],
], # only sample/generate the hardest tasks PSO difficulty
),
ScalarAttributeDefinition(
@ -169,14 +175,13 @@ class ReArcCurriculum(BaseCurriculum):
field_name="rng_difficulty_weights",
description="The range of RNG difficulty for the Arc problem",
levels=[
[1, 0, 0, 0, 0, 0, 0, 0], # only sample/generate the easiest tasks wrs RNG difficulty
[0, 1, 0, 0, 0, 0, 0, 0],
[0, 0, 1, 0, 0, 0, 0, 0],
[0, 0, 0, 1, 0, 0, 0, 0],
[0, 0, 0, 0, 1, 0, 0, 0],
[0, 0, 0, 0, 0, 1, 0, 0],
[0, 0, 0, 0, 0, 0, 1, 0],
[0, 0, 0, 0, 0, 0, 0, 1],
[1, 0, 0, 0, 0, 0, 0], # only sample/generate the easiest tasks wrs RNG difficulty
[0, 1, 0, 0, 0, 0, 0],
[0, 0, 1, 0, 0, 0, 0],
[0, 0, 0, 1, 0, 0, 0],
[0, 0, 0, 0, 1, 0, 0],
[0, 0, 0, 0, 0, 1, 0],
[0, 0, 0, 0, 0, 0, 1],
], # only sample/generate the hardest tasks wrs RNG difficulty
),
)