mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-24 17:05:03 +00:00
fix(envs): Add source dataset and index to metadata (#388)
* add source dataset and index to metadata * fix typo * fix coach class and its test
This commit is contained in:
parent
c6d01541aa
commit
4c47527130
104 changed files with 549 additions and 146 deletions
|
|
@ -5,6 +5,8 @@ from typing import Any, Literal, Optional
|
|||
from ..coaching import BaseCurriculum, RangeAttributeDefinition
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
DATASET_NAME = "basic_arithmetic"
|
||||
|
||||
|
||||
@dataclass
|
||||
class BasicArithmeticDatasetConfig:
|
||||
|
|
@ -95,6 +97,8 @@ class BasicArithmeticDataset(ProceduralDataset):
|
|||
"question": question,
|
||||
"answer": str(result),
|
||||
"metadata": {
|
||||
"source_dataset": DATASET_NAME,
|
||||
"source_index": idx,
|
||||
"expression": expression,
|
||||
"num_terms": num_terms,
|
||||
"num_digits": num_digits,
|
||||
|
|
@ -260,4 +264,4 @@ class BasicArithmeticCurriculum(BaseCurriculum):
|
|||
|
||||
|
||||
# Register the dataset
|
||||
register_dataset("basic_arithmetic", BasicArithmeticDataset, BasicArithmeticDatasetConfig, BasicArithmeticCurriculum)
|
||||
register_dataset(DATASET_NAME, BasicArithmeticDataset, BasicArithmeticDatasetConfig, BasicArithmeticCurriculum)
|
||||
|
|
|
|||
|
|
@ -5,6 +5,8 @@ from typing import Any, Optional
|
|||
from ..coaching import BaseCurriculum, ScalarAttributeDefinition
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
DATASET_NAME = "bitwise_arithmetic"
|
||||
|
||||
|
||||
@dataclass
|
||||
class BitwiseArithmeticConfig:
|
||||
|
|
@ -155,7 +157,12 @@ class BitwiseArithmeticDataset(ProceduralDataset):
|
|||
return {
|
||||
"question": problem_str,
|
||||
"answer": answer,
|
||||
"metadata": {"problem": problem, "difficulty": {"difficulty": self.config.difficulty}},
|
||||
"metadata": {
|
||||
"source_dataset": DATASET_NAME,
|
||||
"source_index": idx,
|
||||
"problem": problem,
|
||||
"difficulty": {"difficulty": self.config.difficulty},
|
||||
},
|
||||
}
|
||||
|
||||
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
|
||||
|
|
@ -193,4 +200,4 @@ class BitwiseArithmeticCurriculum(BaseCurriculum):
|
|||
|
||||
|
||||
# Register the dataset with the factory.
|
||||
register_dataset("bitwise_arithmetic", BitwiseArithmeticDataset, BitwiseArithmeticConfig, BitwiseArithmeticCurriculum)
|
||||
register_dataset(DATASET_NAME, BitwiseArithmeticDataset, BitwiseArithmeticConfig, BitwiseArithmeticCurriculum)
|
||||
|
|
|
|||
|
|
@ -9,6 +9,8 @@ from typing import Any, Optional
|
|||
from ..coaching import BaseCurriculum, ScalarAttributeDefinition
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
DATASET_NAME = "calendar_arithmetic"
|
||||
|
||||
|
||||
class Weekday(Enum):
|
||||
MONDAY = auto()
|
||||
|
|
@ -126,6 +128,8 @@ class CalendarArithmeticDataset(ProceduralDataset):
|
|||
rng = random.Random(self.seed + idx)
|
||||
task = rng.choice(self.tasks)
|
||||
question, answer, metadata = task(rng)
|
||||
metadata["source_dataset"] = DATASET_NAME
|
||||
metadata["source_index"] = idx
|
||||
metadata["difficulty"] = {
|
||||
"task_complexity": self.tasks.index(task),
|
||||
"date_range": self.config.offset_upper_bound,
|
||||
|
|
@ -523,6 +527,4 @@ class CalendarArithmeticCurriculum(BaseCurriculum):
|
|||
)
|
||||
|
||||
|
||||
register_dataset(
|
||||
"calendar_arithmetic", CalendarArithmeticDataset, CalendarArithmeticConfig, CalendarArithmeticCurriculum
|
||||
)
|
||||
register_dataset(DATASET_NAME, CalendarArithmeticDataset, CalendarArithmeticConfig, CalendarArithmeticCurriculum)
|
||||
|
|
|
|||
|
|
@ -7,6 +7,8 @@ from reasoning_gym import utils
|
|||
from ..coaching import BaseCurriculum, RangeAttributeDefinition
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
DATASET_NAME = "chain_sum"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChainSumConfig:
|
||||
|
|
@ -64,6 +66,8 @@ class ChainSumDataset(ProceduralDataset):
|
|||
"question": f"State the final answer to the following arithmetic problem: {expression} =",
|
||||
"answer": str(result),
|
||||
"metadata": {
|
||||
"source_dataset": DATASET_NAME,
|
||||
"source_index": idx,
|
||||
"num_terms": num_terms,
|
||||
"num_digits": num_digits,
|
||||
"expression": expression,
|
||||
|
|
@ -143,4 +147,4 @@ class ChainSumCurriculum(BaseCurriculum):
|
|||
|
||||
|
||||
# Register the dataset
|
||||
register_dataset("chain_sum", ChainSumDataset, ChainSumConfig, ChainSumCurriculum)
|
||||
register_dataset(DATASET_NAME, ChainSumDataset, ChainSumConfig, ChainSumCurriculum)
|
||||
|
|
|
|||
|
|
@ -9,6 +9,8 @@ from ..factory import ProceduralDataset, register_dataset
|
|||
|
||||
QUESTION_TEMPLATE = """How many 1 bits are there in the binary representation of the number {number}?"""
|
||||
|
||||
DATASET_NAME = "count_bits"
|
||||
|
||||
|
||||
@dataclass
|
||||
class CountBitsConfig:
|
||||
|
|
@ -43,6 +45,8 @@ class CountBitsDataset(ProceduralDataset):
|
|||
"question": QUESTION_TEMPLATE.format(number=number),
|
||||
"answer": str(answer),
|
||||
"metadata": {
|
||||
"source_dataset": DATASET_NAME,
|
||||
"source_index": idx,
|
||||
"number": number,
|
||||
"solution": answer,
|
||||
"binary": binary,
|
||||
|
|
@ -70,4 +74,4 @@ class CountBitsCurriculum(BaseCurriculum):
|
|||
)
|
||||
|
||||
|
||||
register_dataset("count_bits", CountBitsDataset, CountBitsConfig, CountBitsCurriculum)
|
||||
register_dataset(DATASET_NAME, CountBitsDataset, CountBitsConfig, CountBitsCurriculum)
|
||||
|
|
|
|||
|
|
@ -7,6 +7,8 @@ from typing import Any, Optional
|
|||
from ..coaching import BaseCurriculum, RangeAttributeDefinition
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
DATASET_NAME = "decimal_arithmetic"
|
||||
|
||||
|
||||
@dataclass
|
||||
class DecimalArithmeticConfig:
|
||||
|
|
@ -189,6 +191,8 @@ class DecimalArithmeticDataset(ProceduralDataset):
|
|||
"question": problem_str,
|
||||
"answer": str(answer),
|
||||
"metadata": {
|
||||
"source_dataset": DATASET_NAME,
|
||||
"source_index": idx,
|
||||
"decimal_places": decimal_places,
|
||||
"num_terms": terms,
|
||||
"difficulty": {
|
||||
|
|
@ -249,4 +253,4 @@ class DecimalArithmeticCurriculum(BaseCurriculum):
|
|||
|
||||
|
||||
# Register the dataset with the factory.
|
||||
register_dataset("decimal_arithmetic", DecimalArithmeticDataset, DecimalArithmeticConfig, DecimalArithmeticCurriculum)
|
||||
register_dataset(DATASET_NAME, DecimalArithmeticDataset, DecimalArithmeticConfig, DecimalArithmeticCurriculum)
|
||||
|
|
|
|||
|
|
@ -6,6 +6,8 @@ from typing import Any, Optional
|
|||
from ..coaching import BaseCurriculum, RangeAttributeDefinition
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
DATASET_NAME = "decimal_chain_sum"
|
||||
|
||||
|
||||
@dataclass
|
||||
class DecimalChainSumConfig:
|
||||
|
|
@ -66,6 +68,8 @@ class DecimalChainSumDataset(ProceduralDataset):
|
|||
"question": f"State the final answer to the following arithmetic problem: {expression} =",
|
||||
"answer": str(result),
|
||||
"metadata": {
|
||||
"source_dataset": DATASET_NAME,
|
||||
"source_index": idx,
|
||||
"num_terms": num_terms,
|
||||
"num_digits": num_digits,
|
||||
"expression": expression,
|
||||
|
|
@ -195,4 +199,4 @@ class DecimalChainSumCurriculum(BaseCurriculum):
|
|||
)
|
||||
|
||||
|
||||
register_dataset("decimal_chain_sum", DecimalChainSumDataset, DecimalChainSumConfig, DecimalChainSumCurriculum)
|
||||
register_dataset(DATASET_NAME, DecimalChainSumDataset, DecimalChainSumConfig, DecimalChainSumCurriculum)
|
||||
|
|
|
|||
|
|
@ -7,6 +7,8 @@ from typing import Any, Optional
|
|||
from ..coaching import BaseCurriculum, ScalarAttributeDefinition
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
DATASET_NAME = "dice"
|
||||
|
||||
|
||||
def compute_probability(dice, target):
|
||||
"""
|
||||
|
|
@ -124,6 +126,8 @@ class DiceDataset(ProceduralDataset):
|
|||
"question": puzzle_str,
|
||||
"answer": answer_str,
|
||||
"metadata": {
|
||||
"source_dataset": DATASET_NAME,
|
||||
"source_index": idx,
|
||||
"puzzle": puzzle,
|
||||
"difficulty": {
|
||||
"num_dice": self.config.num_dice,
|
||||
|
|
@ -174,4 +178,4 @@ class DiceCurriculum(BaseCurriculum):
|
|||
)
|
||||
|
||||
|
||||
register_dataset("dice", DiceDataset, DiceConfig, DiceCurriculum)
|
||||
register_dataset(DATASET_NAME, DiceDataset, DiceConfig, DiceCurriculum)
|
||||
|
|
|
|||
|
|
@ -11,6 +11,8 @@ from ..factory import ProceduralDataset, register_dataset
|
|||
|
||||
QUESTION_TEMPLATE = "Simplify the fraction {question_fraction} to its lowest terms. Give only the simplified fraction as your final answer."
|
||||
|
||||
DATASET_NAME = "fraction_simplification"
|
||||
|
||||
|
||||
@dataclass
|
||||
class FractionSimplificationConfig:
|
||||
|
|
@ -114,6 +116,8 @@ class FractionSimplificationDataset(ProceduralDataset):
|
|||
"question": QUESTION_TEMPLATE.format(question_fraction=question_fraction),
|
||||
"answer": answer_fraction,
|
||||
"metadata": {
|
||||
"source_dataset": DATASET_NAME,
|
||||
"source_index": idx,
|
||||
"numerator": num,
|
||||
"denominator": den,
|
||||
"simplified_numerator": simple_num,
|
||||
|
|
@ -184,7 +188,7 @@ class FractionSimplificationCurriculum(BaseCurriculum):
|
|||
|
||||
|
||||
register_dataset(
|
||||
"fraction_simplification",
|
||||
DATASET_NAME,
|
||||
FractionSimplificationDataset,
|
||||
FractionSimplificationConfig,
|
||||
FractionSimplificationCurriculum,
|
||||
|
|
|
|||
|
|
@ -9,6 +9,8 @@ from typing import Optional
|
|||
from ..coaching import BaseCurriculum, RangeAttributeDefinition
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
DATASET_NAME = "gcd"
|
||||
|
||||
|
||||
@dataclass
|
||||
class GCDConfig:
|
||||
|
|
@ -62,6 +64,8 @@ class GCDDataset(ProceduralDataset):
|
|||
"question": f"Find the Greatest Common Divisor (GCD) of these numbers: {numbers_str}. Give only the GCD as your final answer.",
|
||||
"answer": str(result),
|
||||
"metadata": {
|
||||
"source_dataset": DATASET_NAME,
|
||||
"source_index": idx,
|
||||
"numbers": numbers,
|
||||
"result": result,
|
||||
"num_terms": num_terms,
|
||||
|
|
@ -96,4 +100,4 @@ class GCDCurriculum(BaseCurriculum):
|
|||
)
|
||||
|
||||
|
||||
register_dataset("gcd", GCDDataset, GCDConfig)
|
||||
register_dataset(DATASET_NAME, GCDDataset, GCDConfig)
|
||||
|
|
|
|||
|
|
@ -7,6 +7,8 @@ from typing import Any, Callable, Optional
|
|||
|
||||
from reasoning_gym.factory import ProceduralDataset, register_dataset
|
||||
|
||||
DATASET_NAME = "gsm_symbolic"
|
||||
|
||||
tasks_ok = [
|
||||
0,
|
||||
1,
|
||||
|
|
@ -151,6 +153,8 @@ class GSMSymbolicDataset(ProceduralDataset):
|
|||
generator = self.generators[generator_idx]
|
||||
example = generator(rng, self.config.difficulty)
|
||||
example["question"] += " Give the result as your final answer. Do not include units."
|
||||
example["metadata"]["source_dataset"] = DATASET_NAME
|
||||
example["metadata"]["source_index"] = idx
|
||||
return example
|
||||
|
||||
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
|
||||
|
|
@ -174,4 +178,4 @@ class GSMSymbolicDataset(ProceduralDataset):
|
|||
return reward
|
||||
|
||||
|
||||
register_dataset("gsm_symbolic", GSMSymbolicDataset, GSMSymbolicDatasetConfig)
|
||||
register_dataset(DATASET_NAME, GSMSymbolicDataset, GSMSymbolicDatasetConfig)
|
||||
|
|
|
|||
|
|
@ -9,6 +9,8 @@ from typing import Optional
|
|||
from ..coaching import BaseCurriculum, RangeAttributeDefinition
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
DATASET_NAME = "lcm"
|
||||
|
||||
|
||||
@dataclass
|
||||
class LCMConfig:
|
||||
|
|
@ -64,6 +66,8 @@ class LCMDataset(ProceduralDataset):
|
|||
"question": f"Find the Least Common Multiple (LCM) of these numbers: {numbers_str}",
|
||||
"answer": str(result),
|
||||
"metadata": {
|
||||
"source_dataset": DATASET_NAME,
|
||||
"source_index": idx,
|
||||
"numbers": numbers,
|
||||
"result": result,
|
||||
"difficulty": {
|
||||
|
|
@ -98,4 +102,4 @@ class LCMCurriculum(BaseCurriculum):
|
|||
)
|
||||
|
||||
|
||||
register_dataset("lcm", LCMDataset, LCMConfig, LCMCurriculum)
|
||||
register_dataset(DATASET_NAME, LCMDataset, LCMConfig, LCMCurriculum)
|
||||
|
|
|
|||
|
|
@ -60,6 +60,8 @@ QUESTION_TEMPLATE = """Your task is to count how many legs there are in total wh
|
|||
Now, how many legs are there in total if you have {animals}?
|
||||
"""
|
||||
|
||||
DATASET_NAME = "leg_counting"
|
||||
|
||||
|
||||
@dataclass
|
||||
class LegCountingConfig:
|
||||
|
|
@ -118,6 +120,8 @@ class LegCountingDataset(ProceduralDataset):
|
|||
"question": QUESTION_TEMPLATE.format(animals=", ".join(animal_list)),
|
||||
"answer": str(total_legs),
|
||||
"metadata": {
|
||||
"source_dataset": DATASET_NAME,
|
||||
"source_index": idx,
|
||||
"animals": animals,
|
||||
"num_animals": len(animals),
|
||||
"total_legs": total_legs,
|
||||
|
|
@ -152,4 +156,4 @@ class LegCountingCurriculum(BaseCurriculum):
|
|||
)
|
||||
|
||||
|
||||
register_dataset("leg_counting", LegCountingDataset, LegCountingConfig, LegCountingCurriculum)
|
||||
register_dataset(DATASET_NAME, LegCountingDataset, LegCountingConfig, LegCountingCurriculum)
|
||||
|
|
|
|||
|
|
@ -14,6 +14,8 @@ Your output should be only the number of interest.
|
|||
Now, pick the {size} number of the following candidates: {numbers}
|
||||
"""
|
||||
|
||||
DATASET_NAME = "number_format"
|
||||
|
||||
|
||||
@dataclass
|
||||
class NumberFormatConfig:
|
||||
|
|
@ -94,6 +96,8 @@ class NumberFormatDataset(ProceduralDataset):
|
|||
"question": QUESTION_TEMPLATE.format(numbers=" ".join(formatted_candidates), size=size),
|
||||
"answer": str(answer),
|
||||
"metadata": {
|
||||
"source_dataset": DATASET_NAME,
|
||||
"source_index": idx,
|
||||
"candidates": candidates,
|
||||
"solution": answer,
|
||||
"formatted_candidates": formatted_candidates,
|
||||
|
|
@ -138,4 +142,4 @@ class NumberFormatCurriculum(BaseCurriculum):
|
|||
)
|
||||
|
||||
|
||||
register_dataset("number_format", NumberFormatDataset, NumberFormatConfig, NumberFormatCurriculum)
|
||||
register_dataset(DATASET_NAME, NumberFormatDataset, NumberFormatConfig, NumberFormatCurriculum)
|
||||
|
|
|
|||
|
|
@ -15,6 +15,8 @@ Compute {base}^{exponent}. Return your final answer correct to 3 significant fig
|
|||
Provide your answer in scientific notation using 'e' notation (e.g., 1.23e+4).
|
||||
"""
|
||||
|
||||
DATASET_NAME = "power_function"
|
||||
|
||||
|
||||
@dataclass
|
||||
class PowerFunctionConfig:
|
||||
|
|
@ -74,6 +76,8 @@ class PowerFunctionDataset(ProceduralDataset):
|
|||
"question": QUESTION_TEMPLATE.format(base=base, exponent=exponent),
|
||||
"answer": str(answer),
|
||||
"metadata": {
|
||||
"source_dataset": DATASET_NAME,
|
||||
"source_index": idx,
|
||||
"base": base,
|
||||
"exponent": exponent,
|
||||
"solution": answer,
|
||||
|
|
@ -97,4 +101,4 @@ class PowerFunctionCurriculum(BaseCurriculum):
|
|||
)
|
||||
|
||||
|
||||
register_dataset("power_function", PowerFunctionDataset, PowerFunctionConfig, PowerFunctionCurriculum)
|
||||
register_dataset(DATASET_NAME, PowerFunctionDataset, PowerFunctionConfig, PowerFunctionCurriculum)
|
||||
|
|
|
|||
|
|
@ -8,6 +8,8 @@ from typing import Any, Optional
|
|||
from ..coaching import BaseCurriculum, RangeAttributeDefinition
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
DATASET_NAME = "prime_factorization"
|
||||
|
||||
|
||||
@dataclass
|
||||
class PrimeFactorizationConfig:
|
||||
|
|
@ -84,6 +86,8 @@ class PrimeFactorizationDataset(ProceduralDataset):
|
|||
),
|
||||
"answer": answer,
|
||||
"metadata": {
|
||||
"source_dataset": DATASET_NAME,
|
||||
"source_index": idx,
|
||||
"number": number,
|
||||
"factors": factors,
|
||||
"difficulty": {
|
||||
|
|
@ -110,6 +114,4 @@ class PrimeFactorizationCurriculum(BaseCurriculum):
|
|||
)
|
||||
|
||||
|
||||
register_dataset(
|
||||
"prime_factorization", PrimeFactorizationDataset, PrimeFactorizationConfig, PrimeFactorizationCurriculum
|
||||
)
|
||||
register_dataset(DATASET_NAME, PrimeFactorizationDataset, PrimeFactorizationConfig, PrimeFactorizationCurriculum)
|
||||
|
|
|
|||
|
|
@ -7,6 +7,8 @@ from reasoning_gym import utils
|
|||
from ..coaching import BaseCurriculum, RangeAttributeDefinition
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
DATASET_NAME = "products"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProductsConfig:
|
||||
|
|
@ -66,6 +68,8 @@ class ProductsDataset(ProceduralDataset):
|
|||
"question": f"Solve the following multiplication: {expression}. Give only the result as your final answer.",
|
||||
"answer": str(result),
|
||||
"metadata": {
|
||||
"source_dataset": DATASET_NAME,
|
||||
"source_index": idx,
|
||||
"expression": expression,
|
||||
"num_terms": num_terms,
|
||||
"num_digits": num_digits,
|
||||
|
|
@ -135,4 +139,4 @@ class ProductsCurriculum(BaseCurriculum):
|
|||
|
||||
|
||||
# Register the dataset
|
||||
register_dataset("products", ProductsDataset, ProductsConfig, ProductsCurriculum)
|
||||
register_dataset(DATASET_NAME, ProductsDataset, ProductsConfig, ProductsCurriculum)
|
||||
|
|
|
|||
|
|
@ -9,6 +9,8 @@ from dateutil import parser
|
|||
from ..coaching import BaseCurriculum, ScalarAttributeDefinition
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
DATASET_NAME = "time_intervals"
|
||||
|
||||
|
||||
@dataclass
|
||||
class TimeIntervalsConfig:
|
||||
|
|
@ -134,6 +136,8 @@ class TimeIntervalsDataset(ProceduralDataset):
|
|||
"question": question,
|
||||
"answer": answer,
|
||||
"metadata": {
|
||||
"source_dataset": DATASET_NAME,
|
||||
"source_index": idx,
|
||||
"task_type": task_type,
|
||||
"start_time": start_dt,
|
||||
"end_time": end_dt,
|
||||
|
|
@ -346,4 +350,4 @@ class TimeIntervalsCurriculum(BaseCurriculum):
|
|||
|
||||
|
||||
# Register the dataset
|
||||
register_dataset("time_intervals", TimeIntervalsDataset, TimeIntervalsConfig, TimeIntervalsCurriculum)
|
||||
register_dataset(DATASET_NAME, TimeIntervalsDataset, TimeIntervalsConfig, TimeIntervalsCurriculum)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue