fix(envs): Add source dataset and index to metadata (#388)

* add source dataset and index to metadata

* fix typo

* fix coach class and its test
This commit is contained in:
Zafir Stojanovski 2025-03-20 12:12:14 +01:00 committed by GitHub
parent c6d01541aa
commit 4c47527130
104 changed files with 549 additions and 146 deletions

View file

@ -5,6 +5,8 @@ from typing import Any, Literal, Optional
from ..coaching import BaseCurriculum, RangeAttributeDefinition
from ..factory import ProceduralDataset, register_dataset
DATASET_NAME = "basic_arithmetic"
@dataclass
class BasicArithmeticDatasetConfig:
@ -95,6 +97,8 @@ class BasicArithmeticDataset(ProceduralDataset):
"question": question,
"answer": str(result),
"metadata": {
"source_dataset": DATASET_NAME,
"source_index": idx,
"expression": expression,
"num_terms": num_terms,
"num_digits": num_digits,
@ -260,4 +264,4 @@ class BasicArithmeticCurriculum(BaseCurriculum):
# Register the dataset
register_dataset("basic_arithmetic", BasicArithmeticDataset, BasicArithmeticDatasetConfig, BasicArithmeticCurriculum)
register_dataset(DATASET_NAME, BasicArithmeticDataset, BasicArithmeticDatasetConfig, BasicArithmeticCurriculum)

View file

@ -5,6 +5,8 @@ from typing import Any, Optional
from ..coaching import BaseCurriculum, ScalarAttributeDefinition
from ..factory import ProceduralDataset, register_dataset
DATASET_NAME = "bitwise_arithmetic"
@dataclass
class BitwiseArithmeticConfig:
@ -155,7 +157,12 @@ class BitwiseArithmeticDataset(ProceduralDataset):
return {
"question": problem_str,
"answer": answer,
"metadata": {"problem": problem, "difficulty": {"difficulty": self.config.difficulty}},
"metadata": {
"source_dataset": DATASET_NAME,
"source_index": idx,
"problem": problem,
"difficulty": {"difficulty": self.config.difficulty},
},
}
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
@ -193,4 +200,4 @@ class BitwiseArithmeticCurriculum(BaseCurriculum):
# Register the dataset with the factory.
register_dataset("bitwise_arithmetic", BitwiseArithmeticDataset, BitwiseArithmeticConfig, BitwiseArithmeticCurriculum)
register_dataset(DATASET_NAME, BitwiseArithmeticDataset, BitwiseArithmeticConfig, BitwiseArithmeticCurriculum)

View file

@ -9,6 +9,8 @@ from typing import Any, Optional
from ..coaching import BaseCurriculum, ScalarAttributeDefinition
from ..factory import ProceduralDataset, register_dataset
DATASET_NAME = "calendar_arithmetic"
class Weekday(Enum):
MONDAY = auto()
@ -126,6 +128,8 @@ class CalendarArithmeticDataset(ProceduralDataset):
rng = random.Random(self.seed + idx)
task = rng.choice(self.tasks)
question, answer, metadata = task(rng)
metadata["source_dataset"] = DATASET_NAME
metadata["source_index"] = idx
metadata["difficulty"] = {
"task_complexity": self.tasks.index(task),
"date_range": self.config.offset_upper_bound,
@ -523,6 +527,4 @@ class CalendarArithmeticCurriculum(BaseCurriculum):
)
register_dataset(
"calendar_arithmetic", CalendarArithmeticDataset, CalendarArithmeticConfig, CalendarArithmeticCurriculum
)
register_dataset(DATASET_NAME, CalendarArithmeticDataset, CalendarArithmeticConfig, CalendarArithmeticCurriculum)

View file

@ -7,6 +7,8 @@ from reasoning_gym import utils
from ..coaching import BaseCurriculum, RangeAttributeDefinition
from ..factory import ProceduralDataset, register_dataset
DATASET_NAME = "chain_sum"
@dataclass
class ChainSumConfig:
@ -64,6 +66,8 @@ class ChainSumDataset(ProceduralDataset):
"question": f"State the final answer to the following arithmetic problem: {expression} =",
"answer": str(result),
"metadata": {
"source_dataset": DATASET_NAME,
"source_index": idx,
"num_terms": num_terms,
"num_digits": num_digits,
"expression": expression,
@ -143,4 +147,4 @@ class ChainSumCurriculum(BaseCurriculum):
# Register the dataset
register_dataset("chain_sum", ChainSumDataset, ChainSumConfig, ChainSumCurriculum)
register_dataset(DATASET_NAME, ChainSumDataset, ChainSumConfig, ChainSumCurriculum)

View file

@ -9,6 +9,8 @@ from ..factory import ProceduralDataset, register_dataset
QUESTION_TEMPLATE = """How many 1 bits are there in the binary representation of the number {number}?"""
DATASET_NAME = "count_bits"
@dataclass
class CountBitsConfig:
@ -43,6 +45,8 @@ class CountBitsDataset(ProceduralDataset):
"question": QUESTION_TEMPLATE.format(number=number),
"answer": str(answer),
"metadata": {
"source_dataset": DATASET_NAME,
"source_index": idx,
"number": number,
"solution": answer,
"binary": binary,
@ -70,4 +74,4 @@ class CountBitsCurriculum(BaseCurriculum):
)
register_dataset("count_bits", CountBitsDataset, CountBitsConfig, CountBitsCurriculum)
register_dataset(DATASET_NAME, CountBitsDataset, CountBitsConfig, CountBitsCurriculum)

View file

@ -7,6 +7,8 @@ from typing import Any, Optional
from ..coaching import BaseCurriculum, RangeAttributeDefinition
from ..factory import ProceduralDataset, register_dataset
DATASET_NAME = "decimal_arithmetic"
@dataclass
class DecimalArithmeticConfig:
@ -189,6 +191,8 @@ class DecimalArithmeticDataset(ProceduralDataset):
"question": problem_str,
"answer": str(answer),
"metadata": {
"source_dataset": DATASET_NAME,
"source_index": idx,
"decimal_places": decimal_places,
"num_terms": terms,
"difficulty": {
@ -249,4 +253,4 @@ class DecimalArithmeticCurriculum(BaseCurriculum):
# Register the dataset with the factory.
register_dataset("decimal_arithmetic", DecimalArithmeticDataset, DecimalArithmeticConfig, DecimalArithmeticCurriculum)
register_dataset(DATASET_NAME, DecimalArithmeticDataset, DecimalArithmeticConfig, DecimalArithmeticCurriculum)

View file

@ -6,6 +6,8 @@ from typing import Any, Optional
from ..coaching import BaseCurriculum, RangeAttributeDefinition
from ..factory import ProceduralDataset, register_dataset
DATASET_NAME = "decimal_chain_sum"
@dataclass
class DecimalChainSumConfig:
@ -66,6 +68,8 @@ class DecimalChainSumDataset(ProceduralDataset):
"question": f"State the final answer to the following arithmetic problem: {expression} =",
"answer": str(result),
"metadata": {
"source_dataset": DATASET_NAME,
"source_index": idx,
"num_terms": num_terms,
"num_digits": num_digits,
"expression": expression,
@ -195,4 +199,4 @@ class DecimalChainSumCurriculum(BaseCurriculum):
)
register_dataset("decimal_chain_sum", DecimalChainSumDataset, DecimalChainSumConfig, DecimalChainSumCurriculum)
register_dataset(DATASET_NAME, DecimalChainSumDataset, DecimalChainSumConfig, DecimalChainSumCurriculum)

View file

@ -7,6 +7,8 @@ from typing import Any, Optional
from ..coaching import BaseCurriculum, ScalarAttributeDefinition
from ..factory import ProceduralDataset, register_dataset
DATASET_NAME = "dice"
def compute_probability(dice, target):
"""
@ -124,6 +126,8 @@ class DiceDataset(ProceduralDataset):
"question": puzzle_str,
"answer": answer_str,
"metadata": {
"source_dataset": DATASET_NAME,
"source_index": idx,
"puzzle": puzzle,
"difficulty": {
"num_dice": self.config.num_dice,
@ -174,4 +178,4 @@ class DiceCurriculum(BaseCurriculum):
)
register_dataset("dice", DiceDataset, DiceConfig, DiceCurriculum)
register_dataset(DATASET_NAME, DiceDataset, DiceConfig, DiceCurriculum)

View file

@ -11,6 +11,8 @@ from ..factory import ProceduralDataset, register_dataset
QUESTION_TEMPLATE = "Simplify the fraction {question_fraction} to its lowest terms. Give only the simplified fraction as your final answer."
DATASET_NAME = "fraction_simplification"
@dataclass
class FractionSimplificationConfig:
@ -114,6 +116,8 @@ class FractionSimplificationDataset(ProceduralDataset):
"question": QUESTION_TEMPLATE.format(question_fraction=question_fraction),
"answer": answer_fraction,
"metadata": {
"source_dataset": DATASET_NAME,
"source_index": idx,
"numerator": num,
"denominator": den,
"simplified_numerator": simple_num,
@ -184,7 +188,7 @@ class FractionSimplificationCurriculum(BaseCurriculum):
register_dataset(
"fraction_simplification",
DATASET_NAME,
FractionSimplificationDataset,
FractionSimplificationConfig,
FractionSimplificationCurriculum,

View file

@ -9,6 +9,8 @@ from typing import Optional
from ..coaching import BaseCurriculum, RangeAttributeDefinition
from ..factory import ProceduralDataset, register_dataset
DATASET_NAME = "gcd"
@dataclass
class GCDConfig:
@ -62,6 +64,8 @@ class GCDDataset(ProceduralDataset):
"question": f"Find the Greatest Common Divisor (GCD) of these numbers: {numbers_str}. Give only the GCD as your final answer.",
"answer": str(result),
"metadata": {
"source_dataset": DATASET_NAME,
"source_index": idx,
"numbers": numbers,
"result": result,
"num_terms": num_terms,
@ -96,4 +100,4 @@ class GCDCurriculum(BaseCurriculum):
)
register_dataset("gcd", GCDDataset, GCDConfig)
register_dataset(DATASET_NAME, GCDDataset, GCDConfig)

View file

@ -7,6 +7,8 @@ from typing import Any, Callable, Optional
from reasoning_gym.factory import ProceduralDataset, register_dataset
DATASET_NAME = "gsm_symbolic"
tasks_ok = [
0,
1,
@ -151,6 +153,8 @@ class GSMSymbolicDataset(ProceduralDataset):
generator = self.generators[generator_idx]
example = generator(rng, self.config.difficulty)
example["question"] += " Give the result as your final answer. Do not include units."
example["metadata"]["source_dataset"] = DATASET_NAME
example["metadata"]["source_index"] = idx
return example
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
@ -174,4 +178,4 @@ class GSMSymbolicDataset(ProceduralDataset):
return reward
register_dataset("gsm_symbolic", GSMSymbolicDataset, GSMSymbolicDatasetConfig)
register_dataset(DATASET_NAME, GSMSymbolicDataset, GSMSymbolicDatasetConfig)

View file

@ -9,6 +9,8 @@ from typing import Optional
from ..coaching import BaseCurriculum, RangeAttributeDefinition
from ..factory import ProceduralDataset, register_dataset
DATASET_NAME = "lcm"
@dataclass
class LCMConfig:
@ -64,6 +66,8 @@ class LCMDataset(ProceduralDataset):
"question": f"Find the Least Common Multiple (LCM) of these numbers: {numbers_str}",
"answer": str(result),
"metadata": {
"source_dataset": DATASET_NAME,
"source_index": idx,
"numbers": numbers,
"result": result,
"difficulty": {
@ -98,4 +102,4 @@ class LCMCurriculum(BaseCurriculum):
)
register_dataset("lcm", LCMDataset, LCMConfig, LCMCurriculum)
register_dataset(DATASET_NAME, LCMDataset, LCMConfig, LCMCurriculum)

View file

@ -60,6 +60,8 @@ QUESTION_TEMPLATE = """Your task is to count how many legs there are in total wh
Now, how many legs are there in total if you have {animals}?
"""
DATASET_NAME = "leg_counting"
@dataclass
class LegCountingConfig:
@ -118,6 +120,8 @@ class LegCountingDataset(ProceduralDataset):
"question": QUESTION_TEMPLATE.format(animals=", ".join(animal_list)),
"answer": str(total_legs),
"metadata": {
"source_dataset": DATASET_NAME,
"source_index": idx,
"animals": animals,
"num_animals": len(animals),
"total_legs": total_legs,
@ -152,4 +156,4 @@ class LegCountingCurriculum(BaseCurriculum):
)
register_dataset("leg_counting", LegCountingDataset, LegCountingConfig, LegCountingCurriculum)
register_dataset(DATASET_NAME, LegCountingDataset, LegCountingConfig, LegCountingCurriculum)

View file

@ -14,6 +14,8 @@ Your output should be only the number of interest.
Now, pick the {size} number of the following candidates: {numbers}
"""
DATASET_NAME = "number_format"
@dataclass
class NumberFormatConfig:
@ -94,6 +96,8 @@ class NumberFormatDataset(ProceduralDataset):
"question": QUESTION_TEMPLATE.format(numbers=" ".join(formatted_candidates), size=size),
"answer": str(answer),
"metadata": {
"source_dataset": DATASET_NAME,
"source_index": idx,
"candidates": candidates,
"solution": answer,
"formatted_candidates": formatted_candidates,
@ -138,4 +142,4 @@ class NumberFormatCurriculum(BaseCurriculum):
)
register_dataset("number_format", NumberFormatDataset, NumberFormatConfig, NumberFormatCurriculum)
register_dataset(DATASET_NAME, NumberFormatDataset, NumberFormatConfig, NumberFormatCurriculum)

View file

@ -15,6 +15,8 @@ Compute {base}^{exponent}. Return your final answer correct to 3 significant fig
Provide your answer in scientific notation using 'e' notation (e.g., 1.23e+4).
"""
DATASET_NAME = "power_function"
@dataclass
class PowerFunctionConfig:
@ -74,6 +76,8 @@ class PowerFunctionDataset(ProceduralDataset):
"question": QUESTION_TEMPLATE.format(base=base, exponent=exponent),
"answer": str(answer),
"metadata": {
"source_dataset": DATASET_NAME,
"source_index": idx,
"base": base,
"exponent": exponent,
"solution": answer,
@ -97,4 +101,4 @@ class PowerFunctionCurriculum(BaseCurriculum):
)
register_dataset("power_function", PowerFunctionDataset, PowerFunctionConfig, PowerFunctionCurriculum)
register_dataset(DATASET_NAME, PowerFunctionDataset, PowerFunctionConfig, PowerFunctionCurriculum)

View file

@ -8,6 +8,8 @@ from typing import Any, Optional
from ..coaching import BaseCurriculum, RangeAttributeDefinition
from ..factory import ProceduralDataset, register_dataset
DATASET_NAME = "prime_factorization"
@dataclass
class PrimeFactorizationConfig:
@ -84,6 +86,8 @@ class PrimeFactorizationDataset(ProceduralDataset):
),
"answer": answer,
"metadata": {
"source_dataset": DATASET_NAME,
"source_index": idx,
"number": number,
"factors": factors,
"difficulty": {
@ -110,6 +114,4 @@ class PrimeFactorizationCurriculum(BaseCurriculum):
)
register_dataset(
"prime_factorization", PrimeFactorizationDataset, PrimeFactorizationConfig, PrimeFactorizationCurriculum
)
register_dataset(DATASET_NAME, PrimeFactorizationDataset, PrimeFactorizationConfig, PrimeFactorizationCurriculum)

View file

@ -7,6 +7,8 @@ from reasoning_gym import utils
from ..coaching import BaseCurriculum, RangeAttributeDefinition
from ..factory import ProceduralDataset, register_dataset
DATASET_NAME = "products"
@dataclass
class ProductsConfig:
@ -66,6 +68,8 @@ class ProductsDataset(ProceduralDataset):
"question": f"Solve the following multiplication: {expression}. Give only the result as your final answer.",
"answer": str(result),
"metadata": {
"source_dataset": DATASET_NAME,
"source_index": idx,
"expression": expression,
"num_terms": num_terms,
"num_digits": num_digits,
@ -135,4 +139,4 @@ class ProductsCurriculum(BaseCurriculum):
# Register the dataset
register_dataset("products", ProductsDataset, ProductsConfig, ProductsCurriculum)
register_dataset(DATASET_NAME, ProductsDataset, ProductsConfig, ProductsCurriculum)

View file

@ -9,6 +9,8 @@ from dateutil import parser
from ..coaching import BaseCurriculum, ScalarAttributeDefinition
from ..factory import ProceduralDataset, register_dataset
DATASET_NAME = "time_intervals"
@dataclass
class TimeIntervalsConfig:
@ -134,6 +136,8 @@ class TimeIntervalsDataset(ProceduralDataset):
"question": question,
"answer": answer,
"metadata": {
"source_dataset": DATASET_NAME,
"source_index": idx,
"task_type": task_type,
"start_time": start_dt,
"end_time": end_dt,
@ -346,4 +350,4 @@ class TimeIntervalsCurriculum(BaseCurriculum):
# Register the dataset
register_dataset("time_intervals", TimeIntervalsDataset, TimeIntervalsConfig, TimeIntervalsCurriculum)
register_dataset(DATASET_NAME, TimeIntervalsDataset, TimeIntervalsConfig, TimeIntervalsCurriculum)