diff --git a/notebooks/codeio.ipynb b/notebooks/codeio.ipynb
new file mode 100644
index 00000000..10ef8ba4
--- /dev/null
+++ b/notebooks/codeio.ipynb
@@ -0,0 +1,719 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "True"
+ ]
+ },
+ "execution_count": 1,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "import abc\n",
+ "import asyncio\n",
+ "from collections import defaultdict\n",
+ "import json\n",
+ "import os\n",
+ "import re\n",
+ "from typing import Union\n",
+ "\n",
+ "import aiohttp\n",
+ "import datasets\n",
+ "from dotenv import load_dotenv\n",
+ "import numpy as np\n",
+ "from sentence_transformers import SentenceTransformer\n",
+ "from tenacity import (\n",
+ " AsyncRetrying,\n",
+ " retry_if_exception_type,\n",
+ " stop_after_attempt,\n",
+ " wait_exponential,\n",
+ ")\n",
+ "import torch\n",
+ "from tqdm.notebook import tqdm\n",
+ "from e2b_code_interpreter import Sandbox\n",
+ "from e2b import TimeoutException\n",
+ "\n",
+ "load_dotenv()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "dataset = datasets.load_dataset(\"hkust-nlp/CodeIO-PyEdu-Reasoning\")['train']"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Extract the relevant parts of the prompt"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "pattern = re.compile(\n",
+ " r'(?s)' # DOTALL so . matches newlines\n",
+ " r'You are given a question that requires some input and output variables as follows:\\s*(.*?)'\n",
+ " r'\\s*The input and output requirements are as follows:\\s*(.*?)'\n",
+ " r'\\s*Given the following.*?Tip: Here is a reference code snippet for this question\\. '\n",
+ " r'You can refer to this code to guide your reasoning but not copy spans of code directly\\.\\s*(.*)'\n",
+ ")\n",
+ "\n",
+ "seen = set()\n",
+ "duplicate = 0\n",
+ "\n",
+ "with open(\"data/codeio-pyedu-extracted.jsonl\", \"w+\") as f:\n",
+ " for i, item in tqdm(enumerate(dataset), total=len(dataset)):\n",
+ " match = pattern.search(item[\"prompt\"])\n",
+ " if match:\n",
+ " # Extract relevant info\n",
+ " task_description = match.group(1).strip()\n",
+ " input_output_spec = match.group(2).strip()\n",
+ " code_sample = match.group(3).strip()\n",
+ "\n",
+ " # Check if code sample is unique\n",
+ " hash_entry = f\"{hash(task_description)}-{hash(input_output_spec)}-{hash(code_sample)}\"\n",
+ " if hash_entry in seen:\n",
+ " duplicate += 1\n",
+ " continue\n",
+ " seen.add(hash_entry)\n",
+ "\n",
+ " # Save to disk\n",
+ " json.dump({\n",
+ " \"task_description\": task_description,\n",
+ " \"input_output_spec\": input_output_spec,\n",
+ " \"code_sample\": code_sample\n",
+ " }, f)\n",
+ " f.write(\"\\n\")\n",
+ " else:\n",
+ " print(f\"No match found for item {i}\")\n",
+ "\n",
+ "print(f\"There were {duplicate} out of {len(dataset)} duplicate entries\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Subsample the data"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "class IdentitySampler:\n",
+ " def run(\n",
+ " self, features: Union[torch.Tensor, np.ndarray]\n",
+ " ) -> Union[torch.Tensor, np.ndarray]:\n",
+ " return features\n",
+ "\n",
+ "\n",
+ "class BaseSampler(abc.ABC):\n",
+ " def __init__(self, percentage: float):\n",
+ " if not 0 < percentage < 1:\n",
+ " raise ValueError(\"Percentage value not in (0, 1).\")\n",
+ " self.percentage = percentage\n",
+ "\n",
+ " @abc.abstractmethod\n",
+ " def run(\n",
+ " self, features: Union[torch.Tensor, np.ndarray]\n",
+ " ) -> Union[torch.Tensor, np.ndarray]:\n",
+ " pass\n",
+ "\n",
+ " def _store_type(self, features: Union[torch.Tensor, np.ndarray]) -> None:\n",
+ " self.features_is_numpy = isinstance(features, np.ndarray)\n",
+ " if not self.features_is_numpy:\n",
+ " self.features_device = features.device\n",
+ "\n",
+ " def _restore_type(self, features: torch.Tensor) -> Union[torch.Tensor, np.ndarray]:\n",
+ " if self.features_is_numpy:\n",
+ " return features.cpu().numpy()\n",
+ " return features.to(self.features_device)\n",
+ "\n",
+ "\n",
+ "class GreedyCoresetSampler(BaseSampler):\n",
+ " def __init__(\n",
+ " self,\n",
+ " percentage: float,\n",
+ " device: torch.device,\n",
+ " dtype: torch.dtype = torch.float32,\n",
+ " dimension_to_project_features_to=128,\n",
+ " ):\n",
+ " \"\"\"Greedy Coreset sampling base class.\"\"\"\n",
+ " super().__init__(percentage)\n",
+ "\n",
+ " self.device = device\n",
+ " self.dtype = dtype\n",
+ " self.dimension_to_project_features_to = dimension_to_project_features_to\n",
+ "\n",
+ " def _reduce_features(self, features):\n",
+ " if features.shape[1] == self.dimension_to_project_features_to:\n",
+ " return features\n",
+ " mapper = torch.nn.Linear(\n",
+ " features.shape[1], self.dimension_to_project_features_to, bias=False, dtype=self.dtype,\n",
+ " )\n",
+ " _ = mapper.to(self.device)\n",
+ " features = features.to(self.device)\n",
+ " return mapper(features)\n",
+ "\n",
+ " def run(\n",
+ " self, features: Union[torch.Tensor, np.ndarray]\n",
+ " ) -> Union[torch.Tensor, np.ndarray]:\n",
+ " \"\"\"Subsamples features using Greedy Coreset.\n",
+ "\n",
+ " Args:\n",
+ " features: [N x D]\n",
+ " \"\"\"\n",
+ " if self.percentage == 1:\n",
+ " return features\n",
+ " self._store_type(features)\n",
+ " if isinstance(features, np.ndarray):\n",
+ " features = torch.from_numpy(features)\n",
+ " reduced_features = self._reduce_features(features)\n",
+ " sample_indices = self._compute_greedy_coreset_indices(reduced_features)\n",
+ " return sample_indices\n",
+ "\n",
+ " @staticmethod\n",
+ " def _compute_batchwise_differences(\n",
+ " matrix_a: torch.Tensor, matrix_b: torch.Tensor\n",
+ " ) -> torch.Tensor:\n",
+ " \"\"\"Computes batchwise Euclidean distances using PyTorch.\"\"\"\n",
+ " a_times_a = matrix_a.unsqueeze(1).bmm(matrix_a.unsqueeze(2)).reshape(-1, 1)\n",
+ " b_times_b = matrix_b.unsqueeze(1).bmm(matrix_b.unsqueeze(2)).reshape(1, -1)\n",
+ " a_times_b = matrix_a.mm(matrix_b.T)\n",
+ "\n",
+ " return (-2 * a_times_b + a_times_a + b_times_b).clamp(0, None).sqrt()\n",
+ "\n",
+ " def _compute_greedy_coreset_indices(self, features: torch.Tensor) -> np.ndarray:\n",
+ " \"\"\"Runs iterative greedy coreset selection.\n",
+ "\n",
+ " Args:\n",
+ " features: [NxD] input feature bank to sample.\n",
+ " \"\"\"\n",
+ " distance_matrix = self._compute_batchwise_differences(features, features)\n",
+ " coreset_anchor_distances = torch.norm(distance_matrix, dim=1)\n",
+ "\n",
+ " coreset_indices = []\n",
+ " num_coreset_samples = int(len(features) * self.percentage)\n",
+ "\n",
+ " for _ in range(num_coreset_samples):\n",
+ " select_idx = torch.argmax(coreset_anchor_distances).item()\n",
+ " coreset_indices.append(select_idx)\n",
+ "\n",
+ " coreset_select_distance = distance_matrix[\n",
+ " :, select_idx : select_idx + 1 # noqa E203\n",
+ " ]\n",
+ " coreset_anchor_distances = torch.cat(\n",
+ " [coreset_anchor_distances.unsqueeze(-1), coreset_select_distance], dim=1\n",
+ " )\n",
+ " coreset_anchor_distances = torch.min(coreset_anchor_distances, dim=1).values\n",
+ "\n",
+ " return torch.tensor(coreset_indices, device=features.device, dtype=torch.int64)\n",
+ "\n",
+ "\n",
+ "class ApproximateGreedyCoresetSampler(GreedyCoresetSampler):\n",
+ " def __init__(\n",
+ " self,\n",
+ " percentage: float,\n",
+ " device: torch.device,\n",
+ " dtype: torch.dtype = torch.float32,\n",
+ " number_of_starting_points: int = 10,\n",
+ " dimension_to_project_features_to: int = 128,\n",
+ " ):\n",
+ " \"\"\"Approximate Greedy Coreset sampling base class.\"\"\"\n",
+ " self.number_of_starting_points = number_of_starting_points\n",
+ " super().__init__(percentage, device, dtype, dimension_to_project_features_to)\n",
+ "\n",
+ " def _compute_greedy_coreset_indices(self, features: torch.Tensor) -> np.ndarray:\n",
+ " \"\"\"Runs approximate iterative greedy coreset selection.\n",
+ "\n",
+ " This greedy coreset implementation does not require computation of the\n",
+ " full N x N distance matrix and thus requires a lot less memory, however\n",
+ " at the cost of increased sampling times.\n",
+ "\n",
+ " Args:\n",
+ " features: [NxD] input feature bank to sample.\n",
+ " \"\"\"\n",
+ " number_of_starting_points = np.clip(\n",
+ " self.number_of_starting_points, None, len(features)\n",
+ " )\n",
+ " start_points = np.random.choice(\n",
+ " len(features), number_of_starting_points, replace=False\n",
+ " ).tolist()\n",
+ "\n",
+ " approximate_distance_matrix = self._compute_batchwise_differences(\n",
+ " features, features[start_points]\n",
+ " )\n",
+ " approximate_coreset_anchor_distances = torch.mean(\n",
+ " approximate_distance_matrix, axis=-1\n",
+ " ).reshape(-1, 1)\n",
+ " coreset_indices = []\n",
+ " num_coreset_samples = int(len(features) * self.percentage)\n",
+ "\n",
+ " with torch.no_grad():\n",
+ " for _ in tqdm.tqdm(range(num_coreset_samples), desc=\"Subsampling...\"):\n",
+ " select_idx = torch.argmax(approximate_coreset_anchor_distances).item()\n",
+ " coreset_indices.append(select_idx)\n",
+ " coreset_select_distance = self._compute_batchwise_differences(\n",
+ " features, features[select_idx : select_idx + 1] # noqa: E203\n",
+ " )\n",
+ " approximate_coreset_anchor_distances = torch.cat(\n",
+ " [approximate_coreset_anchor_distances, coreset_select_distance],\n",
+ " dim=-1,\n",
+ " )\n",
+ " approximate_coreset_anchor_distances = torch.min(\n",
+ " approximate_coreset_anchor_distances, dim=1\n",
+ " ).values.reshape(-1, 1)\n",
+ "\n",
+ " return torch.tensor(coreset_indices, device=features.device, dtype=torch.int64)\n",
+ "\n",
+ "\n",
+ "class RandomSampler(BaseSampler):\n",
+ " def __init__(self, percentage: float):\n",
+ " super().__init__(percentage)\n",
+ "\n",
+ " def run(\n",
+ " self, features: Union[torch.Tensor, np.ndarray]\n",
+ " ) -> Union[torch.Tensor, np.ndarray]:\n",
+ " \"\"\"Randomly samples input feature collection.\n",
+ "\n",
+ " Args:\n",
+ " features: [N x D]\n",
+ " \"\"\"\n",
+ " num_random_samples = int(len(features) * self.percentage)\n",
+ " subset_indices = np.random.choice(\n",
+ " len(features), num_random_samples, replace=False\n",
+ " )\n",
+ " return torch.tensor(subset_indices, device=features.device, dtype=torch.int64)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# I ran this cell on Google Colab because I don't have a GPU on my local machine,\n",
+ "# hence why you see the Google Drive paths\n",
+ "\n",
+ "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
+ "model = SentenceTransformer(\"nomic-ai/modernbert-embed-base\")\n",
+ "print(model)\n",
+ "\n",
+ "def get_entry_info(entry) -> str:\n",
+ " return entry['task_description']\n",
+ "\n",
+ "def get_embeddings(text) -> torch.Tensor:\n",
+ " return torch.from_numpy(model.encode(text)).to(torch.bfloat16)\n",
+ "\n",
+ "embeddings = []\n",
+ "\n",
+ "with open(\"./drive/MyDrive/reasoning-gym/codeio-pyedu-extracted.jsonl\") as f:\n",
+ " for line in tqdm(f):\n",
+ " entry = json.loads(line)\n",
+ " entry_info = get_entry_info(entry)\n",
+ " embeddings.append(get_embeddings(entry_info))\n",
+ "\n",
+ "embeddings = torch.stack(embeddings).to(torch.bfloat16).to(device)\n",
+ "print(embeddings.shape)\n",
+ "\n",
+ "sampler = ApproximateGreedyCoresetSampler(\n",
+ " percentage=0.05, \n",
+ " device=device, \n",
+ " dtype=torch.bfloat16,\n",
+ " dimension_to_project_features_to=768,\n",
+ ")\n",
+ "subsampled = sampler.run(embeddings)\n",
+ "\n",
+ "indices = set(subsampled.cpu().tolist())\n",
+ "with open(\"./drive/MyDrive/reasoning-gym/codeio-pyedu-extracted.jsonl\", \"r\") as f_in, \\\n",
+ " open(\"./drive/MyDrive/reasoning-gym/codeio-pyedu-best-coverage.jsonl\", \"w+\") as f_out:\n",
+ " for i, line in enumerate(f_in):\n",
+ " if i in indices:\n",
+ " f_out.write(line)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Create input generators for each problem separately"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "SYSTEM_PROMPT = \"\"\"You are a helpful assistant that generates valid Python functions that act as input generators for a given code snippet.\n",
+ "\n",
+ "You have access to `random.Random`, therefore you SHOULD NOT import it again. You should use this random number generator to make the input generation process stochastic on each call.\n",
+ "\n",
+ "When the user asks you to generate an input for a code snippet, you should strictly respond in the following format:\n",
+ "\n",
+ "def generate_input(rng: Random) -> dict:\n",
+ " # Your code here\n",
+ " pass\n",
+ "\n",
+ "\n",
+ "The output of the function should be a dictionary where the keys are the variable names and the values are the generated values.\n",
+ "\n",
+ "It must contain all the variables that listed in the user's input specification, or more precisely in the `main_solution` function signature. \n",
+ "\"\"\"\n",
+ "\n",
+ "USER_PROMPT = \"\"\"Following are a task description, input/output specification, and relevant code snippet for a Python programming task.\n",
+ "\n",
+ "\n",
+ "{task_description}\n",
+ "\n",
+ "\n",
+ "\n",
+ "{input_output_spec}\n",
+ "\n",
+ "\n",
+ "\n",
+ "{code_sample}\n",
+ "\n",
+ "\n",
+ "Your task is to write a Python function `def generate_input(rng: Random) -> dict:` that generates valid inputs for the given code snippet, based on the provided information.\n",
+ "\"\"\"\n",
+ "\n",
+ "# We'll control concurrency with a semaphore\n",
+ "CONCURRENCY_LIMIT = 10\n",
+ "sem = asyncio.Semaphore(CONCURRENCY_LIMIT)\n",
+ "\n",
+ "async def fetch_input_generator(session: aiohttp.ClientSession, entry: dict) -> dict:\n",
+ " \"\"\"\n",
+ " Sends a POST request to OpenRouter with the system & user prompts,\n",
+ " extracts the function from the response, and returns the updated entry.\n",
+ " \"\"\"\n",
+ " url = \"https://openrouter.ai/api/v1/chat/completions\"\n",
+ " headers = {\n",
+ " \"Authorization\": f\"Bearer {os.getenv('OPENROUTER_API_KEY')}\",\n",
+ " \"Content-Type\": \"application/json\",\n",
+ " }\n",
+ "\n",
+ " payload = {\n",
+ " \"model\": \"deepseek/deepseek-chat\",\n",
+ " \"messages\": [\n",
+ " {\"role\": \"system\", \"content\": SYSTEM_PROMPT},\n",
+ " {\n",
+ " \"role\": \"user\",\n",
+ " \"content\": USER_PROMPT.format(**entry)\n",
+ " },\n",
+ " ],\n",
+ " }\n",
+ "\n",
+ " async with sem:\n",
+ " async for attempt in AsyncRetrying(\n",
+ " stop=stop_after_attempt(5),\n",
+ " wait=wait_exponential(multiplier=1, min=1, max=60),\n",
+ " retry=retry_if_exception_type(\n",
+ " (aiohttp.ClientError, asyncio.TimeoutError, json.JSONDecodeError, ValueError)\n",
+ " ),\n",
+ " ):\n",
+ " with attempt:\n",
+ " async with session.post(url, headers=headers, json=payload) as response:\n",
+ " data = await response.json()\n",
+ "\n",
+ " # Basic checks for valid response\n",
+ " if \"choices\" not in data or not data[\"choices\"]:\n",
+ " print(\"No choices found in response\")\n",
+ " return entry\n",
+ "\n",
+ " content = data[\"choices\"][0][\"message\"][\"content\"]\n",
+ " match = re.search(r\"(.*?)\", content, re.DOTALL)\n",
+ " if not match:\n",
+ " print(\"Could not find ... block in response\")\n",
+ " return entry\n",
+ "\n",
+ " input_generator = match.group(1).strip()\n",
+ " entry[\"input_generator\"] = input_generator\n",
+ " return entry\n",
+ "\n",
+ " # If we exit the loop without returning, raise Exception\n",
+ " raise Exception(\"Failed to get valid input generator after retries\")\n",
+ "\n",
+ "async def process_file(input_file: str, output_file: str):\n",
+ " \"\"\"\n",
+ " Reads each line from `input_file`, processes each entry concurrently,\n",
+ " and writes augmented entries to `output_file`.\n",
+ " \"\"\"\n",
+ " # Read all lines first (synchronously)\n",
+ " with open(input_file, \"r\") as f_in:\n",
+ " lines = f_in.readlines()\n",
+ "\n",
+ " tasks = []\n",
+ " async with aiohttp.ClientSession() as session:\n",
+ " # Create a task for each line/entry\n",
+ " for line in lines:\n",
+ " entry = json.loads(line)\n",
+ " tasks.append(asyncio.create_task(fetch_input_generator(session, entry)))\n",
+ "\n",
+ " # We'll gather results while showing progress\n",
+ " results = []\n",
+ " for t in tqdm(asyncio.as_completed(tasks), total=len(tasks)):\n",
+ " result = await t\n",
+ " results.append(result)\n",
+ "\n",
+ " # Write all results out\n",
+ " with open(output_file, \"w\") as f_out:\n",
+ " for res in results:\n",
+ " f_out.write(json.dumps(res))\n",
+ " f_out.write(\"\\n\")\n",
+ "\n",
+ "# Finally, run the entire pipeline\n",
+ "await process_file(\n",
+ " input_file=\"data/codeio-pyedu-best-coverage.jsonl\",\n",
+ " output_file=\"data/codeio-pyedu-with-input-generator.jsonl\"\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Filter out invalid input generators"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "If you want to install a template with custom package\n",
+ "\n",
+ "https://e2b.dev/docs/quickstart/install-custom-packages\n",
+ "\n",
+ "An example e2b.Dockerfile looks like this:\n",
+ "\n",
+ "```Dockerfile\n",
+ "FROM e2bdev/code-interpreter:latest\n",
+ "\n",
+ "RUN pip install numpy matplotlib scipy pandas scikit-learn sympy networkx requests pillow bs4 cryptography spacy numba pyyaml regex\n",
+ "```\n",
+ "\n",
+ "However, I am going with the default installed libraries: https://e2b.dev/docs/code-interpreting/analyze-data-with-ai/pre-installed-libraries "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Example usage of the Sandbox class\n",
+ "with Sandbox() as sandbox:\n",
+ "\n",
+ " # First initialize the sandbox\n",
+ " execution = sandbox.run_code(\"\"\"\n",
+ "from random import Random # <----- ALWAYS PREPEND THIS LINE TO YOUR CODE SNIPPET\n",
+ "\n",
+ "def hello_world():\n",
+ " return {\"a\": 5, \"b\": 10}\n",
+ "\n",
+ "def multiple_hello_worlds(rng: Random):\n",
+ " return [\n",
+ " {\"a\": rng.randint(1, 10), \"b\": rng.randint(10, 20)},\n",
+ " {\"a\": 10, \"b\": 20},\n",
+ " ]\n",
+ "\"\"\"\n",
+ " )\n",
+ " try:\n",
+ " # Run the code snippet\n",
+ " execution = sandbox.run_code(\"rng = Random(53);multiple_hello_worlds(rng)\", timeout=5)\n",
+ " print(execution)\n",
+ " if execution.error:\n",
+ " print(\"[!! FOUND ERROR !!]\")\n",
+ " else:\n",
+ " print(type(execution.text))\n",
+ " print(execution.text)\n",
+ " except TimeoutException as e:\n",
+ " print(e)\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "371d38e1fe9e41d587b2cfa64ca9ef91",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ " 0%| | 0/7053 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Response 404\n",
+ "Response 404\n",
+ "Response 404\n",
+ "Response 404\n",
+ "Response 404\n",
+ "Response 404\n",
+ "Response 404\n",
+ "Response 404\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "full_sampling_fails: 913\n",
+ "warmup_fails: 528\n",
+ "missing_input_generator: 36\n",
+ "cannot_initialize_code: 98\n",
+ "Total errors: 1575\n"
+ ]
+ }
+ ],
+ "source": [
+ "CODE_TEMPLATE = \"\"\"from random import Random\n",
+ "{code_sample}\n",
+ "\n",
+ "{input_generator}\n",
+ "\n",
+ "def multiple_eval(num_generations: int, seed: int = 42) -> tuple:\n",
+ " rng = Random(seed)\n",
+ " inputs = [generate_input(rng) for _ in range(num_generations)]\n",
+ " outputs = [main_solution(**inp) for inp in inputs]\n",
+ " return inputs, outputs\n",
+ "\"\"\"\n",
+ "\n",
+ "SAMPLING_TEMPLATE = \"multiple_eval({num_generations})\"\n",
+ "\n",
+ "WARMUP_GENERATIONS = 5\n",
+ "TOTAL_GENERATIONS = 1_000\n",
+ "TIMEOUT_CODE_INIT = 10\n",
+ "TIMEOUT_PER_SAMPLE = 0.01\n",
+ "\n",
+ "errors = defaultdict(int)\n",
+ "total_entries = sum(1 for _ in open(\"data/codeio-pyedu-with-input-generator.jsonl\", \"r\"))\n",
+ "\n",
+ "with open(\"data/codeio-pyedu-with-input-generator.jsonl\", \"r\") as f_in, \\\n",
+ " open(\"data/codeio-pyedu-with-input-generator-filtered.jsonl\", \"w+\") as f_out:\n",
+ "\n",
+ " iterator = tqdm(enumerate(f_in), total=total_entries)\n",
+ "\n",
+ " for i, line in iterator:\n",
+ " iterator.set_description(f\"Failures: \" + \" | \".join(f\"{k}: {v}\" for k, v in errors.items()) + f\" | total: {sum(errors.values())}\")\n",
+ " entry = json.loads(line)\n",
+ "\n",
+ " if not \"input_generator\" in entry:\n",
+ " errors[\"missing_input_generator\"] += 1\n",
+ " continue\n",
+ " \n",
+ " with Sandbox() as sandbox:\n",
+ " # 1. Initialize the sandbox\n",
+ " try: \n",
+ " execution = sandbox.run_code(\n",
+ " code=CODE_TEMPLATE.format(**entry), \n",
+ " timeout=TIMEOUT_CODE_INIT\n",
+ " )\n",
+ " assert not execution.error, \"Error in code snippet\"\n",
+ " except Exception as e:\n",
+ " errors[\"cannot_initialize_code\"] += 1\n",
+ " continue\n",
+ " \n",
+ " # 2. Warmup the sampling\n",
+ " try:\n",
+ " execution = sandbox.run_code(\n",
+ " code=SAMPLING_TEMPLATE.format(num_generations=WARMUP_GENERATIONS),\n",
+ " timeout=TIMEOUT_CODE_INIT,\n",
+ " )\n",
+ " assert not execution.error, \"Error in input generator (warmup)\"\n",
+ " assert execution.text, \"Empty input generator output (warmup)\"\n",
+ " inputs, outputs = eval(execution.text)\n",
+ " except Exception as e:\n",
+ " errors[\"warmup_fails\"] += 1\n",
+ " continue\n",
+ "\n",
+ " # 3. Run the full sampling\n",
+ " try:\n",
+ " execution = sandbox.run_code(\n",
+ " code=SAMPLING_TEMPLATE.format(num_generations=TOTAL_GENERATIONS),\n",
+ " timeout=int(TIMEOUT_PER_SAMPLE * TOTAL_GENERATIONS),\n",
+ " )\n",
+ " assert not execution.error, \"Error in input generator (full)\"\n",
+ " assert execution.text, \"Empty input generator output (full)\"\n",
+ " inputs, outputs = eval(execution.text)\n",
+ " assert len(inputs) == TOTAL_GENERATIONS, \"Mismatch in input generations\"\n",
+ " assert len(outputs) == TOTAL_GENERATIONS, \"Mismatch in output generations\"\n",
+ " unique_inputs = len(set(hash(json.dumps(inp, sort_keys=True)) for inp in inputs))\n",
+ " unique_outputs = len(set(hash(json.dumps(out, sort_keys=True)) for out in outputs))\n",
+ " except:\n",
+ " errors[\"full_sampling_fails\"] += 1\n",
+ " continue\n",
+ " \n",
+ " # 4. Save the entry\n",
+ " entry = entry | {\n",
+ " \"unique_inputs\": unique_inputs,\n",
+ " \"unique_outputs\": unique_outputs,\n",
+ " \"total_generations\": TOTAL_GENERATIONS,\n",
+ " }\n",
+ " f_out.write(json.dumps(entry))\n",
+ " f_out.write(\"\\n\")\n",
+ "\n",
+ "for k, v in errors.items():\n",
+ " print(f\"{k}: {v}\")\n",
+ "print(f\"Total errors: {sum(errors.values())}\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.12.7"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}
diff --git a/notebooks/codeio/PreprocessCode.ipynb b/notebooks/codeio/PreprocessCode.ipynb
new file mode 100644
index 00000000..00f07980
--- /dev/null
+++ b/notebooks/codeio/PreprocessCode.ipynb
@@ -0,0 +1,603 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## CodeI/O\n",
+ "\n",
+ "Original paper (DeepSeek): https://arxiv.org/pdf/2502.07316\n",
+ "\n",
+ "The approach begins by obtaining high quality raw code data and preprocessing it by prompting an LLM. The output of this preprocessing, for each raw code file used, should be:\n",
+ "\n",
+ "- cleaned reference code, with a main entrypoint function\n",
+ "- a query, converting the reference code into a question (along the lines of \"given [function parameters...] how can we obtain [desired outputs...]\")\n",
+ "- a natural language description of all inputs (function parameters) and outputs (function return values)\n",
+ "- an input generator, which can generate a dictionary of valid inputs for the function\n",
+ "\n",
+ "This notebook seeks to experiment with prompting an LLM to this end, as a starting point. The raw code data is from this GitHub repository that the DeepSeek paper mentions as one of their raw code sources: https://github.com/TheAlgorithms/Python\n",
+ "\n",
+ "NOTE: Be careful with the raw code you input into this, as cells later execute the LLM-generated outputs."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Cloning into 'Python'...\n",
+ "remote: Enumerating objects: 20925, done.\u001b[K\n",
+ "remote: Counting objects: 100% (13/13), done.\u001b[K\n",
+ "remote: Compressing objects: 100% (11/11), done.\u001b[K\n",
+ "remote: Total 20925 (delta 6), reused 2 (delta 2), pack-reused 20912 (from 3)\u001b[K\n",
+ "Receiving objects: 100% (20925/20925), 14.86 MiB | 17.27 MiB/s, done.\n",
+ "Resolving deltas: 100% (13469/13469), done.\n"
+ ]
+ }
+ ],
+ "source": [
+ "!git clone https://github.com/TheAlgorithms/Python.git\n",
+ "\n",
+ "import shutil\n",
+ "from pathlib import Path\n",
+ "\n",
+ "repo_dir = Path(\"Python\")\n",
+ "raw_code_dir = Path(\"raw_files\")\n",
+ "raw_code_dir.mkdir(exist_ok=True)\n",
+ "\n",
+ "def process_dir(directory: Path):\n",
+ " # Move all the Python code files to the raw code file directory\n",
+ " # Handles subdirectories recursively\n",
+ " dirname = directory.name\n",
+ " for file in directory.iterdir():\n",
+ " if file.is_dir():\n",
+ " process_dir(file)\n",
+ " elif file.name.endswith(\".py\") and file.name != \"__init__.py\":\n",
+ " file.rename(raw_code_dir / f\"{dirname}_{file.name}\")\n",
+ "\n",
+ "for repo_child in repo_dir.iterdir():\n",
+ " # For this repo, algorithms are divided into categories by subdirectories\n",
+ " if not repo_child.is_dir() or repo_child.name.startswith(\".\"):\n",
+ " continue\n",
+ " process_dir(repo_child)\n",
+ "\n",
+ "shutil.rmtree(repo_dir)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import random\n",
+ "from pathlib import Path\n",
+ "from dotenv import load_dotenv\n",
+ "load_dotenv()\n",
+ "raw_files = list(Path(\"raw_files/\").iterdir())"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Note that the below prompt is built for DeepSeekV3. It may not work with other LLMs."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "format_prompt_template = \"\"\"\n",
+ "You are tasked with preprocessing a raw file of Python code into a standard format. The format is made up of several components. Here is a very simple example of a raw code file:\n",
+ "\n",
+ "def kg_to_pounds(weights):\n",
+ " return [w * 2.20462 for w in weights]\n",
+ "\n",
+ "def filter_weekly(original_measurements, days):\n",
+ " return [m for i, m in enumerate(original_measurements) if i % 7 == 0]\n",
+ "\n",
+ "def main(kgs, days):\n",
+ " lbs = kg_to_pounds(kgs)\n",
+ "\n",
+ " for measurement in filter_weekly(lbs, days):\n",
+ " print(measurement)\n",
+ "\n",
+ "1. Cleaned reference code, with a main entrypoint function that takes all required arguments as parameters and returns all outputs.\n",
+ "\n",
+ "The name of the main entrypoint function should be `main`. The parameters should be clearly named but do not require type hints. The function should return a dict mapping output names to values. The function should contain all the necessary code to perform the functionality, without splitting into several functions. The function should not print or otherwise output anything; results should be returned as part of the result dict. Ensure you include any imports necessary, prior to the function definition.\n",
+ "\n",
+ "Example function signature: `def main(weights_kg, days):`\n",
+ "\n",
+ "2. A query, defined as natural language description of the question the function answers.\n",
+ "\n",
+ "Example query: \"You are given two lists of integers, `weights_kg` and `days`. The unit of `weights_kg` is kilograms. `days` refers to the number of days passed, starting from zero. Your task is to convert the integers to pounds and filter to only one weight measurement every 7 days. Return the list of integers in pounds.\"\n",
+ "\n",
+ "The query should be as detailed as the code requires to be fully explained. It should be clear what the function does, what the inputs are, and what the outputs are.\n",
+ "\n",
+ "3. A natural language description of all inputs (function parameters) and outputs (return values) of the function.\n",
+ "\n",
+ "Example description:\n",
+ "\n",
+ "Input:\n",
+ " weights_kg (list of int): List of weight values in kilograms.\n",
+ " days (list of int): List of integers representing the number of days passed, starting from zero.\n",
+ "\n",
+ "Output:\n",
+ " return (dict): A dictionary with one key:\n",
+ " - weights_lb (list of int): List of filtered weight values in pounds.\n",
+ "\n",
+ "4. Python 3.11 code for an input generator, which randomly generates valid sets of inputs for the functions.\n",
+ "\n",
+ "The input generator should return a dict mapping parameter names to values. The values should be randomly generated, but should be valid inputs for the function. You have access to `random` in the input generator. Do not import any other modules.\n",
+ "\n",
+ "Example input generator:\n",
+ "\n",
+ "def input_generator():\n",
+ " weights = [random.randint(100) for _ in range(40)]\n",
+ " days = list(range(40))\n",
+ " return {{\"weights_kg\": weights, \"days\": days}}\n",
+ "\n",
+ "Using the guidelines and example above, preprocess the following raw code file into the standard format:\n",
+ "\n",
+ "{0}\n",
+ "\n",
+ "Output the components (reference code, query, description, input generator) in order. Separate each component with a line of dashes (---). Avoid code blocks and do not output any Markdown formatting. Respond only with the four components, no prefix or additional text.\n",
+ "\"\"\""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Edit the below cell or appropriate env variables to utilise different API providers, etc"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 19,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import asyncio\n",
+ "import os\n",
+ "from openai import AsyncOpenAI\n",
+ "from openai.types.chat import ChatCompletion, ChatCompletionMessageParam\n",
+ "from typing import Any, Iterable\n",
+ "\n",
+ "# Cap concurrent requests. I had to set this to 1 for the DeepSeek API to work, YMMV\n",
+ "semaphore = asyncio.Semaphore(1)\n",
+ "\n",
+ "async def llm_generate(\n",
+ " client: AsyncOpenAI,\n",
+ " messages: Iterable[ChatCompletionMessageParam],\n",
+ " sampling_params: dict[str, Any],\n",
+ " retry_empty_response: bool = True,\n",
+ " max_retries: int = 3,\n",
+ ") -> ChatCompletion:\n",
+ " for trial in range(max_retries):\n",
+ " async with semaphore:\n",
+ " try:\n",
+ " completion = await client.chat.completions.create(\n",
+ " messages=messages, **sampling_params\n",
+ " )\n",
+ " if completion.choices[0].message.content or not retry_empty_response:\n",
+ " return completion\n",
+ " await asyncio.sleep(5)\n",
+ " except Exception as e:\n",
+ " print(f\"Failure response (trial {trial}):\", e)\n",
+ " await asyncio.sleep(3 * (trial + 1))\n",
+ " if trial == max_retries - 1:\n",
+ " raise\n",
+ "\n",
+ "client = AsyncOpenAI(\n",
+ " base_url=os.getenv(\"API_BASE_URL\"),\n",
+ " api_key=os.getenv(\"API_KEY\"),\n",
+ " timeout=120.0,\n",
+ ")\n",
+ "\n",
+ "sampling_params = {\n",
+ " \"model\": \"deepseek-chat\", # For DeepSeek API\n",
+ " #\"model\": \"deepseek/deepseek-chat:free\", # For OpenRouter\n",
+ " \"max_tokens\": 8192,\n",
+ "}"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Demo cell to illustrate the LLM preprocessing:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "raw_files/genetic_algorithm_basic_string.py\n",
+ "def main(target: str, genes: list[str], debug: bool = True) -> dict:\n",
+ " if N_POPULATION < N_SELECTED:\n",
+ " raise ValueError(f\"{N_POPULATION} must be bigger than {N_SELECTED}\")\n",
+ " \n",
+ " not_in_genes_list = sorted({c for c in target if c not in genes})\n",
+ " if not_in_genes_list:\n",
+ " raise ValueError(f\"{not_in_genes_list} is not in genes list, evolution cannot converge\")\n",
+ " \n",
+ " population = []\n",
+ " for _ in range(N_POPULATION):\n",
+ " population.append(\"\".join([random.choice(genes) for _ in range(len(target))]))\n",
+ " \n",
+ " generation, total_population = 0, 0\n",
+ " \n",
+ " while True:\n",
+ " generation += 1\n",
+ " total_population += len(population)\n",
+ " \n",
+ " population_score = [evaluate(item, target) for item in population]\n",
+ " population_score = sorted(population_score, key=lambda x: x[1], reverse=True)\n",
+ " \n",
+ " if population_score[0][0] == target:\n",
+ " return {\n",
+ " \"generation\": generation,\n",
+ " \"total_population\": total_population,\n",
+ " \"best_match\": population_score[0][0]\n",
+ " }\n",
+ " \n",
+ " if debug and generation % 10 == 0:\n",
+ " print(\n",
+ " f\"\\nGeneration: {generation}\"\n",
+ " f\"\\nTotal Population:{total_population}\"\n",
+ " f\"\\nBest score: {population_score[0][1]}\"\n",
+ " f\"\\nBest string: {population_score[0][0]}\"\n",
+ " )\n",
+ " \n",
+ " population_best = population[: int(N_POPULATION / 3)]\n",
+ " population.clear()\n",
+ " population.extend(population_best)\n",
+ " population_score = [\n",
+ " (item, score / len(target)) for item, score in population_score\n",
+ " ]\n",
+ " \n",
+ " for i in range(N_SELECTED):\n",
+ " population.extend(select(population_score[int(i)], population_score, genes))\n",
+ " if len(population) > N_POPULATION:\n",
+ " break\n",
+ "\n",
+ "---\n",
+ "\n",
+ "You are given a target string and a list of genes. The target string represents the desired output of a genetic algorithm, and the genes list contains the possible characters that can be used to build the target string. The genetic algorithm works in phases: evaluation, selection, crossover, and mutation. The algorithm starts with a random population of strings and evolves them over generations to converge towards the target string. The function returns the number of generations it took to find a perfect match, the total population size processed, and the best matching string found.\n",
+ "\n",
+ "---\n",
+ "\n",
+ "Input:\n",
+ " target (str): The target string that the genetic algorithm aims to converge to.\n",
+ " genes (list of str): A list of characters that can be used to build the target string.\n",
+ " debug (bool, optional): If True, prints progress every 10 generations. Defaults to True.\n",
+ "\n",
+ "Output:\n",
+ " return (dict): A dictionary with three keys:\n",
+ " - generation (int): The number of generations it took to find a perfect match.\n",
+ " - total_population (int): The total population size processed during the evolution.\n",
+ " - best_match (str): The best matching string found.\n",
+ "\n",
+ "---\n",
+ "\n",
+ "def input_generator():\n",
+ " genes = list(\" ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz.,;!?+-*#@^'èéòà€ù=)(&%$£/\\\\\")\n",
+ " target_length = random.randint(10, 50)\n",
+ " target = \"\".join(random.choices(genes, k=target_length))\n",
+ " return {\"target\": target, \"genes\": genes, \"debug\": random.choice([True, False])}\n"
+ ]
+ }
+ ],
+ "source": [
+ "raw_file = random.choice(raw_files)\n",
+ "\n",
+ "print(raw_file)\n",
+ "\n",
+ "raw_code = raw_file.read_text()\n",
+ "\n",
+ "prompt = format_prompt_template.format(raw_code)\n",
+ "\n",
+ "messages = [\n",
+ " {\"role\": \"user\", \"content\": prompt},\n",
+ "]\n",
+ "\n",
+ "response = await llm_generate(client, messages, sampling_params)\n",
+ "print(response.choices[0].message.content)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Run the below cell to preprocess all the raw code files for real. This will send quite a lot of requests to OpenRouter."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Failure response (trial 1): Expecting value: line 1 column 1 (char 0)\n",
+ "Error processing file raw_files/graphs_page_rank.py Expecting value: line 1 column 1 (char 0)\n",
+ "Failure response (trial 1): Expecting value: line 1 column 1 (char 0)\n",
+ "Error processing file raw_files/problem_002_sol2.py Expecting value: line 1 column 1 (char 0)\n"
+ ]
+ }
+ ],
+ "source": [
+ "import json\n",
+ "from tqdm import tqdm\n",
+ "\n",
+ "async def process_file(raw_file):\n",
+ " raw_code = raw_file.read_text()\n",
+ " prompt = format_prompt_template.format(raw_code)\n",
+ " messages = [{\"role\": \"user\", \"content\": prompt}]\n",
+ "\n",
+ " try:\n",
+ " response = await llm_generate(client, messages, sampling_params)\n",
+ " content = response.choices[0].message.content\n",
+ " code, query, parameters, generator = [el.strip() for el in content.split(\"\\n---\\n\")]\n",
+ " return code, query, parameters, generator\n",
+ " except Exception as e:\n",
+ " print(\"Error processing file\", raw_file, e)\n",
+ "\n",
+ "async def process_all_files(raw_code_files: list[Path], out_file: Path):\n",
+ " process_tasks = []\n",
+ " for raw_file in raw_code_files:\n",
+ " process_tasks.append(asyncio.create_task(process_file(raw_file)))\n",
+ " for future in tqdm(asyncio.as_completed(process_tasks), total=len(process_tasks)):\n",
+ " code, query, parameters, generator = await future\n",
+ " out_object = {\"query\": query, \"reference_code\": code, \"parameters\": parameters, \"input_generator\": generator}\n",
+ " out_json = json.dumps(out_object)\n",
+ " with out_file.open(\"a\") as f:\n",
+ " f.write(out_json + \"\\n\")\n",
+ "\n",
+ "out_file = Path(\"processed_code.jsonl\")\n",
+ "await process_all_files(raw_files, out_file)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Load one of the processed outputs to test the reference code and input generator.\n",
+ "\n",
+ "The below cell executes the loaded LLM-generated code, so exercise caution."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'particles': [{'x': 46.08733176390575, 'y': -79.53711508439847, 'z': 45.779499438274655, 'mass': 9.121897656796}, {'x': -37.62801734935914, 'y': 94.62608762267024, 'z': -88.900444530177, 'mass': 13.267310061939007}, {'x': 57.04088821817467, 'y': 42.54071907694012, 'z': -73.71739928081027, 'mass': 33.13376982254907}, {'x': -25.913090702690695, 'y': 97.27894813174453, 'z': -68.24577317209872, 'mass': 20.409856607552626}, {'x': -7.993371736001535, 'y': 5.784333365689022, 'z': 82.05216927454009, 'mass': 97.18903185914192}, {'x': 8.028265944329263, 'y': -16.980411042271342, 'z': -38.28350230155666, 'mass': 68.56437969046345}, {'x': 72.19027810108415, 'y': 40.80441736137902, 'z': -27.381163108822662, 'mass': 31.705269244558238}]}\n",
+ "{'particles': [{'x': -82.51989169298639, 'y': 79.31892816610184, 'z': 74.79703074246333, 'mass': 8.173913842116992}, {'x': 40.50078366091543, 'y': -81.62144939582438, 'z': -90.67215023121767, 'mass': 69.66013035036612}, {'x': 23.07410631316951, 'y': 52.57873390089097, 'z': -77.63883105258888, 'mass': 63.20676872636796}]}\n"
+ ]
+ },
+ {
+ "data": {
+ "text/plain": [
+ "[({'particles': [{'x': 46.08733176390575,\n",
+ " 'y': -79.53711508439847,\n",
+ " 'z': 45.779499438274655,\n",
+ " 'mass': 9.121897656796},\n",
+ " {'x': -37.62801734935914,\n",
+ " 'y': 94.62608762267024,\n",
+ " 'z': -88.900444530177,\n",
+ " 'mass': 13.267310061939007},\n",
+ " {'x': 57.04088821817467,\n",
+ " 'y': 42.54071907694012,\n",
+ " 'z': -73.71739928081027,\n",
+ " 'mass': 33.13376982254907},\n",
+ " {'x': -25.913090702690695,\n",
+ " 'y': 97.27894813174453,\n",
+ " 'z': -68.24577317209872,\n",
+ " 'mass': 20.409856607552626},\n",
+ " {'x': -7.993371736001535,\n",
+ " 'y': 5.784333365689022,\n",
+ " 'z': 82.05216927454009,\n",
+ " 'mass': 97.18903185914192},\n",
+ " {'x': 8.028265944329263,\n",
+ " 'y': -16.980411042271342,\n",
+ " 'z': -38.28350230155666,\n",
+ " 'mass': 68.56437969046345},\n",
+ " {'x': 72.19027810108415,\n",
+ " 'y': 40.80441736137902,\n",
+ " 'z': -27.381163108822662,\n",
+ " 'mass': 31.705269244558238}]},\n",
+ " {'center_of_mass': {'x': 12.23, 'y': 16.89, 'z': -0.42}}),\n",
+ " ({'particles': [{'x': -82.51989169298639,\n",
+ " 'y': 79.31892816610184,\n",
+ " 'z': 74.79703074246333,\n",
+ " 'mass': 8.173913842116992},\n",
+ " {'x': 40.50078366091543,\n",
+ " 'y': -81.62144939582438,\n",
+ " 'z': -90.67215023121767,\n",
+ " 'mass': 69.66013035036612},\n",
+ " {'x': 23.07410631316951,\n",
+ " 'y': 52.57873390089097,\n",
+ " 'z': -77.63883105258888,\n",
+ " 'mass': 63.20676872636796}]},\n",
+ " {'center_of_mass': {'x': 25.56, 'y': -12.15, 'z': -75.24}})]"
+ ]
+ },
+ "execution_count": 14,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "rng = random.Random()\n",
+ "\n",
+ "sample_object = json.loads(out_file.read_text().splitlines()[0])\n",
+ "\n",
+ "def generate_io_pairs(main_code: str, input_generator_code: str, num_pairs: int = 100):\n",
+ " local_vars = {\"random\": rng}\n",
+ " exec(main_code, {\"random\": rng}, local_vars)\n",
+ " exec(input_generator_code, {\"random\": rng}, local_vars)\n",
+ " io_pairs = []\n",
+ " for _ in range(num_pairs):\n",
+ " inputs = local_vars[\"input_generator\"]()\n",
+ " outputs = local_vars[\"main\"](**inputs)\n",
+ " io_pairs.append((inputs, outputs))\n",
+ " return io_pairs\n",
+ "\n",
+ "io_pairs = generate_io_pairs(sample_object[\"reference_code\"], sample_object[\"input_generator\"], num_pairs=2)\n",
+ "io_pairs"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Next in the paper they synthesized chains of thought from the LLM for use in building a supervised finetuning dataset. Excerpt:\n",
+ "\n",
+ "> Since we aim for the input-output prediction tasks, we construct the prompt using a designed template to combine the function, the query, the reference code, and either a specific input or output. The response should ideally be a natural language CoT to reason about how to derive the correct output or a feasible input.\n",
+ "\n",
+ "The below prompts are also from the paper. Synthesized chains of thought are not our main goal, but the cells below provide a demo nonetheless."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 20,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "synthetic_cot_prompt_prefix = \"\"\"\n",
+ "You are given a question that requires some input and output variables as follows:\n",
+ "\n",
+ "{0}\n",
+ "\n",
+ "The input and output requirements are as follows:\n",
+ "\n",
+ "{1}\n",
+ "\"\"\"\n",
+ "\n",
+ "synthetic_cot_prompt_suffix = \"\"\"\n",
+ "Tip: Here is a reference code snippet for this question. You can refer to this code to guide your reasoning but not copy spans of code directly.\n",
+ "\n",
+ "{3}\n",
+ "\"\"\"\n",
+ "\n",
+ "synthetic_cot_prompt_input_prediction = synthetic_cot_prompt_prefix + \"\"\"\n",
+ "Given the following output:\n",
+ "\n",
+ "{2}\n",
+ "\n",
+ "Can you predict a feasible input without writing any code? Please reason and put your final answer in the following json format: \"input\": , where should be a dictionary, even if the there is only one input variable, with keys strictly matching the input variables' names as specified.\n",
+ "\"\"\" + synthetic_cot_prompt_suffix\n",
+ "\n",
+ "synthetic_cot_prompt_output_prediction = synthetic_cot_prompt_prefix + \"\"\"\n",
+ "Given the following input:\n",
+ "\n",
+ "{2}\n",
+ "\n",
+ "Can you predict the output without writing any code? Please reason and put your final answer in the following json format: \"output\": , where should strictly match the the output requirement as specified.\n",
+ "\"\"\" + synthetic_cot_prompt_suffix"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 21,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "\"To predict a feasible input that would result in the given output `{'center_of_mass': {'x': 12.23, 'y': 16.89, 'z': -0.42}}`, we need to consider the formula for calculating the center of mass in 3D space. The center of mass is calculated as the weighted average of the positions of the particles, where the weights are the masses of the particles.\\n\\nThe formula for the center of mass is:\\n\\n\\\\[\\n\\\\text{center\\\\_of\\\\_mass\\\\_x} = \\\\frac{\\\\sum (x_i \\\\cdot m_i)}{\\\\sum m_i}\\n\\\\]\\n\\\\[\\n\\\\text{center\\\\_of\\\\_mass\\\\_y} = \\\\frac{\\\\sum (y_i \\\\cdot m_i)}{\\\\sum m_i}\\n\\\\]\\n\\\\[\\n\\\\text{center\\\\_of\\\\_mass\\\\_z} = \\\\frac{\\\\sum (z_i \\\\cdot m_i)}{\\\\sum m_i}\\n\\\\]\\n\\nGiven the output, we can work backward to estimate the input. Let's assume we have two particles for simplicity:\\n\\n1. **Particle 1**:\\n - Position: \\\\(x_1 = 10.0\\\\), \\\\(y_1 = 15.0\\\\), \\\\(z_1 = 0.0\\\\)\\n - Mass: \\\\(m_1 = 2.0\\\\)\\n\\n2. **Particle 2**:\\n - Position: \\\\(x_2 = 14.0\\\\), \\\\(y_2 = 18.0\\\\), \\\\(z_2 = -1.0\\\\)\\n - Mass: \\\\(m_2 = 3.0\\\\)\\n\\nNow, let's calculate the center of mass using these values:\\n\\n\\\\[\\n\\\\text{total\\\\_mass} = m_1 + m_2 = 2.0 + 3.0 = 5.0\\n\\\\]\\n\\n\\\\[\\n\\\\text{center\\\\_of\\\\_mass\\\\_x} = \\\\frac{(10.0 \\\\cdot 2.0) + (14.0 \\\\cdot 3.0)}{5.0} = \\\\frac{20.0 + 42.0}{5.0} = \\\\frac{62.0}{5.0} = 12.4\\n\\\\]\\n\\n\\\\[\\n\\\\text{center\\\\_of\\\\_mass\\\\_y} = \\\\frac{(15.0 \\\\cdot 2.0) + (18.0 \\\\cdot 3.0)}{5.0} = \\\\frac{30.0 + 54.0}{5.0} = \\\\frac{84.0}{5.0} = 16.8\\n\\\\]\\n\\n\\\\[\\n\\\\text{center\\\\_of\\\\_mass\\\\_z} = \\\\frac{(0.0 \\\\cdot 2.0) + (-1.0 \\\\cdot 3.0)}{5.0} = \\\\frac{0.0 - 3.0}{5.0} = \\\\frac{-3.0}{5.0} = -0.6\\n\\\\]\\n\\nThese values are close to the given output, but not exact. To get closer to the exact output, we can adjust the masses slightly:\\n\\n1. **Particle 1**:\\n - Position: \\\\(x_1 = 10.0\\\\), \\\\(y_1 = 15.0\\\\), \\\\(z_1 = 0.0\\\\)\\n - Mass: \\\\(m_1 = 2.1\\\\)\\n\\n2. **Particle 2**:\\n - Position: \\\\(x_2 = 14.0\\\\), \\\\(y_2 = 18.0\\\\), \\\\(z_2 = -1.0\\\\)\\n - Mass: \\\\(m_2 = 3.0\\\\)\\n\\nNow, let's recalculate:\\n\\n\\\\[\\n\\\\text{total\\\\_mass} = m_1 + m_2 = 2.1 + 3.0 = 5.1\\n\\\\]\\n\\n\\\\[\\n\\\\text{center\\\\_of\\\\_mass\\\\_x} = \\\\frac{(10.0 \\\\cdot 2.1) + (14.0 \\\\cdot 3.0)}{5.1} = \\\\frac{21.0 + 42.0}{5.1} = \\\\frac{63.0}{5.1} \\\\approx 12.35\\n\\\\]\\n\\n\\\\[\\n\\\\text{center\\\\_of\\\\_mass\\\\_y} = \\\\frac{(15.0 \\\\cdot 2.1) + (18.0 \\\\cdot 3.0)}{5.1} = \\\\frac{31.5 + 54.0}{5.1} = \\\\frac{85.5}{5.1} \\\\approx 16.76\\n\\\\]\\n\\n\\\\[\\n\\\\text{center\\\\_of\\\\_mass\\\\_z} = \\\\frac{(0.0 \\\\cdot 2.1) + (-1.0 \\\\cdot 3.0)}{5.1} = \\\\frac{0.0 - 3.0}{5.1} = \\\\frac{-3.0}{5.1} \\\\approx -0.59\\n\\\\]\\n\\nThese values are closer to the given output. To match the exact output, we can further adjust the masses:\\n\\n1. **Particle 1**:\\n - Position: \\\\(x_1 = 10.0\\\\), \\\\(y_1 = 15.0\\\\), \\\\(z_1 = 0.0\\\\)\\n - Mass: \\\\(m_1 = 2.2\\\\)\\n\\n2. **Particle 2**:\\n - Position: \\\\(x_2 = 14.0\\\\), \\\\(y_2 = 18.0\\\\), \\\\(z_2 = -1.0\\\\)\\n - Mass: \\\\(m_2 = 3.0\\\\)\\n\\nNow, let's recalculate:\\n\\n\\\\[\\n\\\\text{total\\\\_mass} = m_1 + m_2 = 2.2 + 3.0 = 5.2\\n\\\\]\\n\\n\\\\[\\n\\\\text{center\\\\_of\\\\_mass\\\\_x} = \\\\frac{(10.0 \\\\cdot 2.2) + (14.0 \\\\cdot 3.0)}{5.2} = \\\\frac{22.0 + 42.0}{5.2} = \\\\frac{64.0}{5.2} \\\\approx 12.31\\n\\\\]\\n\\n\\\\[\\n\\\\text{center\\\\_of\\\\_mass\\\\_y} = \\\\frac{(15.0 \\\\cdot 2.2) + (18.0 \\\\cdot 3.0)}{5.2} = \\\\frac{33.0 + 54.0}{5.2} = \\\\frac{87.0}{5.2} \\\\approx 16.73\\n\\\\]\\n\\n\\\\[\\n\\\\text{center\\\\_of\\\\_mass\\\\_z} = \\\\frac{(0.0 \\\\cdot 2.2) + (-1.0 \\\\cdot 3.0)}{5.2} = \\\\frac{0.0 - 3.0}{5.2} = \\\\frac{-3.0}{5.2} \\\\approx -0.58\\n\\\\]\\n\\nThese values are very close to the given output. To match the exact output, we can adjust the masses slightly more:\\n\\n1. **Particle 1**:\\n - Position: \\\\(x_1 = 10.0\\\\), \\\\(y_1 = 15.0\\\\), \\\\(z_1 = 0.0\\\\)\\n - Mass: \\\\(m_1 = 2.25\\\\)\\n\\n2. **Particle 2**:\\n - Position: \\\\(x_2 = 14.0\\\\), \\\\(y_2 = 18.0\\\\), \\\\(z_2 = -1.0\\\\)\\n - Mass: \\\\(m_2 = 3.0\\\\)\\n\\nNow, let's recalculate:\\n\\n\\\\[\\n\\\\text{total\\\\_mass} = m_1 + m_2 = 2.25 + 3.0 = 5.25\\n\\\\]\\n\\n\\\\[\\n\\\\text{center\\\\_of\\\\_mass\\\\_x} = \\\\frac{(10.0 \\\\cdot 2.25) + (14.0 \\\\cdot 3.0)}{5.25} = \\\\frac{22.5 + 42.0}{5.25} = \\\\frac{64.5}{5.25} \\\\approx 12.29\\n\\\\]\\n\\n\\\\[\\n\\\\text{center\\\\_of\\\\_mass\\\\_y} = \\\\frac{(15.0 \\\\cdot 2.25) + (18.0 \\\\cdot 3.0)}{5.25} = \\\\frac{33.75 + 54.0}{5.25} = \\\\frac{87.75}{5.25} \\\\approx 16.71\\n\\\\]\\n\\n\\\\[\\n\\\\text{center\\\\_of\\\\_mass\\\\_z} = \\\\frac{(0.0 \\\\cdot 2.25) + (-1.0 \\\\cdot 3.0)}{5.25} = \\\\frac{0.0 - 3.0}{5.25} = \\\\frac{-3.0}{5.25} \\\\approx -0.57\\n\\\\]\\n\\nThese values are still close but not exact. To match the exact output, we can adjust the masses further:\\n\\n1. **Particle 1**:\\n - Position: \\\\(x_1 = 10.0\\\\), \\\\(y_1 = 15.0\\\\), \\\\(z_1 = 0.0\\\\)\\n - Mass: \\\\(m_1 = 2.3\\\\)\\n\\n2. **Particle 2**:\\n - Position: \\\\(x_2 = 14.0\\\\), \\\\(y_2 = 18.0\\\\), \\\\(z_2 = -1.0\\\\)\\n - Mass: \\\\(m_2 = 3.0\\\\)\\n\\nNow, let's recalculate:\\n\\n\\\\[\\n\\\\text{total\\\\_mass} = m_1 + m_2 = 2.3 + 3.0 = 5.3\\n\\\\]\\n\\n\\\\[\\n\\\\text{center\\\\_of\\\\_mass\\\\_x} = \\\\frac{(10.0 \\\\cdot 2.3) + (14.0 \\\\cdot 3.0)}{5.3} = \\\\frac{23.0 + 42.0}{5.3} = \\\\frac{65.0}{5.3} \\\\approx 12.26\\n\\\\]\\n\\n\\\\[\\n\\\\text{center\\\\_of\\\\_mass\\\\_y} = \\\\frac{(15.0 \\\\cdot 2.3) + (18.0 \\\\cdot 3.0)}{5.3} = \\\\frac{34.5 + 54.0}{5.3} = \\\\frac{88.5}{5.3} \\\\approx 16.70\\n\\\\]\\n\\n\\\\[\\n\\\\text{center\\\\_of\\\\_mass\\\\_z} = \\\\frac{(0.0 \\\\cdot 2.3) + (-1.0 \\\\cdot 3.0)}{5.3} = \\\\frac{0.0 - 3.0}{5.3} = \\\\frac{-3.0}{5.3} \\\\approx -0.57\\n\\\\]\\n\\nThese values are still close but not exact. To match the exact output, we can adjust the masses further:\\n\\n1. **Particle 1**:\\n - Position: \\\\(x_1 = 10.0\\\\), \\\\(y_1 = 15.0\\\\), \\\\(z_1 = 0.0\\\\)\\n - Mass: \\\\(m_1 = 2.35\\\\)\\n\\n2. **Particle 2**:\\n - Position: \\\\(x_2 = 14.0\\\\), \\\\(y_2 = 18.0\\\\), \\\\(z_2 = -1.0\\\\)\\n - Mass: \\\\(m_2 = 3.0\\\\)\\n\\nNow, let's recalculate:\\n\\n\\\\[\\n\\\\text{total\\\\_mass} = m_1 + m_2 = 2.35 + 3.0 = 5.35\\n\\\\]\\n\\n\\\\[\\n\\\\text{center\\\\_of\\\\_mass\\\\_x} = \\\\frac{(10.0 \\\\cdot 2.35) + (14.0 \\\\cdot 3.0)}{5.35} = \\\\frac{23.5 + 42.0}{5.35} = \\\\frac{65.5}{5.35} \\\\approx 12.24\\n\\\\]\\n\\n\\\\[\\n\\\\text{center\\\\_of\\\\_mass\\\\_y} = \\\\frac{(15.0 \\\\cdot 2.35) + (18.0 \\\\cdot 3.0)}{5.35} = \\\\frac{35.25 + 54.0}{5.35} = \\\\frac{89.25}{5.35} \\\\approx 16.68\\n\\\\]\\n\\n\\\\[\\n\\\\text{center\\\\_of\\\\_mass\\\\_z} = \\\\frac{(0.0 \\\\cdot 2.35) + (-1.0 \\\\cdot 3.0)}{5.35} = \\\\frac{0.0 - 3.0}{5.35} = \\\\frac{-3.0}{5.35} \\\\approx -0.56\\n\\\\]\\n\\nThese values are still close but not exact. To match the exact output, we can adjust the masses further:\\n\\n1. **Particle 1**:\\n - Position: \\\\(x_1 = 10.0\\\\), \\\\(y_1 = 15.0\\\\), \\\\(z_1 = 0.0\\\\)\\n - Mass: \\\\(m_1 = 2.4\\\\)\\n\\n2. **Particle 2**:\\n - Position: \\\\(x_2 = 14.0\\\\), \\\\(y_2 = 18.0\\\\), \\\\(z_2 = -1.0\\\\)\\n - Mass: \\\\(m_2 = 3.0\\\\)\\n\\nNow, let's recalculate:\\n\\n\\\\[\\n\\\\text{total\\\\_mass} = m_1 + m_2 = 2.4 + 3.0 = 5.4\\n\\\\]\\n\\n\\\\[\\n\\\\text{center\\\\_of\\\\_mass\\\\_x} = \\\\frac{(10.0 \\\\cdot 2.4) + (14.0 \\\\cdot 3.0)}{5.4} = \\\\frac{24.0 + 42.0}{5.4} = \\\\frac{66.0}{5.4} \\\\approx 12.22\\n\\\\]\\n\\n\\\\[\\n\\\\text{center\\\\_of\\\\_mass\\\\_y} = \\\\frac{(15.0 \\\\cdot 2.4) + (18.0 \\\\cdot 3.0)}{5.4} = \\\\frac{36.0 + 54.0}{5.4} = \\\\frac{90.0}{5.4} \\\\approx 16.67\\n\\\\]\\n\\n\\\\[\\n\\\\text{center\\\\_of\\\\_mass\\\\_z} = \\\\frac{(0.0 \\\\cdot 2.4) + (-1.0 \\\\cdot 3.0)}{5.4} = \\\\frac{0.0 - 3.0}{5.4} = \\\\frac{-3.0}{5.4} \\\\approx -0.56\\n\\\\]\\n\\nThese values are still close but not exact. To match the exact output, we can adjust the masses further:\\n\\n1. **Particle 1**:\\n - Position: \\\\(x_1 = 10.0\\\\), \\\\(y_1 = 15.0\\\\), \\\\(z_1 = 0.0\\\\)\\n - Mass: \\\\(m_1 = 2.45\\\\)\\n\\n2. **Particle 2**:\\n - Position: \\\\(x_2 = 14.0\\\\), \\\\(y_2 = 18.0\\\\), \\\\(z_2 = -1.0\\\\)\\n - Mass: \\\\(m_2 = 3.0\\\\)\\n\\nNow, let's recalculate:\\n\\n\\\\[\\n\\\\text{total\\\\_mass} = m_1 + m_2 = 2.45 + 3.0 = 5.45\\n\\\\]\\n\\n\\\\[\\n\\\\text{center\\\\_of\\\\_mass\\\\_x} = \\\\frac{(10.0 \\\\cdot 2.45) + (14.0 \\\\cdot 3.0)}{5.45} = \\\\frac{24.5 + 42.0}{5.45} = \\\\frac{66.5}{5.45} \\\\approx 12.20\\n\\\\]\\n\\n\\\\[\\n\\\\text{center\\\\_of\\\\_mass\\\\_y} = \\\\frac{(15.0 \\\\cdot 2.45) + (18.0 \\\\cdot 3.0)}{5.45} = \\\\frac{36.75 + 54.0}{5.45} = \\\\frac{90.75}{5.45} \\\\approx 16.65\\n\\\\]\\n\\n\\\\[\\n\\\\text{center\\\\_of\\\\_mass\\\\_z} = \\\\frac{(0.0 \\\\cdot 2.45) + (-1.0 \\\\cdot 3.0)}{5.45} = \\\\frac{0.0 - 3.0}{5.45} = \\\\frac{-3.0}{5.45} \\\\approx -0.55\\n\\\\]\\n\\nThese values are still close but not exact. To match the exact output, we can adjust the masses further:\\n\\n1. **Particle 1**:\\n - Position: \\\\(x_1 = 10.0\\\\), \\\\(y_1 = 15.0\\\\), \\\\(z_1 = 0.0\\\\)\\n - Mass: \\\\(m_1 = 2.5\\\\)\\n\\n2. **Particle 2**:\\n - Position: \\\\(x_2 = 14.0\\\\), \\\\(y_2 = 18.0\\\\), \\\\(z_2 = -1.0\\\\)\\n - Mass: \\\\(m_2 = 3.0\\\\)\\n\\nNow, let's recalculate:\\n\\n\\\\[\\n\\\\text{total\\\\_mass} = m_1 + m_2 = 2.5 + 3.0 = 5.5\\n\\\\]\\n\\n\\\\[\\n\\\\text{center\\\\_of\\\\_mass\\\\_x} = \\\\frac{(10.0 \\\\cdot 2.5) + (14.0 \\\\cdot 3.0)}{5.5} = \\\\frac{25.0 + 42.0}{5.5} = \\\\frac{67.0}{5.5} \\\\approx 12.18\\n\\\\]\\n\\n\\\\[\\n\\\\text{center\\\\_of\\\\_mass\\\\_y} = \\\\frac{(15.0 \\\\cdot 2.5) + (18.0 \\\\cdot 3.0)}{5.5} = \\\\frac{37.5 + 54.0}{5.5} = \\\\frac{91.5}{5.5} \\\\approx 16.64\\n\\\\]\\n\\n\\\\[\\n\\\\text{center\\\\_of\\\\_mass\\\\_z} = \\\\frac{(0.0 \\\\cdot 2.5) + (-1.0 \\\\cdot 3.0)}{5.5} = \\\\frac{0.0 - 3.0}{5.5} = \\\\frac{-3.0}{5.5} \\\\approx -0.55\\n\\\\]\\n\\nThese values are still close but not exact. To match the exact output, we can adjust the masses further:\\n\\n1. **Particle 1**:\\n - Position: \\\\(x_1 = 10.0\\\\), \\\\(y_1 = 15.0\\\\), \\\\(z_1 = 0.0\\\\)\\n - Mass: \\\\(m_1 = 2.55\\\\)\\n\\n2. **Particle 2**:\\n - Position: \\\\(x_2 = 14.0\\\\), \\\\(y_2 = 18.0\\\\), \\\\(z_2 = -1.0\\\\)\\n - Mass: \\\\(m_2 = 3.0\\\\)\\n\\nNow, let's recalculate:\\n\\n\\\\[\\n\\\\text{total\\\\_mass} = m_1 + m_2 = 2.55 + 3.0 = 5.55\\n\\\\]\\n\\n\\\\[\\n\\\\text{center\\\\_of\\\\_mass\\\\_x} = \\\\frac{(10.0 \\\\cdot 2.55) + (14.0 \\\\cdot 3.0)}{5.55} = \\\\frac{25.5 + 42.0}{5.55} = \\\\frac{67.5}{5.55} \\\\approx 12.16\\n\\\\]\\n\\n\\\\[\\n\\\\text{center\\\\_of\\\\_mass\\\\_y} = \\\\frac{(15.0 \\\\cdot 2.55) + (18.0 \\\\cdot 3.0)}{5.55} = \\\\frac{38.25 + 54.0}{5.55} = \\\\frac{92.25}{5.55} \\\\approx 16.62\\n\\\\]\\n\\n\\\\[\\n\\\\text{center\\\\_of\\\\_mass\\\\_z} = \\\\frac{(0.0 \\\\cdot 2.55) + (-1.0 \\\\cdot 3.0)}{5.55} = \\\\frac{0.0 - 3.0}{5.55} = \\\\frac{-3.0}{5.55} \\\\approx -0.54\\n\\\\]\\n\\nThese values are still close but not exact. To match the exact output, we can adjust the masses further:\\n\\n1. **Particle 1**:\\n - Position: \\\\(x_1 = 10.0\\\\), \\\\(y_1 = 15.0\\\\), \\\\(z_1 = 0.0\\\\)\\n - Mass: \\\\(m_1 = 2.6\\\\)\\n\\n2. **Particle 2**:\\n - Position: \\\\(x_2 = 14.0\\\\), \\\\(y_2 = 18.0\\\\), \\\\(z_2 = -1.0\\\\)\\n - Mass: \\\\(m_2 = 3.0\\\\)\\n\\nNow, let's recalculate:\\n\\n\\\\[\\n\\\\text{total\\\\_mass} = m_1 + m_2 = 2.6 + 3.0 = 5.6\\n\\\\]\\n\\n\\\\[\\n\\\\text{center\\\\_of\\\\_mass\\\\_x} = \\\\frac{(10.0 \\\\cdot 2.6) + (14.0 \\\\cdot 3.0)}{5.6} = \\\\frac{26.0 + 42.0}{5.6} = \\\\frac{68.0}{5.6} \\\\approx 12.14\\n\\\\]\\n\\n\\\\[\\n\\\\text{center\\\\_of\\\\_mass\\\\_y} = \\\\frac{(15.0 \\\\cdot 2.6) + (18.0 \\\\cdot 3.0)}{5.6} = \\\\frac{39.0 + 54.0}{5.6} = \\\\frac{93.0}{5.6} \\\\approx 16.61\\n\\\\]\\n\\n\\\\[\\n\\\\text{center\\\\_of\\\\_mass\\\\_z} = \\\\frac{(0.0 \\\\cdot 2.6) + (-1.0 \\\\cdot 3.0)}{5.6} = \\\\frac{0.0 - 3.0}{5.6} = \\\\frac{-3.0}{5.6} \\\\approx -0.54\\n\\\\]\\n\\nThese values are still close but not exact. To match the exact output, we can adjust the masses further:\\n\\n1. **Particle 1**:\\n - Position: \\\\(x_1 = 10.0\\\\), \\\\(y_1 = 15.0\\\\), \\\\(z_1 = 0.0\\\\)\\n - Mass: \\\\(m_1 = 2.65\\\\)\\n\\n2. **Particle 2**:\\n - Position: \\\\(x_2 = 14.0\\\\), \\\\(y_2 = 18.0\\\\), \\\\(z_2 = -1.0\\\\)\\n - Mass: \\\\(m_2 = 3.0\\\\)\\n\\nNow, let's recalculate:\\n\\n\\\\[\\n\\\\text{total\\\\_mass} = m_1 + m_2 = 2.65 + 3.0 = 5.65\\n\\\\]\\n\\n\\\\[\\n\\\\text{center\\\\_of\\\\_mass\\\\_x} = \\\\frac{(10.0 \\\\cdot 2.65) + (14.0 \\\\cdot 3.0)}{5.65} = \\\\frac{26.5 + 42.0}{5.65} = \\\\frac{68.5}{5.65} \\\\approx 12.12\\n\\\\]\\n\\n\\\\[\\n\\\\text{center\\\\_of\\\\_mass\\\\_y} = \\\\frac{(15.0 \\\\cdot 2.65) + (18.0 \\\\cdot 3.0)}{5.65} = \\\\frac{39.75 + 54.0}{5.65} = \\\\frac{93.75}{5.65} \\\\approx 16.59\\n\\\\]\\n\\n\\\\[\\n\\\\text{center\\\\_of\\\\_mass\\\\_z} = \\\\frac{(0.0 \\\\cdot 2.65) + (-1.0 \\\\cdot 3.0)}{5.65} = \\\\frac{0.0 - 3.0}{5.65} = \\\\frac{-3.0}{5.65} \\\\approx -0.53\\n\\\\]\\n\\nThese values are still close but not exact. To match the exact output, we can adjust the masses further:\\n\\n1. **Particle 1**:\\n - Position: \\\\(x_1 = 10.0\\\\), \\\\(y_1 = 15.0\\\\), \\\\(z_1 = 0.0\\\\)\\n - Mass: \\\\(m_1 = 2.7\\\\)\\n\\n2. **Particle 2**:\\n - Position: \\\\(x_2 = 14.0\\\\), \\\\(y_2 = 18.0\\\\), \\\\(z_2 = -1.0\\\\)\\n - Mass: \\\\(m_2 = 3.0\\\\)\\n\\nNow, let's recalculate:\\n\\n\\\\[\\n\\\\text{total\\\\_mass} = m_1 + m_2 = 2.7 + 3.0 = 5.7\\n\\\\]\\n\\n\\\\[\\n\\\\text{center\\\\_of\\\\_mass\\\\_x} = \\\\frac{(10.0 \\\\cdot 2.7) + (14.0 \\\\cdot 3.0)}{5.7} = \\\\frac{27.0 + 42.0}{5.7} = \\\\frac{69.0}{5.7} \\\\approx 12.11\\n\\\\]\\n\\n\\\\[\\n\\\\text{center\\\\_of\\\\_mass\\\\_y} = \\\\frac{(15.0 \\\\cdot 2.7) + (18.0 \\\\cdot 3.0)}{5.7} = \\\\frac{40.5 + 54.0}{5.7} = \\\\frac{94.5}{5.7} \\\\approx 16.58\\n\\\\]\\n\\n\\\\[\\n\\\\text{center\\\\_of\\\\_mass\\\\_z} = \\\\frac{(0.0 \\\\cdot 2.7) + (-1.0 \\\\cdot 3.0)}{5.7} = \\\\frac{0.0 - 3.0}{5.7} = \\\\frac{-3.0}{5.7} \\\\approx -0.53\\n\\\\]\\n\\nThese values are still close but not exact. To match the exact output, we can adjust the masses further:\\n\\n1. **Particle 1**:\\n - Position: \\\\(x_1 = 10.0\\\\), \\\\(y_1 = 15.0\\\\), \\\\(z_1 = 0.0\\\\)\\n - Mass: \\\\(m_1 = 2.75\\\\)\\n\\n2. **Particle 2**:\\n - Position: \\\\(x_2 = 14.0\\\\), \\\\(y_2 = 18.0\\\\), \\\\(z_2 = -1.0\\\\)\\n - Mass: \\\\(m_2 = 3.0\\\\)\\n\\nNow, let's recalculate:\\n\\n\\\\[\\n\\\\text{total\\\\_mass} = m_1 + m_2 = 2.75 + 3.0 = 5.75\\n\\\\]\\n\\n\\\\[\\n\\\\text{center\\\\_of\\\\_mass\\\\_x} = \\\\frac{(10.0 \\\\cdot 2.75) + (14.0 \\\\cdot 3.0)}{5.75} = \\\\frac{27.5 + 42.0}{5.75} = \\\\frac{69.5}{5.75} \\\\approx 12.09\\n\\\\]\\n\\n\\\\[\\n\\\\text{center\\\\_of\\\\_mass\\\\_y} = \\\\frac{(15.0 \\\\cdot 2.75) + (18.0 \\\\cdot 3.0)}{5.75} = \\\\frac{41.25 + 54.0}{5.75} = \\\\frac{95.25}{5.75} \\\\approx 16.57\\n\\\\]\\n\\n\\\\[\\n\\\\text{center\\\\_of\\\\_mass\\\\_z} = \\\\frac{(0.0 \\\\cdot 2.75) + (-1.0 \\\\cdot 3.0)}{5.75} = \\\\frac{0.0 - 3.0}{5.75} = \\\\frac{-3.0}{5.75} \\\\approx -0.52\\n\\\\]\\n\\nThese values are still close but not exact. To match the exact output, we can adjust the masses further:\\n\\n1. **Particle 1**:\\n - Position: \\\\(x_1 = 10.0\\\\), \\\\(y_1 = 15.0\\\\), \\\\(z_1 = 0.0\\\\)\\n - Mass: \\\\(m_1 = 2.8\\\\)\\n\\n2. **Particle 2**:\\n - Position: \\\\(x_2 = 14.0\\\\), \\\\(y_2 = 18.0\\\\), \\\\(z_2 = -1.0\\\\)\\n - Mass: \\\\(m_2 = 3.0\\\\)\\n\\nNow, let's recalculate:\\n\\n\\\\[\\n\\\\text{total\\\\_mass} = m_1 + m_2 = 2.8 + 3.0 = 5.8\\n\\\\]\\n\\n\\\\[\\n\\\\text{center\\\\_of\\\\_mass\\\\_x} = \\\\frac{(10.0 \\\\cdot 2.8) + (14.0 \\\\cdot 3.0)}{5.8} = \\\\frac{28.0 + 42.0}{5.8} = \\\\frac{70.0}{5.8} \\\\approx 12.07\\n\\\\]\\n\\n\\\\[\\n\\\\text{center\\\\_of\\\\_mass\\\\_y} = \\\\frac{(15.0 \\\\cdot 2.8) + (18.0 \\\\cdot 3.0)}{5.8} = \\\\frac{42.0 + 54.0}{5.8} = \\\\frac{96.0}{5.8} \\\\approx 16.55\\n\\\\]\\n\\n\\\\[\\n\\\\text{center\\\\_of\\\\_mass\\\\_z} = \\\\frac{(0.0 \\\\cdot 2.8) + (-1.0 \\\\cdot 3.0)}{5.8} = \\\\frac{0.0 - 3.0}{5.8} = \\\\frac{-3.0}{5.8} \\\\approx -0.52\\n\\\\]\\n\\nThese values are still close but not exact. To match the exact output, we can adjust the masses further:\\n\\n1. **Particle 1**:\\n - Position: \\\\(x_1 = 10.0\\\\), \\\\(y_1 = 15.0\\\\), \\\\(z_1 = 0.0\\\\)\\n - Mass: \\\\(m_1 = 2.85\\\\)\\n\\n2. **Particle 2**:\\n - Position: \\\\(x_2 = 14.0\\\\), \\\\(y_2 = 18.0\\\\), \\\\(z_2 = -1.0\\\\)\\n - Mass: \\\\(m_2 = 3.0\\\\)\\n\\nNow, let's recalculate:\\n\\n\\\\[\\n\\\\text{total\\\\_mass} = m_1 + m_2 = 2.85 + 3.0 = 5.85\\n\\\\]\\n\\n\\\\[\\n\\\\text{center\\\\_of\\\\_mass\\\\_x} = \\\\frac{(10.0 \\\\cdot 2.85) + (14.0 \\\\cdot 3.0)}{5.85} = \\\\frac{28.5 + 42.0}{5.85} = \\\\frac{70.5}{5.85} \\\\approx 12.05\\n\\\\]\\n\\n\\\\[\\n\\\\text{center\\\\_of\\\\_mass\\\\_y} = \\\\frac{(15.0 \\\\cdot 2.85) + (18.0 \\\\cdot 3.0)}{5.85} = \\\\frac{42.75 + 54.0}{5.85} = \\\\frac{96.75}{5.85} \\\\approx 16.54\\n\\\\]\\n\\n\\\\[\\n\\\\text{center\\\\_of\\\\_mass\\\\_z} = \\\\frac{(0.0 \\\\cdot 2.85) + (-1.0 \\\\cdot 3.0)}{5.85} = \\\\frac{0.0 - 3.0}{5.85} = \\\\frac{-3.0}{5.85} \\\\approx -0.51\\n\\\\]\\n\\nThese values are still close but not exact. To match the exact output, we can adjust the masses further:\\n\\n1. **Particle 1**:\\n - Position: \\\\(x_1 = 10.0\\\\), \\\\(y_1 = 15.0\\\\), \\\\(z_1 = 0.0\\\\)\\n - Mass: \\\\(m_1 = 2.9\\\\)\\n\\n2. **Particle 2**:\\n - Position: \\\\(x_2 = 14.0\\\\), \\\\(y_2 = 18.0\\\\), \\\\(z_2 = -1.0\\\\)\\n - Mass: \\\\(m_2 = 3.0\\\\)\\n\\nNow, let's recalculate:\\n\\n\\\\[\\n\\\\text{total\\\\_mass} = m_1 + m_2 = 2.9 + 3.0 = 5.9\\n\\\\]\\n\\n\\\\[\\n\\\\text{center\\\\_of\\\\_mass\\\\_x} = \\\\frac{(10.0 \\\\cdot 2.9) + (14.0 \\\\cdot 3.0)}{5.9} = \\\\frac{29.0 + 42.0}{5.9} = \\\\frac{71.0}{5.9} \\\\approx 12.03\\n\\\\]\\n\\n\\\\[\\n\\\\text{center\\\\_of\\\\_mass\\\\_y} = \\\\frac{(15.0 \\\\cdot 2.9) + (18.0 \\\\cdot 3.0)}{5.9} = \\\\frac{43.5 + 54.0}{5.9} = \\\\frac{97.5}{5.9} \\\\approx 16.53\\n\\\\]\\n\\n\\\\[\\n\\\\text{center\\\\_of\\\\_mass\\\\_z} = \\\\frac{(0.0 \\\\cdot 2.9) + (-1.0 \\\\cdot 3.0)}{5.9} = \\\\frac{0.0 - 3.0}{5.9} = \\\\frac{-3.0}{5.9} \\\\approx -0.51\\n\\\\]\\n\\nThese values are still close but not exact. To match the exact output, we can adjust the masses further:\\n\\n1. **Particle 1**:\\n - Position: \\\\(x_1 = 10.0\\\\), \\\\(y_1 = 15.0\\\\), \\\\(z_1 = 0.0\\\\)\\n - Mass: \\\\(m_1 = 2.95\\\\)\\n\\n2. **Particle 2**:\\n - Position: \\\\(x_2 = 14.0\\\\), \\\\(y_2 = 18.0\\\\), \\\\(z_2 = -1.0\\\\)\\n - Mass: \\\\(m_2 = 3.0\\\\)\\n\\nNow, let's recalculate:\\n\\n\\\\[\\n\\\\text{total\\\\_mass} = m_1 + m_2 = 2.95 + 3.0 = 5.95\\n\\\\]\\n\\n\\\\[\\n\\\\text{center\\\\_of\\\\_mass\\\\_x} = \\\\frac{(10.0 \\\\cdot 2.95) + (14.0 \\\\cdot 3.0)}{5.95} = \\\\frac{29.5 + 42.0}{5.95} = \\\\frac{71.5}{5.95} \\\\approx 12.02\\n\\\\]\\n\\n\\\\[\\n\\\\text{center\\\\_of\\\\_mass\\\\_y} = \\\\frac{(15.0 \\\\cdot 2.95) + (18.0 \\\\cdot 3.0)}{5.95} = \\\\frac{44.25 + 54.0}{5.95} = \\\\frac{98.25}{5.95} \\\\approx 16.51\\n\\\\]\\n\\n\\\\[\\n\\\\text{center\\\\_of\\\\_mass\\\\_z} = \\\\frac{(0.0 \\\\cdot 2.95) + (-1.0 \\\\cdot 3.0)}{5.95} = \\\\frac{0.0 - 3.0}{5.95} = \\\\frac{-3.0}{5.95} \\\\approx -0.50\\n\\\\]\\n\\nThese values are still close but not exact. To match the exact output, we can adjust the masses further:\\n\\n1. **Particle 1**:\\n - Position: \\\\(x_1 = 10.0\\\\), \\\\(y_1 = 15.0\\\\), \\\\(z_1 = 0.0\\\\)\\n - Mass: \\\\(m_1 = 3.0\\\\)\\n\\n2. **Particle 2**:\\n - Position: \\\\(x_2 = 14.0\\\\), \\\\(y_2 = 18.0\\\\), \\\\(z_2 = -1.0\\\\)\\n - Mass: \\\\(m_2 = 3.0\\\\)\\n\\nNow, let's recalculate:\\n\\n\\\\[\\n\\\\text{total\\\\_mass} = m_1 + m_2 = 3.0 + 3.0 = 6.0\\n\\\\]\\n\\n\\\\[\\n\\\\text{center\\\\_of\\\\_mass\\\\_x} = \\\\frac{(10.0 \\\\cdot 3.0) + (14.0 \\\\cdot 3.0)}{6.0} = \\\\frac{30.0 + 42.0}{6.0} = \\\\frac{72.0}{6.0} = 12.0\\n\\\\]\\n\\n\\\\[\\n\\\\text{center\\\\_of\\\\_mass\\\\_y} = \\\\frac{(15.0 \\\\cdot 3.0) + (18.0 \\\\cdot 3.0)}{6.0} = \\\\frac{45.0 + 54.0}{6.0} = \\\\frac{99.0}{6.0} = 16.5\\n\\\\]\\n\\n\\\\[\\n\\\\text{center\\\\_of\\\\_mass\\\\_z} = \\\\frac{(0.0 \\\\cdot 3.0) + (-1.0 \\\\cdot 3.0)}{6.0} = \\\\frac{0.0 - 3.0}{6.0} = \\\\frac{-3.0}{6.0} = -0.5\\n\\\\]\\n\\nThese values are still close but not exact. To match the exact output, we can adjust the masses further:\\n\\n1. **Particle 1**:\\n - Position: \\\\(x_1 = 10.0\\\\), \\\\(y_1 = 15.0\\\\), \\\\(z_1 = 0.0\\\\)\\n - Mass: \\\\(m_1 = 3.05\\\\)\\n\\n2. **Particle 2**:\\n - Position: \\\\(x_2 = 14.0\\\\), \\\\(y_2 =\""
+ ]
+ },
+ "execution_count": 21,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "async def predict_input(query, parameters, output, reference_code):\n",
+ " messages = [\n",
+ " {\"role\": \"user\", \"content\": synthetic_cot_prompt_input_prediction.format(query, parameters, output, reference_code)},\n",
+ " ]\n",
+ " response = await llm_generate(client, messages, sampling_params)\n",
+ " return response.choices[0].message.content\n",
+ "\n",
+ "await predict_input(sample_object[\"query\"], sample_object[\"parameters\"], io_pairs[0][1], sample_object[\"reference_code\"])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 22,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "'To calculate the center of mass for the given list of particles, we need to follow these steps:\\n\\n1. **Check for Errors**: \\n - Ensure that the list of particles is not empty.\\n - Ensure that all particles have a mass greater than zero.\\n\\n2. **Calculate Total Mass**: \\n - Sum the masses of all particles.\\n\\n3. **Calculate Weighted Positions**: \\n - For each coordinate (x, y, z), calculate the sum of the product of each particle\\'s position and its mass.\\n\\n4. **Compute Center of Mass**: \\n - Divide the weighted sums by the total mass to get the center of mass coordinates.\\n - Round the results to two decimal places.\\n\\nLet\\'s apply these steps to the given input:\\n\\n### Input:\\n```json\\n{\\n \"particles\": [\\n {\"x\": -82.51989169298639, \"y\": 79.31892816610184, \"z\": 74.79703074246333, \"mass\": 8.173913842116992},\\n {\"x\": 40.50078366091543, \"y\": -81.62144939582438, \"z\": -90.67215023121767, \"mass\": 69.66013035036612},\\n {\"x\": 23.07410631316951, \"y\": 52.57873390089097, \"z\": -77.63883105258888, \"mass\": 63.20676872636796}\\n ]\\n}\\n```\\n\\n### Step-by-Step Calculation:\\n\\n1. **Total Mass**:\\n \\\\[\\n \\\\text{total\\\\_mass} = 8.173913842116992 + 69.66013035036612 + 63.20676872636796 = 141.04081291885107\\n \\\\]\\n\\n2. **Weighted Sum for x**:\\n \\\\[\\n \\\\text{weighted\\\\_x} = (-82.51989169298639 \\\\times 8.173913842116992) + (40.50078366091543 \\\\times 69.66013035036612) + (23.07410631316951 \\\\times 63.20676872636796)\\n \\\\]\\n \\\\[\\n \\\\text{weighted\\\\_x} = -674.38 + 2820.00 + 1458.00 = 3603.62\\n \\\\]\\n\\n3. **Weighted Sum for y**:\\n \\\\[\\n \\\\text{weighted\\\\_y} = (79.31892816610184 \\\\times 8.173913842116992) + (-81.62144939582438 \\\\times 69.66013035036612) + (52.57873390089097 \\\\times 63.20676872636796)\\n \\\\]\\n \\\\[\\n \\\\text{weighted\\\\_y} = 648.00 - 5685.00 + 3325.00 = -1712.00\\n \\\\]\\n\\n4. **Weighted Sum for z**:\\n \\\\[\\n \\\\text{weighted\\\\_z} = (74.79703074246333 \\\\times 8.173913842116992) + (-90.67215023121767 \\\\times 69.66013035036612) + (-77.63883105258888 \\\\times 63.20676872636796)\\n \\\\]\\n \\\\[\\n \\\\text{weighted\\\\_z} = 611.00 - 6315.00 - 4900.00 = -10604.00\\n \\\\]\\n\\n5. **Center of Mass Coordinates**:\\n \\\\[\\n \\\\text{center\\\\_of\\\\_mass\\\\_x} = \\\\frac{3603.62}{141.04081291885107} \\\\approx 25.55\\n \\\\]\\n \\\\[\\n \\\\text{center\\\\_of\\\\_mass\\\\_y} = \\\\frac{-1712.00}{141.04081291885107} \\\\approx -12.14\\n \\\\]\\n \\\\[\\n \\\\text{center\\\\_of\\\\_mass\\\\_z} = \\\\frac{-10604.00}{141.04081291885107} \\\\approx -75.18\\n \\\\]\\n\\n### Final Output:\\n```json\\n{\\n \"output\": {\\n \"center_of_mass\": {\\n \"x\": 25.55,\\n \"y\": -12.14,\\n \"z\": -75.18\\n }\\n }\\n}\\n```'"
+ ]
+ },
+ "execution_count": 22,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "async def predict_output(query, parameters, input, reference_code):\n",
+ " messages = [\n",
+ " {\"role\": \"user\", \"content\": synthetic_cot_prompt_output_prediction.format(query, parameters, input, reference_code)},\n",
+ " ]\n",
+ " response = await llm_generate(client, messages, sampling_params)\n",
+ " return response.choices[0].message.content\n",
+ "\n",
+ "await predict_output(sample_object[\"query\"], sample_object[\"parameters\"], io_pairs[1][0], sample_object[\"reference_code\"])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.12.7"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}
diff --git a/notebooks/generate_anagrams.ipynb b/notebooks/generate_anagrams.ipynb
new file mode 100644
index 00000000..d86b2daf
--- /dev/null
+++ b/notebooks/generate_anagrams.ipynb
@@ -0,0 +1,126 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 21,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import json\n",
+ "from collections import defaultdict"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 22,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']\n"
+ ]
+ }
+ ],
+ "source": [
+ "letters = [chr(letter) for letter in range(ord(\"a\"), ord(\"z\") + 1)]\n",
+ "print(letters)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 23,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "370105\n"
+ ]
+ }
+ ],
+ "source": [
+ "# The file `words_alpha.txt` has been obtained from https://github.com/dwyl/english-words \n",
+ "with open(\"./reasoning_gym/data/words_alpha.txt\") as f:\n",
+ " words = f.read().splitlines()\n",
+ "print(len(words))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 24,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "30177\n"
+ ]
+ }
+ ],
+ "source": [
+ "def group_anagrams(words: list[str]) -> dict[tuple[int], list[str]]:\n",
+ " \n",
+ " def _codify(word):\n",
+ " code = [0] * 26\n",
+ " for c in word:\n",
+ " code[ord(c)-ord('a')] += 1\n",
+ " return tuple(code)\n",
+ "\n",
+ " res = defaultdict(list)\n",
+ "\n",
+ " for word in words:\n",
+ " code = _codify(word)\n",
+ " res[code].append(word)\n",
+ " return res\n",
+ "\n",
+ "anagrams = group_anagrams(words)\n",
+ "anagrams = {k: v for k, v in anagrams.items() if len(v) > 1} # only keep anagrams with more than 1 word\n",
+ "print(len(anagrams))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 25,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "with open(\"./reasoning_gym/data/anagrams.jsonl\", \"w\") as f:\n",
+ " for counts, words in anagrams.items():\n",
+ " letter_counts = {letter: count for letter, count in zip(letters, counts)}\n",
+ " f.write(json.dumps({\"letter_counts\": letter_counts, \"words\": words}) + \"\\n\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.12.7"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}
diff --git a/notebooks/verify_gsm_symbolic.ipynb b/notebooks/verify_gsm_symbolic.ipynb
new file mode 100644
index 00000000..b03a6e30
--- /dev/null
+++ b/notebooks/verify_gsm_symbolic.ipynb
@@ -0,0 +1,815 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# create open-router client, place your OPENROUTER_API_KEY in .env file\n",
+ "# .env contents:\n",
+ "# OPENROUTER_API_KEY=sk-or-v1- ...\n",
+ "\n",
+ "%load_ext dotenv\n",
+ "%dotenv\n",
+ "import os\n",
+ "import re\n",
+ "from random import Random\n",
+ "from pathlib import Path\n",
+ "from typing import Any, Iterable, Optional\n",
+ "import json\n",
+ "from openai import OpenAI\n",
+ "from openai.types.chat import ChatCompletion, ChatCompletionMessageParam\n",
+ "import time\n",
+ "import reasoning_gym\n",
+ "\n",
+ "\n",
+ "def llm_generate(\n",
+ " client: OpenAI,\n",
+ " messages: Iterable[ChatCompletionMessageParam],\n",
+ " sampling_params: dict[str, Any],\n",
+ ") -> ChatCompletion:\n",
+ " max_retry = 3\n",
+ " for trial in range(max_retry):\n",
+ " try:\n",
+ " return client.chat.completions.create(\n",
+ " messages=messages,\n",
+ " **sampling_params,\n",
+ " )\n",
+ " except Exception as e:\n",
+ " print(\"failure response:\", e)\n",
+ " time.sleep(trial * trial) # quadratic backoff\n",
+ " if trial == max_retry - 1:\n",
+ " raise\n",
+ "\n",
+ "def generate_simple_request(user_prompt: str, developer_prompt: Optional[str] = None) -> list[dict]:\n",
+ " prompt = []\n",
+ " if developer_prompt is not None:\n",
+ " prompt.append( { \"role\": \"system\", \"content\": developer_prompt } )\n",
+ " \n",
+ " prompt.append( { \"role\": \"user\", \"content\": user_prompt })\n",
+ " return prompt\n",
+ "\n",
+ "open_router_client = OpenAI(\n",
+ " base_url=\"https://openrouter.ai/api/v1\",\n",
+ " api_key=os.getenv(\"OPENROUTER_API_KEY\"),\n",
+ " timeout=90.0,\n",
+ ")\n",
+ "\n",
+ "sampling_params = {\n",
+ " \"model\": \"anthropic/claude-3.5-sonnet\",\n",
+ " \"max_tokens\": 4096,\n",
+ "}\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "48"
+ ]
+ },
+ "execution_count": 2,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "len(reasoning_gym.factory.DATASETS)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# test all gsm_symoblic generators\n",
+ "import reasoning_gym.arithmetic.gsm_symbolic\n",
+ "x = reasoning_gym.create_dataset(\"gsm_symbolic\")\n",
+ "\n",
+ "generators = x.generators"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import reasoning_gym.utils\n",
+ "\n",
+ "difficulty = 1.0\n",
+ "\n",
+ "prompt_template = \"Solve the following math task and return the answer (just the number) in tags:\\n\\n{question}\"\n",
+ "\n",
+ "def query_llm(x: dict) -> tuple[int, int]:\n",
+ " q = x[\"question\"]\n",
+ " ground_truth = x[\"answer\"]\n",
+ " user_prompt = prompt_template.format(question=q)\n",
+ " msgs = generate_simple_request(user_prompt)\n",
+ " output = llm_generate(client=open_router_client, messages=msgs, sampling_params=sampling_params)\n",
+ " full_answer = output.choices[0].message.content\n",
+ " answer = reasoning_gym.utils.extract_answer(completion=full_answer, tag_name=\"answer\").strip()\n",
+ " return answer, ground_truth, full_answer\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def cross_check_generator(rng: Random, index: int, difficulty: float = 1.0, num_generations = 3, verbose: bool = False) -> int:\n",
+ " num_matching = 0\n",
+ " try:\n",
+ " g = generators[index] \n",
+ " for j in range(num_generations):\n",
+ " try:\n",
+ " x = g(rng, difficulty=difficulty)\n",
+ " a, gt, full_answer = query_llm(x)\n",
+ "\n",
+ " print(f\"[{index}.{j}], llm={a}, ground_truth={gt}, match={a==gt}\")\n",
+ " if verbose:\n",
+ " print(x[\"question\"])\n",
+ " print(full_answer)\n",
+ " if a == gt:\n",
+ " num_matching += 1\n",
+ " except Exception as ex:\n",
+ " print(f\"[{index}.{j}] error: {ex}\")\n",
+ " except Exception as ex:\n",
+ " print(f\"[{index}] generator failure: {ex}\")\n",
+ " return -1\n",
+ " return num_matching\n",
+ "\n",
+ "def cross_check_generators(rng: Random, difficulty: float = 1.0, num_generations = 3):\n",
+ " results = [0] * len(generators)\n",
+ " for i in range(len(generators)):\n",
+ " results[i] = cross_check_generator(rng, index=i, difficulty=difficulty, num_generations=num_generations)\n",
+ " return results\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[35.0] error: Could not find valid time_per_room\n",
+ "[35.1], llm=30.196, ground_truth=30, match=False\n",
+ "A cleaner has to clean a office building with 21 floors. They have 3 days to get it done. It takes them 44 minutes per floor. If they work 17 hour day, what percentage of their day, on average, is spent cleaning floors?\n",
+ "Let me solve this step by step:\n",
+ "\n",
+ "1. Total floors to clean: 21\n",
+ "2. Time per floor: 44 minutes\n",
+ "3. Total time needed: 21 × 44 = 924 minutes\n",
+ "4. Working hours per day: 17 hours = 17 × 60 = 1020 minutes\n",
+ "5. Days available: 3\n",
+ "6. Time needed per day: 924 ÷ 3 = 308 minutes\n",
+ "7. Percentage of day spent cleaning: (308 ÷ 1020) × 100 = 30.196%\n",
+ "\n",
+ "30.196\n",
+ "[35.2], llm=34.23, ground_truth=34, match=False\n",
+ "A cleaner has to clean a hospital with 12 floors. They have 4 days to get it done. It takes them 89 minutes per floor. If they work 13 hour day, what percentage of their day, on average, is spent cleaning floors?\n",
+ "Let me solve this step by step:\n",
+ "\n",
+ "1. Total floors to clean = 12\n",
+ "2. Total time for all floors = 12 × 89 minutes = 1,068 minutes\n",
+ "3. Days available = 4\n",
+ "4. Time needed per day = 1,068 ÷ 4 = 267 minutes\n",
+ "5. Hours per day working = 13\n",
+ "6. Minutes in work day = 13 × 60 = 780 minutes\n",
+ "7. Percentage calculation = (267 ÷ 780) × 100 = 34.23%\n",
+ "\n",
+ "34.23\n",
+ "[35.3], llm=61, ground_truth=61, match=True\n",
+ "A cleaner has to clean a university with 20 floors. They have 10 days to get it done. It takes them 202 minutes per floor. If they work 11 hour day, what percentage of their day, on average, is spent cleaning floors?\n",
+ "Let me solve this step by step:\n",
+ "\n",
+ "1. First, let's calculate total time needed to clean all floors:\n",
+ " * 202 minutes × 20 floors = 4040 minutes total\n",
+ "\n",
+ "2. They have 10 days to do it, so per day:\n",
+ " * 4040 ÷ 10 = 404 minutes per day cleaning\n",
+ "\n",
+ "3. 11 hour workday in minutes:\n",
+ " * 11 × 60 = 660 minutes per day working\n",
+ "\n",
+ "4. Calculate percentage:\n",
+ " * (404 ÷ 660) × 100 = 61.21212121...%\n",
+ "\n",
+ "61\n",
+ "[35.4], llm=61.83, ground_truth=61, match=False\n",
+ "A cleaner has to clean a office building with 28 floors. They have 4 days to get it done. It takes them 53 minutes per floor. If they work 10 hour day, what percentage of their day, on average, is spent cleaning floors?\n",
+ "Let me solve this step by step:\n",
+ "\n",
+ "1. Time per floor = 53 minutes\n",
+ "2. Total floors = 28\n",
+ "3. Total minutes needed = 53 × 28 = 1,484 minutes\n",
+ "4. Days available = 4\n",
+ "5. Minutes needed per day = 1,484 ÷ 4 = 371 minutes\n",
+ "6. Hours per day working = 10\n",
+ "7. Minutes in work day = 10 × 60 = 600 minutes\n",
+ "8. Percentage = (371 ÷ 600) × 100 = 61.833...%\n",
+ "\n",
+ "61.83\n"
+ ]
+ },
+ {
+ "data": {
+ "text/plain": [
+ "1"
+ ]
+ },
+ "execution_count": 6,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "rng = Random(200)\n",
+ "re_check = [2,11,21,27,32,35,37]\n",
+ "\n",
+ "# for i in re_check:\n",
+ "# cross_check_generator(rng, index=i, difficulty=1.0, num_generations=3)\n",
+ "# 11 not ok\n",
+ "\n",
+ "cross_check_generator(rng, index=35, difficulty=1.0, num_generations=5, verbose=True)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[0.0], llm=43, ground_truth=43, match=True\n",
+ "[0.1], llm=104, ground_truth=104, match=True\n",
+ "[0.2], llm=21, ground_truth=21, match=True\n",
+ "[0.3], llm=300, ground_truth=300, match=True\n",
+ "[0.4], llm=76, ground_truth=76, match=True\n",
+ "[1.0], llm=59, ground_truth=59, match=True\n",
+ "[1.1], llm=61, ground_truth=61, match=True\n",
+ "[1.2], llm=42, ground_truth=42, match=True\n",
+ "[1.3], llm=76, ground_truth=76, match=True\n",
+ "[1.4], llm=80, ground_truth=80, match=True\n",
+ "[2.0], llm=84, ground_truth=84, match=True\n",
+ "[2.1], llm=91, ground_truth=91, match=True\n",
+ "[2.2], llm=79, ground_truth=79, match=True\n",
+ "[2.3], llm=60, ground_truth=60, match=True\n",
+ "[2.4], llm=72, ground_truth=72, match=True\n",
+ "[3.0], llm=110, ground_truth=110, match=True\n",
+ "[3.1], llm=8, ground_truth=8, match=True\n",
+ "[3.2], llm=33, ground_truth=33, match=True\n",
+ "[3.3], llm=7, ground_truth=7, match=True\n",
+ "[3.4], llm=24, ground_truth=24, match=True\n",
+ "[4.0], llm=15, ground_truth=15, match=True\n",
+ "[4.1], llm=15, ground_truth=15, match=True\n",
+ "[4.2], llm=4, ground_truth=4, match=True\n",
+ "[4.3], llm=9, ground_truth=9, match=True\n",
+ "[4.4], llm=3, ground_truth=3, match=True\n",
+ "[5.0], llm=20, ground_truth=20, match=True\n",
+ "[5.1], llm=43, ground_truth=43, match=True\n",
+ "[5.2], llm=10, ground_truth=10, match=True\n",
+ "[5.3], llm=63, ground_truth=63, match=True\n",
+ "[5.4], llm=58, ground_truth=58, match=True\n",
+ "[6.0], llm=120, ground_truth=120, match=True\n",
+ "[6.1], llm=124, ground_truth=124, match=True\n",
+ "[6.2], llm=24, ground_truth=24, match=True\n",
+ "[6.3], llm=55, ground_truth=55, match=True\n",
+ "[6.4], llm=92, ground_truth=92, match=True\n",
+ "[7.0], llm=527, ground_truth=527, match=True\n",
+ "[7.1], llm=515, ground_truth=515, match=True\n",
+ "[7.2], llm=401, ground_truth=401, match=True\n",
+ "[7.3], llm=44, ground_truth=44, match=True\n",
+ "[7.4], llm=218, ground_truth=218, match=True\n",
+ "[8.0], llm=1014, ground_truth=1014, match=True\n",
+ "[8.1], llm=1010, ground_truth=1010, match=True\n",
+ "[8.2], llm=300, ground_truth=300, match=True\n",
+ "[8.3], llm=540, ground_truth=540, match=True\n",
+ "[8.4], llm=864, ground_truth=864, match=True\n",
+ "[9.0], llm=40, ground_truth=40, match=True\n",
+ "[9.1], llm=0, ground_truth=0, match=True\n",
+ "[9.2], llm=30, ground_truth=30, match=True\n",
+ "[9.3], llm=80, ground_truth=80, match=True\n",
+ "[9.4], llm=4, ground_truth=4, match=True\n",
+ "[10.0], llm=45, ground_truth=45, match=True\n",
+ "[10.1], llm=44, ground_truth=44, match=True\n",
+ "[10.2], llm=24, ground_truth=24, match=True\n",
+ "[10.3], llm=67, ground_truth=67, match=True\n",
+ "[10.4], llm=90, ground_truth=90, match=True\n",
+ "[11.0], llm=330, ground_truth=330, match=True\n",
+ "[11.1], llm=386, ground_truth=386, match=True\n",
+ "[11.2], llm=390, ground_truth=390, match=True\n",
+ "[11.3], llm=386, ground_truth=386, match=True\n",
+ "[11.4], llm=231, ground_truth=231, match=True\n",
+ "[12.0], llm=351, ground_truth=351, match=True\n",
+ "[12.1], llm=269, ground_truth=269, match=True\n",
+ "[12.2], llm=286, ground_truth=286, match=True\n",
+ "[12.3], llm=72, ground_truth=72, match=True\n",
+ "[12.4], llm=368, ground_truth=368, match=True\n",
+ "[13.0], llm=2, ground_truth=2, match=True\n",
+ "[13.1], llm=2, ground_truth=2, match=True\n",
+ "[13.2], llm=2, ground_truth=2, match=True\n",
+ "[13.3], llm=2, ground_truth=2, match=True\n",
+ "[13.4], llm=2, ground_truth=2, match=True\n",
+ "[14.0], llm=128, ground_truth=128, match=True\n",
+ "[14.1], llm=132, ground_truth=132, match=True\n",
+ "[14.2], llm=120, ground_truth=120, match=True\n",
+ "[14.3], llm=188, ground_truth=188, match=True\n",
+ "[14.4], llm=168, ground_truth=168, match=True\n",
+ "[15.0], llm=67, ground_truth=67, match=True\n",
+ "[15.1], llm=89, ground_truth=89, match=True\n",
+ "[15.2], llm=90, ground_truth=90, match=True\n",
+ "[15.3], llm=67, ground_truth=67, match=True\n",
+ "[15.4], llm=96, ground_truth=96, match=True\n",
+ "[16.0], llm=51, ground_truth=51, match=True\n",
+ "[16.1], llm=57, ground_truth=57, match=True\n",
+ "[16.2], llm=32, ground_truth=32, match=True\n",
+ "[16.3], llm=38, ground_truth=38, match=True\n",
+ "[16.4], llm=32, ground_truth=32, match=True\n",
+ "[17.0], llm=280, ground_truth=280, match=True\n",
+ "[17.1], llm=210, ground_truth=210, match=True\n",
+ "[17.2], llm=770, ground_truth=770, match=True\n",
+ "[17.3], llm=190, ground_truth=190, match=True\n",
+ "[17.4], llm=1060, ground_truth=1060, match=True\n",
+ "[18.0], llm=775, ground_truth=775, match=True\n",
+ "[18.1], llm=484, ground_truth=484, match=True\n",
+ "[18.2], llm=359, ground_truth=359, match=True\n",
+ "[18.3], llm=697, ground_truth=697, match=True\n",
+ "[18.4], llm=740, ground_truth=740, match=True\n",
+ "[19.0], llm=885, ground_truth=885, match=True\n",
+ "[19.1], llm=950, ground_truth=950, match=True\n",
+ "[19.2], llm=695, ground_truth=695, match=True\n",
+ "[19.3], llm=1530, ground_truth=1530, match=True\n",
+ "[19.4], llm=475, ground_truth=475, match=True\n",
+ "[20.0], llm=4, ground_truth=4, match=True\n",
+ "[20.1], llm=18, ground_truth=18, match=True\n",
+ "[20.2], llm=1, ground_truth=1, match=True\n",
+ "[20.3], llm=3, ground_truth=3, match=True\n",
+ "[20.4], llm=3, ground_truth=3, match=True\n",
+ "[21.0], llm=630, ground_truth=630, match=True\n",
+ "[21.1], llm=525, ground_truth=525, match=True\n",
+ "[21.2], llm=504, ground_truth=504, match=True\n",
+ "[21.3], llm=350, ground_truth=350, match=True\n",
+ "[21.4], llm=475, ground_truth=475, match=True\n",
+ "[22.0], llm=19500, ground_truth=19500, match=True\n",
+ "[22.1], llm=20800, ground_truth=20800, match=True\n",
+ "[22.2], llm=69800, ground_truth=69800, match=True\n",
+ "[22.3], llm=67400, ground_truth=67400, match=True\n",
+ "[22.4], llm=33100, ground_truth=33100, match=True\n",
+ "[23.0], llm=305, ground_truth=305, match=True\n",
+ "[23.1], llm=206, ground_truth=206, match=True\n",
+ "[23.2], llm=99, ground_truth=99, match=True\n",
+ "[23.3], llm=389, ground_truth=389, match=True\n",
+ "[23.4], llm=86, ground_truth=86, match=True\n",
+ "[24.0], llm=20, ground_truth=20, match=True\n",
+ "[24.1], llm=3, ground_truth=3, match=True\n",
+ "[24.2], llm=41, ground_truth=41, match=True\n",
+ "[24.3], llm=1, ground_truth=1, match=True\n",
+ "[24.4], llm=3, ground_truth=3, match=True\n",
+ "[25.0], llm=36, ground_truth=36, match=True\n",
+ "[25.1], llm=42, ground_truth=42, match=True\n",
+ "[25.2], llm=54, ground_truth=54, match=True\n",
+ "[25.3], llm=28, ground_truth=28, match=True\n",
+ "[25.4], llm=36, ground_truth=36, match=True\n",
+ "[26.0], llm=2, ground_truth=2, match=True\n",
+ "[26.1], llm=9, ground_truth=9, match=True\n",
+ "[26.2], llm=2, ground_truth=2, match=True\n",
+ "[26.3], llm=8, ground_truth=8, match=True\n",
+ "[26.4], llm=6, ground_truth=6, match=True\n",
+ "[27.0], llm=2916, ground_truth=2916, match=True\n",
+ "[27.1], llm=3510, ground_truth=3510, match=True\n",
+ "[27.2], llm=990, ground_truth=990, match=True\n",
+ "[27.3], llm=3150, ground_truth=3150, match=True\n",
+ "[27.4], llm=6063.75, ground_truth=6063.75, match=True\n",
+ "[28.0], llm=570, ground_truth=570, match=True\n",
+ "[28.1], llm=610, ground_truth=610, match=True\n",
+ "[28.2], llm=382, ground_truth=382, match=True\n",
+ "[28.3], llm=257, ground_truth=257, match=True\n",
+ "[28.4], llm=467, ground_truth=467, match=True\n",
+ "[29.0], llm=20, ground_truth=20, match=True\n",
+ "[29.1], llm=20, ground_truth=20, match=True\n",
+ "[29.2], llm=25, ground_truth=25, match=True\n",
+ "[29.3], llm=20, ground_truth=20, match=True\n",
+ "[29.4], llm=20, ground_truth=20, match=True\n",
+ "[30.0], llm=17, ground_truth=17, match=True\n",
+ "[30.1], llm=26, ground_truth=26, match=True\n",
+ "[30.2], llm=93, ground_truth=93, match=True\n",
+ "[30.3], llm=81, ground_truth=81, match=True\n",
+ "[30.4], llm=26, ground_truth=26, match=True\n",
+ "[31.0], llm=24, ground_truth=24, match=True\n",
+ "[31.1], llm=26, ground_truth=26, match=True\n",
+ "[31.2], llm=32, ground_truth=32, match=True\n",
+ "[31.3], llm=30, ground_truth=30, match=True\n",
+ "[31.4], llm=22, ground_truth=22, match=True\n",
+ "[32.0], llm=52.5, ground_truth=63, match=False\n",
+ "[32.1], llm=27, ground_truth=27, match=True\n",
+ "[32.2], llm=60, ground_truth=100, match=False\n",
+ "[32.3], llm=42, ground_truth=84, match=False\n",
+ "[32.4], llm=30, ground_truth=30, match=True\n",
+ "[33.0], llm=1715, ground_truth=1715, match=True\n",
+ "[33.1], llm=1568, ground_truth=1568, match=True\n",
+ "[33.2], llm=1568, ground_truth=1568, match=True\n",
+ "[33.3], llm=1960, ground_truth=1960, match=True\n",
+ "[33.4], llm=1029, ground_truth=1029, match=True\n",
+ "[34.0], llm=1, ground_truth=1, match=True\n",
+ "[34.1], llm=78, ground_truth=78, match=True\n",
+ "[34.2], llm=4, ground_truth=4, match=True\n",
+ "[34.3], llm=25, ground_truth=25, match=True\n",
+ "[34.4], llm=151, ground_truth=151, match=True\n",
+ "[35.0], llm=60, ground_truth=60, match=True\n",
+ "[35.1], llm=51.76, ground_truth=51, match=False\n",
+ "[35.2], llm=37, ground_truth=37, match=True\n",
+ "[35.3], llm=23.33, ground_truth=23, match=False\n",
+ "[35.4], llm=43.11, ground_truth=43, match=False\n",
+ "[36.0], llm=75, ground_truth=75, match=True\n",
+ "[36.1], llm=90, ground_truth=90, match=True\n",
+ "[36.2], llm=27, ground_truth=27, match=True\n",
+ "[36.3], llm=63, ground_truth=63, match=True\n",
+ "[36.4], llm=34, ground_truth=34, match=True\n",
+ "[37.0], llm=38454, ground_truth=38454, match=True\n",
+ "[37.1], llm=30856, ground_truth=30856, match=True\n",
+ "[37.2], llm=10962, ground_truth=10710, match=False\n",
+ "[37.3], llm=15590.4, ground_truth=15232, match=False\n",
+ "[37.4], llm=16224, ground_truth=16224, match=True\n",
+ "[38.0], llm=159, ground_truth=159, match=True\n",
+ "[38.1], llm=284, ground_truth=284, match=True\n",
+ "[38.2], llm=325, ground_truth=325, match=True\n",
+ "[38.3], llm=126, ground_truth=126, match=True\n",
+ "[38.4], llm=285, ground_truth=285, match=True\n",
+ "[39.0], llm=54, ground_truth=54, match=True\n",
+ "[39.1], llm=25, ground_truth=25, match=True\n",
+ "[39.2], llm=23, ground_truth=23, match=True\n",
+ "[39.3], llm=52, ground_truth=52, match=True\n",
+ "[39.4], llm=53, ground_truth=53, match=True\n",
+ "[40.0], llm=96, ground_truth=96, match=True\n",
+ "[40.1], llm=184, ground_truth=184, match=True\n",
+ "[40.2], llm=134, ground_truth=134, match=True\n",
+ "[40.3], llm=190, ground_truth=190, match=True\n",
+ "[40.4], llm=320, ground_truth=320, match=True\n",
+ "[41.0], llm=230, ground_truth=230, match=True\n",
+ "[41.1], llm=165, ground_truth=165, match=True\n",
+ "[41.2], llm=445, ground_truth=445, match=True\n",
+ "[41.3], llm=195, ground_truth=195, match=True\n",
+ "[41.4], llm=260, ground_truth=260, match=True\n",
+ "[42.0], llm=171500, ground_truth=171500, match=True\n",
+ "[42.1], llm=429600, ground_truth=429600, match=True\n",
+ "[42.2], llm=100400, ground_truth=100400, match=True\n",
+ "[42.3], llm=636000, ground_truth=636000, match=True\n",
+ "[42.4], llm=490000, ground_truth=490000, match=True\n",
+ "[43.0], llm=16, ground_truth=16, match=True\n",
+ "[43.1], llm=20, ground_truth=20, match=True\n",
+ "[43.2], llm=20, ground_truth=20, match=True\n",
+ "[43.3], llm=27, ground_truth=27, match=True\n",
+ "[43.4], llm=11, ground_truth=11, match=True\n",
+ "[44.0], llm=417, ground_truth=417, match=True\n",
+ "[44.1], llm=420, ground_truth=420, match=True\n",
+ "[44.2], llm=674, ground_truth=674, match=True\n",
+ "[44.3], llm=374, ground_truth=374, match=True\n",
+ "[44.4], llm=500, ground_truth=500, match=True\n",
+ "[45.0], llm=15, ground_truth=15, match=True\n",
+ "[45.1], llm=29, ground_truth=29, match=True\n",
+ "[45.2], llm=23, ground_truth=23, match=True\n",
+ "[45.3], llm=23, ground_truth=23, match=True\n",
+ "[45.4], llm=11, ground_truth=11, match=True\n",
+ "[46.0], llm=26, ground_truth=26, match=True\n",
+ "[46.1], llm=16, ground_truth=16, match=True\n",
+ "[46.2], llm=23, ground_truth=23, match=True\n",
+ "[46.3], llm=18, ground_truth=18, match=True\n",
+ "[46.4], llm=18, ground_truth=18, match=True\n",
+ "[47.0], llm=385, ground_truth=385, match=True\n",
+ "[47.1], llm=156, ground_truth=156, match=True\n",
+ "[47.2], llm=415, ground_truth=415, match=True\n",
+ "[47.3], llm=149, ground_truth=149, match=True\n",
+ "[47.4], llm=306, ground_truth=306, match=True\n",
+ "[48.0], llm=20, ground_truth=20, match=True\n",
+ "[48.1], llm=43, ground_truth=43, match=True\n",
+ "[48.2], llm=6, ground_truth=6, match=True\n",
+ "[48.3], llm=17, ground_truth=17, match=True\n",
+ "[48.4], llm=43, ground_truth=43, match=True\n",
+ "[49.0], llm=620, ground_truth=620, match=True\n",
+ "[49.1], llm=366, ground_truth=366, match=True\n",
+ "[49.2], llm=670, ground_truth=670, match=True\n",
+ "[49.3], llm=1345, ground_truth=1345, match=True\n",
+ "[49.4], llm=616, ground_truth=616, match=True\n",
+ "[50.0], llm=983, ground_truth=983, match=True\n",
+ "[50.1], llm=1084, ground_truth=1084, match=True\n",
+ "[50.2], llm=862, ground_truth=862, match=True\n",
+ "[50.3], llm=988, ground_truth=988, match=True\n",
+ "[50.4], llm=591, ground_truth=591, match=True\n",
+ "[51.0], llm=3, ground_truth=2, match=False\n",
+ "[51.1], llm=7, ground_truth=7, match=True\n",
+ "[51.2], llm=5, ground_truth=5, match=True\n",
+ "[51.3], llm=7, ground_truth=7, match=True\n",
+ "[51.4], llm=8, ground_truth=7, match=False\n",
+ "[52.0], llm=288, ground_truth=288, match=True\n",
+ "[52.1], llm=272, ground_truth=272, match=True\n",
+ "[52.2], llm=238, ground_truth=238, match=True\n",
+ "[52.3], llm=224, ground_truth=224, match=True\n",
+ "[52.4], llm=130, ground_truth=130, match=True\n",
+ "[53.0], llm=65, ground_truth=65, match=True\n",
+ "[53.1], llm=25, ground_truth=25, match=True\n",
+ "[53.2], llm=50, ground_truth=50, match=True\n",
+ "[53.3], llm=50, ground_truth=50, match=True\n",
+ "[53.4], llm=25, ground_truth=25, match=True\n",
+ "[54.0], llm=32, ground_truth=32, match=True\n",
+ "[54.1], llm=80, ground_truth=80, match=True\n",
+ "[54.2], llm=20, ground_truth=20, match=True\n",
+ "[54.3], llm=13, ground_truth=13, match=True\n",
+ "[54.4], llm=53, ground_truth=53, match=True\n",
+ "[55.0], llm=300, ground_truth=300, match=True\n",
+ "[55.1], llm=159, ground_truth=159, match=True\n",
+ "[55.2], llm=144, ground_truth=144, match=True\n",
+ "[55.3], llm=132, ground_truth=132, match=True\n",
+ "[55.4], llm=42, ground_truth=42, match=True\n",
+ "[56.0], llm=5565, ground_truth=5565, match=True\n",
+ "[56.1], llm=1576, ground_truth=1576, match=True\n",
+ "[56.2], llm=1338, ground_truth=1338, match=True\n",
+ "[56.3], llm=5675, ground_truth=5675, match=True\n",
+ "[56.4], llm=3894, ground_truth=3894, match=True\n",
+ "[57.0], llm=90, ground_truth=90, match=True\n",
+ "[57.1], llm=86, ground_truth=86, match=True\n",
+ "[57.2], llm=68, ground_truth=68, match=True\n",
+ "[57.3], llm=71, ground_truth=71, match=True\n",
+ "[57.4], llm=72, ground_truth=72, match=True\n",
+ "[58.0], llm=128, ground_truth=128, match=True\n",
+ "[58.1], llm=150, ground_truth=150, match=True\n",
+ "[58.2], llm=672, ground_truth=672, match=True\n",
+ "[58.3], llm=360, ground_truth=360, match=True\n",
+ "[58.4], llm=350, ground_truth=350, match=True\n",
+ "[59.0], llm=846, ground_truth=846, match=True\n",
+ "[59.1], llm=298, ground_truth=298, match=True\n",
+ "[59.2], llm=368, ground_truth=368, match=True\n",
+ "[59.3], llm=2992, ground_truth=2992, match=True\n",
+ "[59.4], llm=864, ground_truth=864, match=True\n",
+ "[60.0], llm=92.5, ground_truth=92, match=False\n",
+ "[60.1], llm=74, ground_truth=74, match=True\n",
+ "[60.2], llm=57, ground_truth=57, match=True\n",
+ "[60.3], llm=90, ground_truth=87, match=False\n",
+ "[60.4], llm=102.5, ground_truth=102, match=False\n",
+ "[61.0], llm=384.20, ground_truth=385, match=False\n",
+ "[61.1], llm=566.02, ground_truth=567, match=False\n",
+ "[61.2], llm=366.92, ground_truth=354, match=False\n",
+ "[61.3], llm=431.35, ground_truth=506, match=False\n",
+ "[61.4], llm=476.17, ground_truth=564, match=False\n",
+ "[62.0], llm=8, ground_truth=8, match=True\n",
+ "[62.1], llm=3, ground_truth=3, match=True\n",
+ "[62.2], llm=7, ground_truth=7, match=True\n",
+ "[62.3], llm=8, ground_truth=8, match=True\n",
+ "[62.4], llm=5, ground_truth=5, match=True\n",
+ "[63.0], llm=4644, ground_truth=4644, match=True\n",
+ "[63.1], llm=6808, ground_truth=6808, match=True\n",
+ "[63.2], llm=3496, ground_truth=3496, match=True\n",
+ "[63.3], llm=5012, ground_truth=4616, match=False\n",
+ "[63.4], llm=4024, ground_truth=4024, match=True\n",
+ "[64.0], llm=56, ground_truth=56, match=True\n",
+ "[64.1], llm=64, ground_truth=64, match=True\n",
+ "[64.2], llm=64, ground_truth=64, match=True\n",
+ "[64.3], llm=49, ground_truth=49, match=True\n",
+ "[64.4], llm=57, ground_truth=57, match=True\n",
+ "[65.0], llm=454.98, ground_truth=363, match=False\n",
+ "[65.1], llm=520, ground_truth=420, match=False\n",
+ "[65.2], llm=insufficient data, ground_truth=398, match=False\n",
+ "[65.3], llm=missing data, ground_truth=141, match=False\n",
+ "[65.4], llm=431.65, ground_truth=380, match=False\n",
+ "[66.0], llm=2, ground_truth=2, match=True\n",
+ "[66.1], llm=7, ground_truth=7, match=True\n",
+ "[66.2], llm=1, ground_truth=1, match=True\n",
+ "[66.3], llm=4, ground_truth=4, match=True\n",
+ "[66.4], llm=7, ground_truth=7, match=True\n",
+ "[67.0], llm=814, ground_truth=814, match=True\n",
+ "[67.1], llm=1928, ground_truth=1928, match=True\n",
+ "[67.2], llm=512, ground_truth=512, match=True\n",
+ "[67.3], llm=1314, ground_truth=1314, match=True\n",
+ "[67.4], llm=1381, ground_truth=1381, match=True\n",
+ "[68.0], llm=3773, ground_truth=3773, match=True\n",
+ "[68.1], llm=1715, ground_truth=1715, match=True\n",
+ "[68.2], llm=4320, ground_truth=4320, match=True\n",
+ "[68.3], llm=1715, ground_truth=1715, match=True\n",
+ "[68.4], llm=513, ground_truth=513, match=True\n",
+ "[69.0], llm=147, ground_truth=147, match=True\n",
+ "[69.1], llm=74, ground_truth=74, match=True\n",
+ "[69.2], llm=159, ground_truth=159, match=True\n",
+ "[69.3], llm=68, ground_truth=68, match=True\n",
+ "[69.4], llm=10, ground_truth=10, match=True\n",
+ "[70.0], llm=27, ground_truth=27, match=True\n",
+ "[70.1], llm=52, ground_truth=52, match=True\n",
+ "[70.2], llm=23, ground_truth=23, match=True\n",
+ "[70.3], llm=14, ground_truth=14, match=True\n",
+ "[70.4], llm=85, ground_truth=85, match=True\n",
+ "[71.0], llm=1047, ground_truth=1047, match=True\n",
+ "[71.1], llm=776, ground_truth=776, match=True\n",
+ "[71.2], llm=1285, ground_truth=1285, match=True\n",
+ "[71.3], llm=1113, ground_truth=1113, match=True\n",
+ "[71.4], llm=1060, ground_truth=1060, match=True\n",
+ "[72.0], llm=4, ground_truth=4, match=True\n",
+ "[72.1], llm=4, ground_truth=4, match=True\n",
+ "[72.2], llm=19, ground_truth=19, match=True\n",
+ "[72.3], llm=2, ground_truth=2, match=True\n",
+ "[72.4], llm=8, ground_truth=8, match=True\n",
+ "[73.0], llm=1280, ground_truth=1280, match=True\n",
+ "[73.1], llm=1620, ground_truth=1620, match=True\n",
+ "[73.2], llm=1728, ground_truth=1728, match=True\n",
+ "[73.3], llm=1379, ground_truth=1379, match=True\n",
+ "[73.4], llm=1826, ground_truth=1826, match=True\n",
+ "[74.0], llm=100, ground_truth=100, match=True\n",
+ "[74.1], llm=100, ground_truth=100, match=True\n",
+ "[74.2], llm=10, ground_truth=10, match=True\n",
+ "[74.3], llm=4.76, ground_truth=4, match=False\n",
+ "[74.4], llm=1.19, ground_truth=1, match=False\n",
+ "[75.0], llm=14.5, ground_truth=14, match=False\n",
+ "[75.1], llm=60, ground_truth=60, match=True\n",
+ "[75.2], llm=25.5, ground_truth=25, match=False\n",
+ "[75.3], llm=44, ground_truth=44, match=True\n",
+ "[75.4], llm=10.5, ground_truth=10, match=False\n",
+ "[76.0], llm=27, ground_truth=26, match=False\n",
+ "[76.1], llm=22.5, ground_truth=22, match=False\n",
+ "[76.2], llm=42, ground_truth=42, match=True\n",
+ "[76.3], llm=6, ground_truth=6, match=True\n",
+ "[76.4], llm=9, ground_truth=9, match=True\n",
+ "[77.0], llm=6, ground_truth=6, match=True\n",
+ "[77.1], llm=21, ground_truth=21, match=True\n",
+ "[77.2], llm=40, ground_truth=40, match=True\n",
+ "[77.3], llm=5, ground_truth=21, match=False\n",
+ "[77.4], llm=5, ground_truth=15, match=False\n",
+ "[78.0], llm=53, ground_truth=53, match=True\n",
+ "[78.1], llm=55, ground_truth=55, match=True\n",
+ "[78.2], llm=38, ground_truth=38, match=True\n",
+ "[78.3], llm=66, ground_truth=66, match=True\n",
+ "[78.4], llm=76, ground_truth=76, match=True\n",
+ "[79.0], llm=78, ground_truth=78, match=True\n",
+ "[79.1], llm=329, ground_truth=235, match=False\n",
+ "[79.2], llm=231, ground_truth=231, match=True\n",
+ "[79.3], llm=81, ground_truth=81, match=True\n",
+ "[79.4], llm=231, ground_truth=231, match=True\n",
+ "[80.0], llm=50, ground_truth=50, match=True\n",
+ "[80.1], llm=25, ground_truth=25, match=True\n",
+ "[80.2], llm=50, ground_truth=50, match=True\n",
+ "[80.3], llm=50, ground_truth=50, match=True\n",
+ "[80.4], llm=50, ground_truth=50, match=True\n",
+ "[81.0], llm=29160, ground_truth=29160, match=True\n",
+ "[81.1], llm=10080, ground_truth=10080, match=True\n",
+ "[81.2], llm=28080, ground_truth=28080, match=True\n",
+ "[81.3], llm=27000, ground_truth=27000, match=True\n",
+ "[81.4], llm=8160, ground_truth=8160, match=True\n",
+ "[82.0], llm=480, ground_truth=480, match=True\n",
+ "[82.1], llm=475, ground_truth=475, match=True\n",
+ "[82.2], llm=320, ground_truth=320, match=True\n",
+ "[82.3], llm=840, ground_truth=840, match=True\n",
+ "[82.4], llm=540, ground_truth=540, match=True\n",
+ "[83.0], llm=95, ground_truth=95, match=True\n",
+ "[83.1], llm=92, ground_truth=92, match=True\n",
+ "[83.2], llm=48, ground_truth=48, match=True\n",
+ "[83.3], llm=53, ground_truth=53, match=True\n",
+ "[83.4], llm=91, ground_truth=91, match=True\n",
+ "[84.0], llm=84, ground_truth=84, match=True\n",
+ "[84.1], llm=161, ground_truth=161, match=True\n",
+ "[84.2], llm=114, ground_truth=114, match=True\n",
+ "[84.3], llm=145, ground_truth=145, match=True\n",
+ "[84.4], llm=192, ground_truth=192, match=True\n",
+ "[85.0], llm=166, ground_truth=166, match=True\n",
+ "[85.1], llm=90, ground_truth=90, match=True\n",
+ "[85.2], llm=150, ground_truth=150, match=True\n",
+ "[85.3], llm=152, ground_truth=152, match=True\n",
+ "[85.4], llm=178, ground_truth=178, match=True\n",
+ "[86.0], llm=5, ground_truth=4, match=False\n",
+ "[86.1], llm=3, ground_truth=2, match=False\n",
+ "[86.2], llm=4, ground_truth=3, match=False\n",
+ "[86.3], llm=5, ground_truth=4, match=False\n",
+ "[86.4], llm=3, ground_truth=3, match=True\n",
+ "[87.0], llm=3, ground_truth=3, match=True\n",
+ "[87.1], llm=2, ground_truth=2, match=True\n",
+ "[87.2], llm=7.5, ground_truth=7, match=False\n",
+ "[87.3], llm=10, ground_truth=10, match=True\n",
+ "[87.4], llm=2, ground_truth=2, match=True\n",
+ "[88.0], llm=3, ground_truth=3, match=True\n",
+ "[88.1], llm=7, ground_truth=7, match=True\n",
+ "[88.2], llm=2, ground_truth=2, match=True\n",
+ "[88.3], llm=5, ground_truth=5, match=True\n",
+ "[88.4], llm=5, ground_truth=5, match=True\n",
+ "[89.0], llm=54, ground_truth=54, match=True\n",
+ "[89.1], llm=63, ground_truth=63, match=True\n",
+ "[89.2], llm=66, ground_truth=66, match=True\n",
+ "[89.3], llm=27, ground_truth=27, match=True\n",
+ "[89.4], llm=42, ground_truth=42, match=True\n",
+ "[90.0], llm=11.76, ground_truth=11, match=False\n",
+ "[90.1], llm=10.95, ground_truth=10, match=False\n",
+ "[90.2], llm=15.28, ground_truth=15, match=False\n",
+ "[90.3], llm=7.81, ground_truth=7, match=False\n",
+ "[90.4], llm=11.20, ground_truth=11, match=False\n",
+ "[91.0], llm=14400, ground_truth=14400, match=True\n",
+ "[91.1], llm=5040, ground_truth=5040, match=True\n",
+ "[91.2], llm=3520, ground_truth=3520, match=True\n",
+ "[91.3], llm=6300, ground_truth=6300, match=True\n",
+ "[91.4], llm=33630, ground_truth=33630, match=True\n",
+ "[92.0], llm=406, ground_truth=406, match=True\n",
+ "[92.1], llm=308, ground_truth=308, match=True\n",
+ "[92.2], llm=325, ground_truth=325, match=True\n",
+ "[92.3], llm=278, ground_truth=278, match=True\n",
+ "[92.4], llm=315, ground_truth=315, match=True\n",
+ "[93.0], llm=225, ground_truth=225, match=True\n",
+ "[93.1], llm=25, ground_truth=25, match=True\n",
+ "[93.2], llm=150, ground_truth=150, match=True\n",
+ "[93.3], llm=50, ground_truth=50, match=True\n",
+ "[93.4], llm=150, ground_truth=150, match=True\n",
+ "[94.0], llm=1406, ground_truth=1406, match=True\n",
+ "[94.1], llm=504, ground_truth=504, match=True\n",
+ "[94.2], llm=1320, ground_truth=1320, match=True\n",
+ "[94.3], llm=1656, ground_truth=1656, match=True\n",
+ "[94.4], llm=108, ground_truth=108, match=True\n",
+ "[95.0], llm=360, ground_truth=360, match=True\n",
+ "[95.1], llm=510, ground_truth=510, match=True\n",
+ "[95.2], llm=112, ground_truth=112, match=True\n",
+ "[95.3], llm=91, ground_truth=91, match=True\n",
+ "[95.4], llm=450, ground_truth=450, match=True\n",
+ "[96.0], llm=808, ground_truth=808, match=True\n",
+ "[96.1], llm=352, ground_truth=352, match=True\n",
+ "[96.2], llm=1062, ground_truth=1062, match=True\n",
+ "[96.3], llm=1203, ground_truth=1203, match=True\n",
+ "[96.4], llm=347, ground_truth=347, match=True\n",
+ "[97.0], llm=11.136, ground_truth=11, match=False\n",
+ "[97.1], llm=22.272, ground_truth=22, match=False\n",
+ "[97.2], llm=16.704, ground_truth=16, match=False\n",
+ "[97.3], llm=26.57, ground_truth=26, match=False\n",
+ "[97.4], llm=71.64, ground_truth=72, match=False\n",
+ "[98.0], llm=82, ground_truth=82, match=True\n",
+ "[98.1], llm=70, ground_truth=70, match=True\n",
+ "[98.2], llm=83.25, ground_truth=83, match=False\n",
+ "[98.3], llm=88.25, ground_truth=88, match=False\n",
+ "[98.4], llm=70.5, ground_truth=70, match=False\n",
+ "[99.0], llm=30.00, ground_truth=30, match=False\n",
+ "[99.1], llm=51.00, ground_truth=51, match=False\n",
+ "[99.2], llm=59.00, ground_truth=59, match=False\n",
+ "[99.3], llm=2.00, ground_truth=2, match=False\n",
+ "[99.4], llm=44.00, ground_truth=44, match=False\n"
+ ]
+ }
+ ],
+ "source": [
+ "rng = Random(55)\n",
+ "result_1 = cross_check_generators(rng, difficulty=1.0, num_generations=5)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "good = [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,33,34,36,38,39,40,41,42,43,44,45,46,47,48,49,50,52,53,54,55,56,57,58,59,62,64,66,67,68,69,70,71,72,73,78,80,81,82,83,84,85,88,89,91,92,93,94,95,96]\n",
+ "not_good = [32,35,37,51,60,61,63,65,74,75,76,77,79,86,87,90,97,98,99]\n"
+ ]
+ }
+ ],
+ "source": [
+ "good = [str(i) for i in range(len(result_1)) if result_1[i] == 5]\n",
+ "not_good = [str(i) for i in range(len(result_1)) if result_1[i] < 5]\n",
+ "\n",
+ "print('good = [' + \",\".join(good) + ']')\n",
+ "print('not_good = [' + \",\".join(not_good) + ']')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.12.7"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}