gsm_symbolic generator changes

This commit is contained in:
Andreas Koepf 2025-02-04 13:47:57 +01:00
parent b84e29a8b6
commit afb95508ef
10 changed files with 9007 additions and 7360 deletions

View file

@ -0,0 +1,815 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"# create open-router client, place your OPENROUTER_API_KEY in .env file\n",
"# .env contents:\n",
"# OPENROUTER_API_KEY=sk-or-v1- ...\n",
"\n",
"%load_ext dotenv\n",
"%dotenv\n",
"import os\n",
"import re\n",
"from random import Random\n",
"from pathlib import Path\n",
"from typing import Any, Iterable, Optional\n",
"import json\n",
"from openai import OpenAI\n",
"from openai.types.chat import ChatCompletion, ChatCompletionMessageParam\n",
"import time\n",
"import reasoning_gym\n",
"\n",
"\n",
"def llm_generate(\n",
" client: OpenAI,\n",
" messages: Iterable[ChatCompletionMessageParam],\n",
" sampling_params: dict[str, Any],\n",
") -> ChatCompletion:\n",
" max_retry = 3\n",
" for trial in range(max_retry):\n",
" try:\n",
" return client.chat.completions.create(\n",
" messages=messages,\n",
" **sampling_params,\n",
" )\n",
" except Exception as e:\n",
" print(\"failure response:\", e)\n",
" time.sleep(trial * trial) # quadratic backoff\n",
" if trial == max_retry - 1:\n",
" raise\n",
"\n",
"def generate_simple_request(user_prompt: str, developer_prompt: Optional[str] = None) -> list[dict]:\n",
" prompt = []\n",
" if developer_prompt is not None:\n",
" prompt.append( { \"role\": \"system\", \"content\": developer_prompt } )\n",
" \n",
" prompt.append( { \"role\": \"user\", \"content\": user_prompt })\n",
" return prompt\n",
"\n",
"open_router_client = OpenAI(\n",
" base_url=\"https://openrouter.ai/api/v1\",\n",
" api_key=os.getenv(\"OPENROUTER_API_KEY\"),\n",
" timeout=90.0,\n",
")\n",
"\n",
"sampling_params = {\n",
" \"model\": \"anthropic/claude-3.5-sonnet\",\n",
" \"max_tokens\": 4096,\n",
"}\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"48"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(reasoning_gym.factory.DATASETS)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"# test all gsm_symoblic generators\n",
"import reasoning_gym.arithmetic.gsm_symbolic\n",
"x = reasoning_gym.create_dataset(\"gsm_symbolic\")\n",
"\n",
"generators = x.generators"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"import reasoning_gym.utils\n",
"\n",
"difficulty = 1.0\n",
"\n",
"prompt_template = \"Solve the following math task and return the answer (just the number) in <answer></answer> tags:\\n\\n{question}\"\n",
"\n",
"def query_llm(x: dict) -> tuple[int, int]:\n",
" q = x[\"question\"]\n",
" ground_truth = x[\"answer\"]\n",
" user_prompt = prompt_template.format(question=q)\n",
" msgs = generate_simple_request(user_prompt)\n",
" output = llm_generate(client=open_router_client, messages=msgs, sampling_params=sampling_params)\n",
" full_answer = output.choices[0].message.content\n",
" answer = reasoning_gym.utils.extract_answer(completion=full_answer, tag_name=\"answer\").strip()\n",
" return answer, ground_truth, full_answer\n"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"def cross_check_generator(rng: Random, index: int, difficulty: float = 1.0, num_generations = 3, verbose: bool = False) -> int:\n",
" num_matching = 0\n",
" try:\n",
" g = generators[index] \n",
" for j in range(num_generations):\n",
" try:\n",
" x = g(rng, difficulty=difficulty)\n",
" a, gt, full_answer = query_llm(x)\n",
"\n",
" print(f\"[{index}.{j}], llm={a}, ground_truth={gt}, match={a==gt}\")\n",
" if verbose:\n",
" print(x[\"question\"])\n",
" print(full_answer)\n",
" if a == gt:\n",
" num_matching += 1\n",
" except Exception as ex:\n",
" print(f\"[{index}.{j}] error: {ex}\")\n",
" except Exception as ex:\n",
" print(f\"[{index}] generator failure: {ex}\")\n",
" return -1\n",
" return num_matching\n",
"\n",
"def cross_check_generators(rng: Random, difficulty: float = 1.0, num_generations = 3):\n",
" results = [0] * len(generators)\n",
" for i in range(len(generators)):\n",
" results[i] = cross_check_generator(rng, index=i, difficulty=difficulty, num_generations=num_generations)\n",
" return results\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[35.0] error: Could not find valid time_per_room\n",
"[35.1], llm=30.196, ground_truth=30, match=False\n",
"A cleaner has to clean a office building with 21 floors. They have 3 days to get it done. It takes them 44 minutes per floor. If they work 17 hour day, what percentage of their day, on average, is spent cleaning floors?\n",
"Let me solve this step by step:\n",
"\n",
"1. Total floors to clean: 21\n",
"2. Time per floor: 44 minutes\n",
"3. Total time needed: 21 × 44 = 924 minutes\n",
"4. Working hours per day: 17 hours = 17 × 60 = 1020 minutes\n",
"5. Days available: 3\n",
"6. Time needed per day: 924 ÷ 3 = 308 minutes\n",
"7. Percentage of day spent cleaning: (308 ÷ 1020) × 100 = 30.196%\n",
"\n",
"<answer>30.196</answer>\n",
"[35.2], llm=34.23, ground_truth=34, match=False\n",
"A cleaner has to clean a hospital with 12 floors. They have 4 days to get it done. It takes them 89 minutes per floor. If they work 13 hour day, what percentage of their day, on average, is spent cleaning floors?\n",
"Let me solve this step by step:\n",
"\n",
"1. Total floors to clean = 12\n",
"2. Total time for all floors = 12 × 89 minutes = 1,068 minutes\n",
"3. Days available = 4\n",
"4. Time needed per day = 1,068 ÷ 4 = 267 minutes\n",
"5. Hours per day working = 13\n",
"6. Minutes in work day = 13 × 60 = 780 minutes\n",
"7. Percentage calculation = (267 ÷ 780) × 100 = 34.23%\n",
"\n",
"<answer>34.23</answer>\n",
"[35.3], llm=61, ground_truth=61, match=True\n",
"A cleaner has to clean a university with 20 floors. They have 10 days to get it done. It takes them 202 minutes per floor. If they work 11 hour day, what percentage of their day, on average, is spent cleaning floors?\n",
"Let me solve this step by step:\n",
"\n",
"1. First, let's calculate total time needed to clean all floors:\n",
" * 202 minutes × 20 floors = 4040 minutes total\n",
"\n",
"2. They have 10 days to do it, so per day:\n",
" * 4040 ÷ 10 = 404 minutes per day cleaning\n",
"\n",
"3. 11 hour workday in minutes:\n",
" * 11 × 60 = 660 minutes per day working\n",
"\n",
"4. Calculate percentage:\n",
" * (404 ÷ 660) × 100 = 61.21212121...%\n",
"\n",
"<answer>61</answer>\n",
"[35.4], llm=61.83, ground_truth=61, match=False\n",
"A cleaner has to clean a office building with 28 floors. They have 4 days to get it done. It takes them 53 minutes per floor. If they work 10 hour day, what percentage of their day, on average, is spent cleaning floors?\n",
"Let me solve this step by step:\n",
"\n",
"1. Time per floor = 53 minutes\n",
"2. Total floors = 28\n",
"3. Total minutes needed = 53 × 28 = 1,484 minutes\n",
"4. Days available = 4\n",
"5. Minutes needed per day = 1,484 ÷ 4 = 371 minutes\n",
"6. Hours per day working = 10\n",
"7. Minutes in work day = 10 × 60 = 600 minutes\n",
"8. Percentage = (371 ÷ 600) × 100 = 61.833...%\n",
"\n",
"<answer>61.83</answer>\n"
]
},
{
"data": {
"text/plain": [
"1"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"rng = Random(200)\n",
"re_check = [2,11,21,27,32,35,37]\n",
"\n",
"# for i in re_check:\n",
"# cross_check_generator(rng, index=i, difficulty=1.0, num_generations=3)\n",
"# 11 not ok\n",
"\n",
"cross_check_generator(rng, index=35, difficulty=1.0, num_generations=5, verbose=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0.0], llm=43, ground_truth=43, match=True\n",
"[0.1], llm=104, ground_truth=104, match=True\n",
"[0.2], llm=21, ground_truth=21, match=True\n",
"[0.3], llm=300, ground_truth=300, match=True\n",
"[0.4], llm=76, ground_truth=76, match=True\n",
"[1.0], llm=59, ground_truth=59, match=True\n",
"[1.1], llm=61, ground_truth=61, match=True\n",
"[1.2], llm=42, ground_truth=42, match=True\n",
"[1.3], llm=76, ground_truth=76, match=True\n",
"[1.4], llm=80, ground_truth=80, match=True\n",
"[2.0], llm=47, ground_truth=47, match=True\n",
"[2.1], llm=62, ground_truth=61, match=False\n",
"[2.2], llm=36, ground_truth=36, match=True\n",
"[2.3], llm=41, ground_truth=40, match=False\n",
"[2.4], llm=70, ground_truth=70, match=True\n",
"[3.0], llm=55, ground_truth=55, match=True\n",
"[3.1], llm=8, ground_truth=8, match=True\n",
"[3.2], llm=33, ground_truth=33, match=True\n",
"[3.3], llm=7, ground_truth=7, match=True\n",
"[3.4], llm=24, ground_truth=24, match=True\n",
"[4.0], llm=15, ground_truth=15, match=True\n",
"[4.1], llm=15, ground_truth=15, match=True\n",
"[4.2], llm=4, ground_truth=4, match=True\n",
"[4.3], llm=9, ground_truth=9, match=True\n",
"[4.4], llm=3, ground_truth=3, match=True\n",
"[5.0], llm=20, ground_truth=20, match=True\n",
"[5.1], llm=43, ground_truth=43, match=True\n",
"[5.2], llm=10, ground_truth=10, match=True\n",
"[5.3], llm=63, ground_truth=63, match=True\n",
"[5.4], llm=58, ground_truth=58, match=True\n",
"[6.0], llm=120, ground_truth=120, match=True\n",
"[6.1], llm=124, ground_truth=124, match=True\n",
"[6.2], llm=24, ground_truth=24, match=True\n",
"[6.3], llm=55, ground_truth=55, match=True\n",
"[6.4], llm=92, ground_truth=92, match=True\n",
"[7.0], llm=527, ground_truth=527, match=True\n",
"[7.1], llm=515, ground_truth=515, match=True\n",
"[7.2], llm=401, ground_truth=401, match=True\n",
"[7.3], llm=44, ground_truth=44, match=True\n",
"[7.4], llm=218, ground_truth=218, match=True\n",
"[8.0], llm=1014, ground_truth=1014, match=True\n",
"[8.1], llm=1010, ground_truth=1010, match=True\n",
"[8.2], llm=300, ground_truth=300, match=True\n",
"[8.3], llm=540, ground_truth=540, match=True\n",
"[8.4], llm=864, ground_truth=864, match=True\n",
"[9.0], llm=40, ground_truth=40, match=True\n",
"[9.1], llm=0, ground_truth=0, match=True\n",
"[9.2], llm=30, ground_truth=30, match=True\n",
"[9.3], llm=80, ground_truth=80, match=True\n",
"[9.4], llm=4, ground_truth=4, match=True\n",
"[10.0], llm=45, ground_truth=45, match=True\n",
"[10.1], llm=44, ground_truth=44, match=True\n",
"[10.2], llm=24, ground_truth=24, match=True\n",
"[10.3], llm=67, ground_truth=67, match=True\n",
"[10.4], llm=90, ground_truth=90, match=True\n",
"[11.0], llm=401, ground_truth=330, match=False\n",
"[11.1], llm=481, ground_truth=386, match=False\n",
"[11.2], llm=219, ground_truth=390, match=False\n",
"[11.3], llm=212, ground_truth=386, match=False\n",
"[11.4], llm=268, ground_truth=231, match=False\n",
"[12.0], llm=351, ground_truth=351, match=True\n",
"[12.1], llm=269, ground_truth=269, match=True\n",
"[12.2], llm=286, ground_truth=286, match=True\n",
"[12.3], llm=72, ground_truth=72, match=True\n",
"[12.4], llm=368, ground_truth=368, match=True\n",
"[13.0], llm=2, ground_truth=2, match=True\n",
"[13.1], llm=2, ground_truth=2, match=True\n",
"[13.2], llm=2, ground_truth=2, match=True\n",
"[13.3], llm=2, ground_truth=2, match=True\n",
"[13.4], llm=2, ground_truth=2, match=True\n",
"[14.0], llm=128, ground_truth=128, match=True\n",
"[14.1], llm=132, ground_truth=132, match=True\n",
"[14.2], llm=120, ground_truth=120, match=True\n",
"[14.3], llm=188, ground_truth=188, match=True\n",
"[14.4], llm=168, ground_truth=168, match=True\n",
"[15.0], llm=67, ground_truth=67, match=True\n",
"[15.1], llm=89, ground_truth=89, match=True\n",
"[15.2], llm=90, ground_truth=90, match=True\n",
"[15.3], llm=67, ground_truth=67, match=True\n",
"[15.4], llm=96, ground_truth=96, match=True\n",
"[16.0], llm=51, ground_truth=51, match=True\n",
"[16.1], llm=57, ground_truth=57, match=True\n",
"[16.2], llm=32, ground_truth=32, match=True\n",
"[16.3], llm=38, ground_truth=38, match=True\n",
"[16.4], llm=32, ground_truth=32, match=True\n",
"[17.0], llm=280, ground_truth=280, match=True\n",
"[17.1], llm=210, ground_truth=210, match=True\n",
"[17.2], llm=770, ground_truth=770, match=True\n",
"[17.3], llm=190, ground_truth=190, match=True\n",
"[17.4], llm=1060, ground_truth=1060, match=True\n",
"[18.0], llm=775, ground_truth=775, match=True\n",
"[18.1], llm=484, ground_truth=484, match=True\n",
"[18.2], llm=359, ground_truth=359, match=True\n",
"[18.3], llm=697, ground_truth=697, match=True\n",
"[18.4], llm=740, ground_truth=740, match=True\n",
"[19.0], llm=885, ground_truth=885, match=True\n",
"[19.1], llm=950, ground_truth=950, match=True\n",
"[19.2], llm=695, ground_truth=695, match=True\n",
"[19.3], llm=1530, ground_truth=1530, match=True\n",
"[19.4], llm=475, ground_truth=475, match=True\n",
"[20.0], llm=4, ground_truth=4, match=True\n",
"[20.1], llm=18, ground_truth=18, match=True\n",
"[20.2], llm=1, ground_truth=1, match=True\n",
"[20.3], llm=3, ground_truth=3, match=True\n",
"[20.4], llm=3, ground_truth=3, match=True\n",
"[21.0], llm=888, ground_truth=888, match=True\n",
"[21.1], llm=306, ground_truth=498, match=False\n",
"[21.2], llm=738, ground_truth=738, match=True\n",
"[21.3], llm=360, ground_truth=648, match=False\n",
"[21.4], llm=0, ground_truth=-80, match=False\n",
"[22.0], llm=8800, ground_truth=8800, match=True\n",
"[22.1], llm=46000, ground_truth=46000, match=True\n",
"[22.2], llm=51200, ground_truth=51200, match=True\n",
"[22.3], llm=69800, ground_truth=69800, match=True\n",
"[22.4], llm=67400, ground_truth=67400, match=True\n",
"[23.0], llm=478, ground_truth=478, match=True\n",
"[23.1], llm=389, ground_truth=389, match=True\n",
"[23.2], llm=206, ground_truth=206, match=True\n",
"[23.3], llm=99, ground_truth=99, match=True\n",
"[23.4], llm=389, ground_truth=389, match=True\n",
"[24.0], llm=29, ground_truth=29, match=True\n",
"[24.1], llm=20, ground_truth=20, match=True\n",
"[24.2], llm=3, ground_truth=3, match=True\n",
"[24.3], llm=41, ground_truth=41, match=True\n",
"[24.4], llm=1, ground_truth=1, match=True\n",
"[25.0], llm=48, ground_truth=48, match=True\n",
"[25.1], llm=22, ground_truth=22, match=True\n",
"[25.2], llm=10, ground_truth=10, match=True\n",
"[25.3], llm=15, ground_truth=15, match=True\n",
"[25.4], llm=14, ground_truth=14, match=True\n",
"[26.0], llm=1, ground_truth=1, match=True\n",
"[26.1], llm=2, ground_truth=2, match=True\n",
"[26.2], llm=9, ground_truth=9, match=True\n",
"[26.3], llm=2, ground_truth=2, match=True\n",
"[26.4], llm=8, ground_truth=8, match=True\n",
"[27.0], llm=9800, ground_truth=9800, match=True\n",
"[27.1], llm=3864, ground_truth=3864, match=True\n",
"[27.2], llm=8930.25, ground_truth=8930.25, match=True\n",
"[27.3], llm=2868.75, ground_truth=2868.75, match=True\n",
"[27.4], llm=787.50, ground_truth=787.5, match=False\n",
"[28.0], llm=200, ground_truth=200, match=True\n",
"[28.1], llm=324, ground_truth=324, match=True\n",
"[28.2], llm=214, ground_truth=214, match=True\n",
"[28.3], llm=568, ground_truth=568, match=True\n",
"[28.4], llm=295, ground_truth=295, match=True\n",
"[29.0], llm=20, ground_truth=20, match=True\n",
"[29.1], llm=25, ground_truth=25, match=True\n",
"[29.2], llm=20, ground_truth=20, match=True\n",
"[29.3], llm=20, ground_truth=20, match=True\n",
"[29.4], llm=25, ground_truth=25, match=True\n",
"[30.0], llm=50, ground_truth=50, match=True\n",
"[30.1], llm=20, ground_truth=20, match=True\n",
"[30.2], llm=40, ground_truth=40, match=True\n",
"[30.3], llm=58, ground_truth=58, match=True\n",
"[30.4], llm=89, ground_truth=89, match=True\n",
"[31.0], llm=26, ground_truth=26, match=True\n",
"[31.1], llm=34, ground_truth=34, match=True\n",
"[31.2], llm=26, ground_truth=26, match=True\n",
"[31.3], llm=24, ground_truth=24, match=True\n",
"[31.4], llm=26, ground_truth=26, match=True\n",
"[32.0], llm=70, ground_truth=70, match=True\n",
"[32.1], llm=33, ground_truth=33, match=True\n",
"[32.2], llm=45, ground_truth=45, match=True\n",
"[32.3], llm=42, ground_truth=42, match=True\n",
"[32.4], llm=90, ground_truth=120, match=False\n",
"[33.0], llm=2695, ground_truth=2695, match=True\n",
"[33.1], llm=1715, ground_truth=1715, match=True\n",
"[33.2], llm=2940, ground_truth=2940, match=True\n",
"[33.3], llm=1764, ground_truth=1764, match=True\n",
"[33.4], llm=1960, ground_truth=1960, match=True\n",
"[34.0], llm=89, ground_truth=89, match=True\n",
"[34.1], llm=13, ground_truth=13, match=True\n",
"[34.2], llm=63, ground_truth=63, match=True\n",
"[34.3], llm=116, ground_truth=116, match=True\n",
"[34.4], llm=52, ground_truth=52, match=True\n",
"[35.0], llm=30, ground_truth=30, match=True\n",
"[35.1], llm=27, ground_truth=26, match=False\n",
"[35.2], llm=70, ground_truth=70, match=True\n",
"[35.3], llm=60.9375, ground_truth=60, match=False\n",
"[35.4], llm=50.83, ground_truth=50, match=False\n",
"[36.0], llm=52, ground_truth=52, match=True\n",
"[36.1], llm=78, ground_truth=78, match=True\n",
"[36.2], llm=25, ground_truth=25, match=True\n",
"[36.3], llm=36, ground_truth=36, match=True\n",
"[36.4], llm=60, ground_truth=60, match=True\n",
"[37.0], llm=18630, ground_truth=18630, match=True\n",
"[37.1], llm=18451.2, ground_truth=17856, match=False\n",
"[37.2], llm=32640, ground_truth=32640, match=True\n",
"[37.3], llm=25344, ground_truth=25344, match=True\n",
"[37.4], llm=15642.6, ground_truth=15283, match=False\n",
"[38.0], llm=174, ground_truth=174, match=True\n",
"[38.1], llm=200, ground_truth=200, match=True\n",
"[38.2], llm=365, ground_truth=365, match=True\n",
"[38.3], llm=272, ground_truth=272, match=True\n",
"[38.4], llm=268, ground_truth=268, match=True\n",
"[39.0], llm=38, ground_truth=38, match=True\n",
"[39.1], llm=24, ground_truth=24, match=True\n",
"[39.2], llm=40, ground_truth=40, match=True\n",
"[39.3], llm=44, ground_truth=44, match=True\n",
"[39.4], llm=117, ground_truth=117, match=True\n",
"[40.0], llm=352, ground_truth=352, match=True\n",
"[40.1], llm=132, ground_truth=132, match=True\n",
"[40.2], llm=198, ground_truth=198, match=True\n",
"[40.3], llm=290, ground_truth=290, match=True\n",
"[40.4], llm=252, ground_truth=252, match=True\n",
"[41.0], llm=235, ground_truth=235, match=True\n",
"[41.1], llm=415, ground_truth=415, match=True\n",
"[41.2], llm=290, ground_truth=290, match=True\n",
"[41.3], llm=305, ground_truth=305, match=True\n",
"[41.4], llm=170, ground_truth=170, match=True\n",
"[42.0], llm=229400, ground_truth=229400, match=True\n",
"[42.1], llm=85300, ground_truth=85300, match=True\n",
"[42.2], llm=548800, ground_truth=548800, match=True\n",
"[42.3], llm=300700, ground_truth=300700, match=True\n",
"[42.4], llm=414400, ground_truth=414400, match=True\n",
"[43.0], llm=7, ground_truth=7, match=True\n",
"[43.1], llm=8, ground_truth=8, match=True\n",
"[43.2], llm=34, ground_truth=34, match=True\n",
"[43.3], llm=9, ground_truth=9, match=True\n",
"[43.4], llm=21, ground_truth=21, match=True\n",
"[44.0], llm=183, ground_truth=183, match=True\n",
"[44.1], llm=301, ground_truth=301, match=True\n",
"[44.2], llm=197, ground_truth=197, match=True\n",
"[44.3], llm=369, ground_truth=369, match=True\n",
"[44.4], llm=432, ground_truth=432, match=True\n",
"[45.0], llm=25, ground_truth=25, match=True\n",
"[45.1], llm=13, ground_truth=13, match=True\n",
"[45.2], llm=22, ground_truth=22, match=True\n",
"[45.3], llm=17, ground_truth=17, match=True\n",
"[45.4], llm=18, ground_truth=18, match=True\n",
"[46.0], llm=22, ground_truth=22, match=True\n",
"[46.1], llm=18, ground_truth=18, match=True\n",
"[46.2], llm=15, ground_truth=15, match=True\n",
"[46.3], llm=30, ground_truth=30, match=True\n",
"[46.4], llm=13, ground_truth=13, match=True\n",
"[47.0], llm=139, ground_truth=139, match=True\n",
"[47.1], llm=187, ground_truth=187, match=True\n",
"[47.2], llm=292, ground_truth=292, match=True\n",
"[47.3], llm=248, ground_truth=248, match=True\n",
"[47.4], llm=225, ground_truth=225, match=True\n",
"[48.0], llm=31, ground_truth=31, match=True\n",
"[48.1], llm=15, ground_truth=15, match=True\n",
"[48.2], llm=5, ground_truth=5, match=True\n",
"[48.3], llm=50, ground_truth=50, match=True\n",
"[48.4], llm=11, ground_truth=11, match=True\n",
"[49.0], llm=770, ground_truth=770, match=True\n",
"[49.1], llm=810, ground_truth=810, match=True\n",
"[49.2], llm=749, ground_truth=749, match=True\n",
"[49.3], llm=1799, ground_truth=1799, match=True\n",
"[49.4], llm=1150, ground_truth=1150, match=True\n",
"[50.0], llm=633, ground_truth=633, match=True\n",
"[50.1], llm=642, ground_truth=642, match=True\n",
"[50.2], llm=695, ground_truth=695, match=True\n",
"[50.3], llm=855, ground_truth=855, match=True\n",
"[50.4], llm=1135, ground_truth=1135, match=True\n",
"[51.0], llm=4, ground_truth=4, match=True\n",
"[51.1], llm=5, ground_truth=5, match=True\n",
"[51.2], llm=4, ground_truth=4, match=True\n",
"[51.3], llm=1, ground_truth=1, match=True\n",
"[51.4], llm=2, ground_truth=2, match=True\n",
"[52.0], llm=312, ground_truth=312, match=True\n",
"[52.1], llm=140, ground_truth=140, match=True\n",
"[52.2], llm=224, ground_truth=224, match=True\n",
"[52.3], llm=312, ground_truth=312, match=True\n",
"[52.4], llm=408, ground_truth=408, match=True\n",
"[53.0], llm=25, ground_truth=25, match=True\n",
"[53.1], llm=25, ground_truth=25, match=True\n",
"[53.2], llm=25, ground_truth=25, match=True\n",
"[53.3], llm=25, ground_truth=25, match=True\n",
"[53.4], llm=50, ground_truth=50, match=True\n",
"[54.0], llm=59, ground_truth=59, match=True\n",
"[54.1], llm=19, ground_truth=19, match=True\n",
"[54.2], llm=16, ground_truth=16, match=True\n",
"[54.3], llm=9, ground_truth=9, match=True\n",
"[54.4], llm=36, ground_truth=36, match=True\n",
"[55.0], llm=237, ground_truth=237, match=True\n",
"[55.1], llm=159, ground_truth=159, match=True\n",
"[55.2], llm=216, ground_truth=216, match=True\n",
"[55.3], llm=123, ground_truth=123, match=True\n",
"[55.4], llm=87, ground_truth=87, match=True\n",
"[56.0], llm=5570, ground_truth=5570, match=True\n",
"[56.1], llm=5005, ground_truth=5005, match=True\n",
"[56.2], llm=4608, ground_truth=4608, match=True\n",
"[56.3], llm=5895, ground_truth=5895, match=True\n",
"[56.4], llm=4864, ground_truth=4864, match=True\n",
"[57.0], llm=69, ground_truth=69, match=True\n",
"[57.1], llm=84, ground_truth=84, match=True\n",
"[57.2], llm=76, ground_truth=76, match=True\n",
"[57.3], llm=97, ground_truth=97, match=True\n",
"[57.4], llm=78, ground_truth=78, match=True\n",
"[58.0], llm=300, ground_truth=300, match=True\n",
"[58.1], llm=420, ground_truth=420, match=True\n",
"[58.2], llm=240, ground_truth=240, match=True\n",
"[58.3], llm=195, ground_truth=195, match=True\n",
"[58.4], llm=270, ground_truth=270, match=True\n",
"[59.0], llm=996, ground_truth=996, match=True\n",
"[59.1], llm=396, ground_truth=396, match=True\n",
"[59.2], llm=2784, ground_truth=2784, match=True\n",
"[59.3], llm=304, ground_truth=304, match=True\n",
"[59.4], llm=2375, ground_truth=2375, match=True\n",
"[60.0], llm=72, ground_truth=72, match=True\n",
"[60.1], llm=95, ground_truth=95, match=True\n",
"[60.2], llm=95, ground_truth=95, match=True\n",
"[60.3], llm=72, ground_truth=72, match=True\n",
"[60.4], llm=100, ground_truth=100, match=True\n",
"[61.0], llm=450.53, ground_truth=512, match=False\n",
"[61.1], llm=571.04, ground_truth=602, match=False\n",
"[61.2], llm=417.59, ground_truth=418, match=False\n",
"[61.3], llm=449.93, ground_truth=431, match=False\n",
"[61.4], llm=653.34, ground_truth=639, match=False\n",
"[62.0], llm=5, ground_truth=5, match=True\n",
"[62.1], llm=2, ground_truth=2, match=True\n",
"[62.2], llm=6, ground_truth=6, match=True\n",
"[62.3], llm=5, ground_truth=5, match=True\n",
"[62.4], llm=8, ground_truth=8, match=True\n",
"[63.0], llm=2272, ground_truth=2272, match=True\n",
"[63.1], llm=4212, ground_truth=3600, match=False\n",
"[63.2], llm=5372, ground_truth=4852, match=False\n",
"[63.3], llm=4570.90, ground_truth=4252, match=False\n",
"[63.4], llm=4584, ground_truth=4584, match=True\n",
"[64.0], llm=35, ground_truth=35, match=True\n",
"[64.1], llm=53, ground_truth=53, match=True\n",
"[64.2], llm=60, ground_truth=60, match=True\n",
"[64.3], llm=33, ground_truth=33, match=True\n",
"[64.4], llm=31, ground_truth=31, match=True\n",
"[65.0], llm=144, ground_truth=173, match=False\n",
"[65.1], llm=431.65, ground_truth=380, match=False\n",
"[65.2], llm=363, ground_truth=311, match=False\n",
"[65.3], llm=131, ground_truth=159, match=False\n",
"[65.4], llm=Cannot be determined - missing weights for 5 fish, ground_truth=242, match=False\n",
"[66.0], llm=3, ground_truth=3, match=True\n",
"[66.1], llm=1, ground_truth=1, match=True\n",
"[66.2], llm=6, ground_truth=6, match=True\n",
"[66.3], llm=7, ground_truth=7, match=True\n",
"[66.4], llm=3, ground_truth=3, match=True\n",
"[67.0], llm=1488, ground_truth=1488, match=True\n",
"[67.1], llm=299, ground_truth=299, match=True\n",
"[67.2], llm=436, ground_truth=436, match=True\n",
"[67.3], llm=718, ground_truth=718, match=True\n",
"[67.4], llm=1445, ground_truth=1445, match=True\n",
"[68.0], llm=270, ground_truth=270, match=True\n",
"[68.1], llm=1296, ground_truth=1296, match=True\n",
"[68.2], llm=3456, ground_truth=3456, match=True\n",
"[68.3], llm=1512, ground_truth=1512, match=True\n",
"[68.4], llm=162, ground_truth=162, match=True\n",
"[69.0], llm=165, ground_truth=165, match=True\n",
"[69.1], llm=174, ground_truth=174, match=True\n",
"[69.2], llm=42, ground_truth=42, match=True\n",
"[69.3], llm=41, ground_truth=41, match=True\n",
"[69.4], llm=87, ground_truth=87, match=True\n",
"[70.0], llm=56, ground_truth=56, match=True\n",
"[70.1], llm=6, ground_truth=6, match=True\n",
"[70.2], llm=21, ground_truth=21, match=True\n",
"[70.3], llm=34, ground_truth=34, match=True\n",
"[70.4], llm=25, ground_truth=25, match=True\n",
"[71.0], llm=1275, ground_truth=1275, match=True\n",
"[71.1], llm=1151, ground_truth=1151, match=True\n",
"[71.2], llm=1382, ground_truth=1382, match=True\n",
"[71.3], llm=1271, ground_truth=1271, match=True\n",
"[71.4], llm=1047, ground_truth=1047, match=True\n",
"[72.0], llm=19, ground_truth=19, match=True\n",
"[72.1], llm=1, ground_truth=1, match=True\n",
"[72.2], llm=19, ground_truth=19, match=True\n",
"[72.3], llm=8, ground_truth=8, match=True\n",
"[72.4], llm=1, ground_truth=1, match=True\n",
"[73.0], llm=1630, ground_truth=1630, match=True\n",
"[73.1], llm=1664, ground_truth=1664, match=True\n",
"[73.2], llm=2050, ground_truth=2050, match=True\n",
"[73.3], llm=1460, ground_truth=1460, match=True\n",
"[73.4], llm=1821, ground_truth=1821, match=True\n",
"[74.0], llm=50, ground_truth=50, match=True\n",
"[74.1], llm=20, ground_truth=20, match=True\n",
"[74.2], llm=2.22, ground_truth=2, match=False\n",
"[74.3], llm=100, ground_truth=100, match=True\n",
"[74.4], llm=20, ground_truth=20, match=True\n",
"[75.0], llm=24, ground_truth=24, match=True\n",
"[75.1], llm=14, ground_truth=14, match=True\n",
"[75.2], llm=32, ground_truth=32, match=True\n",
"[75.3], llm=48, ground_truth=48, match=True\n",
"[75.4], llm=17, ground_truth=17, match=True\n",
"[76.0], llm=42, ground_truth=42, match=True\n",
"[76.1], llm=6, ground_truth=6, match=True\n",
"[76.2], llm=18, ground_truth=18, match=True\n",
"[76.3], llm=27, ground_truth=27, match=True\n",
"[76.4], llm=7, ground_truth=6, match=False\n",
"[77.0], llm=3, ground_truth=10, match=False\n",
"[77.1], llm=5, ground_truth=6, match=False\n",
"[77.2], llm=20, ground_truth=22, match=False\n",
"[77.3], llm=3, ground_truth=11, match=False\n",
"[77.4], llm=31, ground_truth=54, match=False\n",
"[78.0], llm=75, ground_truth=75, match=True\n",
"[78.1], llm=74, ground_truth=74, match=True\n",
"[78.2], llm=43, ground_truth=43, match=True\n",
"[78.3], llm=48, ground_truth=48, match=True\n",
"[78.4], llm=41, ground_truth=41, match=True\n",
"[79.0], llm=168, ground_truth=120, match=False\n",
"[79.1], llm=315, ground_truth=315, match=True\n",
"[79.2], llm=54, ground_truth=54, match=True\n",
"[79.3], llm=75, ground_truth=75, match=True\n",
"[79.4], llm=102, ground_truth=102, match=True\n",
"[80.0], llm=40, ground_truth=40, match=True\n",
"[80.1], llm=56, ground_truth=56, match=True\n",
"[80.2], llm=39, ground_truth=39, match=True\n",
"[80.3], llm=40, ground_truth=40, match=True\n",
"[80.4], llm=50, ground_truth=50, match=True\n",
"[81.0], llm=29160, ground_truth=29160, match=True\n",
"[81.1], llm=31200, ground_truth=31200, match=True\n",
"[81.2], llm=28800, ground_truth=28800, match=True\n",
"[81.3], llm=7200, ground_truth=7200, match=True\n",
"[81.4], llm=32760, ground_truth=32760, match=True\n",
"[82.0], llm=240, ground_truth=240, match=True\n",
"[82.1], llm=288, ground_truth=288, match=True\n",
"[82.2], llm=672, ground_truth=672, match=True\n",
"[82.3], llm=540, ground_truth=540, match=True\n",
"[82.4], llm=588, ground_truth=588, match=True\n",
"[83.0], llm=53, ground_truth=53, match=True\n",
"[83.1], llm=91, ground_truth=91, match=True\n",
"[83.2], llm=88, ground_truth=88, match=True\n",
"[83.3], llm=78, ground_truth=78, match=True\n",
"[83.4], llm=18, ground_truth=18, match=True\n",
"[84.0], llm=145, ground_truth=145, match=True\n",
"[84.1], llm=192, ground_truth=192, match=True\n",
"[84.2], llm=78, ground_truth=78, match=True\n",
"[84.3], llm=54, ground_truth=54, match=True\n",
"[84.4], llm=76, ground_truth=76, match=True\n",
"[85.0], llm=152, ground_truth=152, match=True\n",
"[85.1], llm=178, ground_truth=178, match=True\n",
"[85.2], llm=44, ground_truth=44, match=True\n",
"[85.3], llm=306, ground_truth=306, match=True\n",
"[85.4], llm=130, ground_truth=130, match=True\n",
"[86.0], llm=3, ground_truth=2, match=False\n",
"[86.1], llm=4, ground_truth=3, match=False\n",
"[86.2], llm=5, ground_truth=4, match=False\n",
"[86.3], llm=4, ground_truth=3, match=False\n",
"[86.4], llm=4, ground_truth=3, match=False\n",
"[87.0], llm=8.5, ground_truth=8, match=False\n",
"[87.1], llm=4, ground_truth=4, match=True\n",
"[87.2], llm=2, ground_truth=2, match=True\n",
"[87.3], llm=4.5, ground_truth=4, match=False\n",
"[87.4], llm=6, ground_truth=6, match=True\n",
"[88.0], llm=2, ground_truth=2, match=True\n",
"[88.1], llm=5, ground_truth=5, match=True\n",
"[88.2], llm=5, ground_truth=5, match=True\n",
"[88.3], llm=7, ground_truth=7, match=True\n",
"[88.4], llm=5, ground_truth=5, match=True\n",
"[89.0], llm=9, ground_truth=9, match=True\n",
"[89.1], llm=63, ground_truth=63, match=True\n",
"[89.2], llm=66, ground_truth=66, match=True\n",
"[89.3], llm=27, ground_truth=27, match=True\n",
"[89.4], llm=42, ground_truth=42, match=True\n",
"[90.0], llm=11.76, ground_truth=11, match=False\n",
"[90.1], llm=10.95, ground_truth=10, match=False\n",
"[90.2], llm=15.28, ground_truth=15, match=False\n",
"[90.3], llm=7.81, ground_truth=7, match=False\n",
"[90.4], llm=11.20, ground_truth=11, match=False\n",
"[91.0], llm=14400, ground_truth=14400, match=True\n",
"[91.1], llm=5040, ground_truth=5040, match=True\n",
"[91.2], llm=3520, ground_truth=3520, match=True\n",
"[91.3], llm=6300, ground_truth=6300, match=True\n",
"[91.4], llm=33630, ground_truth=33630, match=True\n",
"[92.0], llm=406, ground_truth=406, match=True\n",
"[92.1], llm=308, ground_truth=308, match=True\n",
"[92.2], llm=325, ground_truth=325, match=True\n",
"[92.3], llm=278, ground_truth=278, match=True\n",
"[92.4], llm=315, ground_truth=315, match=True\n",
"[93.0], llm=225, ground_truth=225, match=True\n",
"[93.1], llm=25, ground_truth=25, match=True\n",
"[93.2], llm=150, ground_truth=150, match=True\n",
"[93.3], llm=50, ground_truth=50, match=True\n",
"[93.4], llm=150, ground_truth=150, match=True\n",
"[94.0], llm=1406, ground_truth=1406, match=True\n",
"[94.1], llm=504, ground_truth=504, match=True\n",
"[94.2], llm=1320, ground_truth=1320, match=True\n",
"[94.3], llm=1656, ground_truth=1656, match=True\n",
"[94.4], llm=108, ground_truth=108, match=True\n",
"[95.0], llm=360, ground_truth=360, match=True\n",
"[95.1], llm=510, ground_truth=510, match=True\n",
"[95.2], llm=112, ground_truth=112, match=True\n",
"[95.3], llm=91, ground_truth=91, match=True\n",
"[95.4], llm=450, ground_truth=450, match=True\n",
"[96.0], llm=808, ground_truth=808, match=True\n",
"[96.1], llm=352, ground_truth=352, match=True\n",
"[96.2], llm=1062, ground_truth=1062, match=True\n",
"[96.3], llm=1203, ground_truth=1203, match=True\n",
"[96.4], llm=347, ground_truth=347, match=True\n",
"[97.0], llm=11.136, ground_truth=11, match=False\n",
"[97.1], llm=22.272, ground_truth=22, match=False\n",
"[97.2], llm=16.704, ground_truth=16, match=False\n",
"[97.3], llm=26.57, ground_truth=26, match=False\n",
"[97.4], llm=71.64, ground_truth=72, match=False\n",
"[98.0], llm=82, ground_truth=82, match=True\n",
"[98.1], llm=70, ground_truth=70, match=True\n",
"[98.2], llm=83.25, ground_truth=83, match=False\n",
"[98.3], llm=88.25, ground_truth=88, match=False\n",
"[98.4], llm=70.5, ground_truth=70, match=False\n",
"[99.0], llm=30.00, ground_truth=30, match=False\n",
"[99.1], llm=51.00, ground_truth=51, match=False\n",
"[99.2], llm=59.00, ground_truth=59, match=False\n",
"[99.3], llm=2.00, ground_truth=2, match=False\n",
"[99.4], llm=44.00, ground_truth=44, match=False\n"
]
}
],
"source": [
"rng = Random(55)\n",
"result_1 = cross_check_generators(rng, difficulty=1.0, num_generations=5)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"good = [0,1,3,4,5,6,7,8,9,10,12,13,14,15,16,17,18,19,20,22,23,24,25,26,28,29,30,31,33,34,36,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,62,64,66,67,68,69,70,71,72,73,75,78,80,81,82,83,84,85,88,89,91,92,93,94,95,96]\n",
"not_good = [2,11,21,27,32,35,37,61,63,65,74,76,77,79,86,87,90,97,98,99]\n"
]
}
],
"source": [
"good = [str(i) for i in range(len(result_1)) if result_1[i] == 5]\n",
"not_good = [str(i) for i in range(len(result_1)) if result_1[i] < 5]\n",
"\n",
"print('good = [' + \",\".join(good) + ']')\n",
"print('not_good = [' + \",\".join(not_good) + ']')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "reasoning-gym",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.11"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

View file

@ -12,8 +12,7 @@ from .calendar_arithmetic import CalendarArithmeticConfig, CalendarArithmeticDat
from .chain_sum import ChainSum, ChainSumConfig
from .fraction_simplification import FractionSimplificationConfig, FractionSimplificationDataset
from .gcd import GCDConfig, GCDDataset
# from .gsm_symbolic.gsm_symbolic_datasets import GSMSymbolicDataset, GSMSymbolicDatasetConfig
from .gsm_symbolic.gsm_symbolic import GSMSymbolicDataset, GSMSymbolicDatasetConfig
from .lcm import LCMConfig, LCMDataset
from .leg_counting import LegCountingConfig, LegCountingDataset
from .prime_factorization import PrimeFactorizationConfig, PrimeFactorizationDataset
@ -39,8 +38,8 @@ __all__ = [
"LegCountingDataset",
"PrimeFactorizationConfig",
"PrimeFactorizationDataset",
# "GSMSymbolicDatasetConfig",
# "GSMSymbolicDataset",
"GSMSymbolicDatasetConfig",
"GSMSymbolicDataset",
"TimeIntervalsConfig",
"TimeIntervalsDataset",
]

View file

@ -0,0 +1,6 @@
from .gsm_symbolic import GSMSymbolicDataset, GSMSymbolicDatasetConfig
__all__ = [
"GSMSymbolicDatasetConfig",
"GSMSymbolicDataset",
]

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,154 @@
"""GSM Symblic dataset generator"""
from dataclasses import dataclass
from random import Random
from typing import Any, Callable, Optional
from reasoning_gym.factory import ProceduralDataset, register_dataset
tasks_ok = [
0,
1,
2,
3,
4,
5,
6,
7,
8,
9,
10,
11,
12,
13,
14,
15,
16,
17,
18,
19,
20,
21,
22,
23,
24,
25,
26,
27,
28,
29,
30,
31,
33,
34,
36,
38,
39,
40,
41,
42,
43,
44,
45,
46,
47,
48,
49,
50,
51,
52,
53,
54,
55,
56,
57,
58,
59,
60,
62,
64,
66,
67,
68,
69,
70,
71,
72,
73,
75,
78,
80,
81,
82,
83,
84,
85,
88,
89,
91,
92,
93,
94,
95,
96,
99,
]
tasks_need_fix = [32, 35, 37, 61, 63, 65, 74, 76, 77, 79, 86, 87, 90, 97, 98]
@dataclass
class GSMSymbolicDatasetConfig:
"""Configuration for GSM symbolic task generation"""
difficulty: float = 1.0
seed: Optional[int] = None
size: int = 500
def validate(self) -> None:
"""Validate configuration parameters"""
assert self.size > 0, "size must be positive"
assert 1.0 <= self.difficulty <= 1.0 # currently only difficulty 1.0 is supported
class GSMSymbolicDataset(ProceduralDataset):
def __init__(self, config: GSMSymbolicDatasetConfig):
super().__init__(config, config.seed, config.size)
self._generators: dict[int, Callable[[Random, float], dict[str, Any]]] = None # initially None, lazy loading
self.task_indices = Random(self.seed).choices(tasks_ok, k=self.size)
@property
def generators(self) -> dict[int, Callable[[Random, float], dict[str, Any]]]:
"""Lazy load generators only when first accessed"""
if self._generators is None:
self._generators = self._load_generators()
return self._generators
def _load_generators(self):
"""
Generates mapper from task identifiers (keys) to example generator functions
"""
from . import generators_00_49, generators_50_99
def strip_prefix(s: str, prefix: str) -> str:
return s[len(prefix) :]
prefix = "generate_"
gs = {}
for n in dir(generators_00_49):
if n.startswith(prefix):
gs[int(strip_prefix(n, prefix))] = getattr(generators_00_49, n)
for n in dir(generators_50_99):
if n.startswith(prefix):
gs[int(strip_prefix(n, prefix))] = getattr(generators_50_99, n)
return gs
def __getitem__(self, idx: int) -> dict:
"""Generate a single GSM symbolic dataset"""
rng = Random(self.seed + idx)
generator_idx = self.task_indices[idx]
generator = self.generators[generator_idx]
return generator(rng, self.config.difficulty)
register_dataset("gsm_symbolic", GSMSymbolicDataset, GSMSymbolicDatasetConfig)

View file

@ -1,59 +0,0 @@
"""GSM Symblic dataset generator"""
from dataclasses import dataclass
from random import Random
from typing import List, Optional
from reasoning_gym.factory import ProceduralDataset, register_dataset
from . import generators
@dataclass
class GSMSymbolicDatasetConfig:
"""Configuration for GSM symbolic task generation"""
seed: Optional[int] = None
size: int = 500
def validate(self) -> None:
"""Validate configuration parameters"""
pass
class GSMSymbolicDataset(ProceduralDataset):
def __init__(self, config, seed=None, size=500):
super().__init__(config, seed, size)
# Initialize as None
self._generators = None
@property
def generators(self):
"""Lazy load generators only when first accessed"""
if self._generators is None:
self._generators = self.get_generators()
return self._generators
def get_generators(self):
"""
Generates mapper from task identifiers (keys) to example generator functions
"""
prefix = "generate_"
return {self.strip_prefix(n, prefix): getattr(generators, n) for n in dir(generators) if n.startswith(prefix)}
def strip_prefix(self, s, prefix):
return s[len(prefix) :]
def __getitem__(self, idx) -> dict:
"""Generate a single GSM symbolic dataset"""
rng = Random(self.seed + idx)
# Stringify the random integer generated from the random number generator
generator_idx = str(rng.randint(0, len(self.generators) - 1))
generator = self.generators[generator_idx]
# Here the res is a dictionary of
res = generator(rng)
return res
register_dataset("gsm_symbolic", GSMSymbolicDataset, GSMSymbolicDatasetConfig)

View file

@ -1,5 +1,8 @@
import math
import re
from typing import Optional
from decimal import Decimal, InvalidOperation
from fractions import Fraction
from typing import Any, Optional, Union
# DeepSeek Zero system prompt
SYSTEM_PROMPTS = {
@ -22,3 +25,52 @@ def extract_answer(completion: str, tag_name: str = "answer") -> Optional[str]:
if not matches:
return None
return matches[-1].group(1)
def format_number(num: Union[int, float], max_decimals: int = 2) -> str:
"""Convert a number to string representation with controlled decimal places.
Args:
num: Number to format
max_decimals: Maximum allowed decimal places
Returns:
String representation of the number
Raises:
ValueError: If number requires more decimal places than allowed
"""
if isinstance(num, int) or num.is_integer():
return str(int(num))
# Convert to Decimal for exact decimal arithmetic
d = Decimal(str(num))
# Find required decimals by removing trailing zeros
str_val = f"{d:f}"
str_val = str_val.rstrip("0").rstrip(".")
if "." in str_val:
required_decimals = len(str_val.split(".")[1])
if required_decimals > max_decimals:
raise ValueError(f"Number {num} requires {required_decimals} decimals but only {max_decimals} allowed")
# Format with required decimals
result = f"{num:.{max_decimals}f}".rstrip("0").rstrip(".")
# Verify result parses back to original value
try:
parsed = float(result)
if not math.isclose(parsed, num, rel_tol=1e-9):
raise ValueError(f"String representation {result} does not match original value {num}")
except (ValueError, InvalidOperation) as e:
raise ValueError(f"Failed to verify string representation: {e}")
return result
def is_integer(obj: Any) -> bool:
if isinstance(obj, (int, float)):
return isinstance(obj, int) or obj.is_integer()
elif isinstance(obj, Fraction):
return obj.denominator == 1
return False

View file

@ -0,0 +1,92 @@
from random import Random
import pytest
from reasoning_gym.arithmetic.gsm_symbolic import GSMSymbolicDataset, GSMSymbolicDatasetConfig
def test_gsm_symbolic_config_validation():
"""Test that config validation works"""
config = GSMSymbolicDatasetConfig(size=-1)
with pytest.raises(AssertionError):
config.validate()
def test_gsm_symbolic_deterministic():
"""Test that dataset generates same items with same seed"""
config = GSMSymbolicDatasetConfig(seed=42, size=10)
dataset1 = GSMSymbolicDataset(config)
dataset2 = GSMSymbolicDataset(config)
for i in range(len(dataset1)):
assert dataset1[i] == dataset2[i]
def test_gsm_symbolic_items():
"""Test basic properties of generated items"""
config = GSMSymbolicDatasetConfig(size=100, seed=42)
dataset = GSMSymbolicDataset(config)
for i in range(len(dataset)):
item = dataset[i]
assert isinstance(item, dict)
assert "question" in item
assert "answer" in item
assert isinstance(item["question"], str)
assert isinstance(item["answer"], str)
def test_gsm_symbolic_iteration():
"""Test that iteration respects dataset size"""
config = GSMSymbolicDatasetConfig(size=5, seed=42) # Small size for testing
dataset = GSMSymbolicDataset(config)
# Test manual iteration
items = []
for item in dataset:
items.append(item)
assert len(items) == config.size, "Iterator should yield exactly size items"
# Test list conversion
items = list(dataset)
assert len(items) == config.size, "Iterator should yield exactly size items"
# Test multiple iterations
first_items = list(dataset)
second_items = list(dataset)
assert first_items == second_items, "Multiple iterations should yield same items"
def test_gsm_symbolic_generators():
"""Test generator loading and access"""
config = GSMSymbolicDatasetConfig()
dataset = GSMSymbolicDataset(config)
# Test lazy loading
assert dataset._generators is None
_ = dataset.generators # Access to trigger loading
assert dataset._generators is not None
# Test generator mapping
assert isinstance(dataset.generators, dict)
assert len(dataset.generators) > 0
i = 0
rng = Random(18)
for key in sorted(dataset.generators.keys()):
generator = dataset.generators[key]
assert callable(generator)
print(i, key)
answer_set = set()
question_set = set()
for j in range(10):
x = generator(rng, difficulty=1.0)
question_set.add(x["question"])
answer_set.add(x["answer"])
# if j == 123:
# print(f"[{j}] q: {x['question']}")
# print(f"a: {x['answer']}")
# print()
print(f"ok: q={len(question_set)}, a={len(answer_set)}")
i += 1