reasoning-gym/notebooks/codeio.ipynb
2025-02-24 15:58:06 +01:00

533 lines
38 KiB
Text

{
"cells": [
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [],
"source": [
"import abc\n",
"import os\n",
"from typing import Union\n",
"import re\n",
"import random\n",
"from random import Random\n",
"import requests\n",
"import json\n",
"from tqdm import tqdm\n",
"import datasets\n",
"import numpy as np\n",
"import torch\n",
"from sentence_transformers import SentenceTransformer\n",
"import asyncio"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"dataset = datasets.load_dataset(\"hkust-nlp/CodeIO-PyEdu-Reasoning\")['train']"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Extract the relevant parts of the prompt"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 1630607/1630607 [01:20<00:00, 20302.13it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"There were 1489543 out of 1630607 duplicate entries\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"pattern = re.compile(\n",
" r'(?s)' # DOTALL so . matches newlines\n",
" r'You are given a question that requires some input and output variables as follows:\\s*(.*?)'\n",
" r'\\s*The input and output requirements are as follows:\\s*(.*?)'\n",
" r'\\s*Given the following.*?Tip: Here is a reference code snippet for this question\\. '\n",
" r'You can refer to this code to guide your reasoning but not copy spans of code directly\\.\\s*(.*)'\n",
")\n",
"\n",
"seen = set()\n",
"duplicate = 0\n",
"\n",
"with open(\"data/codeio-pyedu-extracted.jsonl\", \"w+\") as f:\n",
" for i, item in tqdm(enumerate(dataset), total=len(dataset)):\n",
" match = pattern.search(item[\"prompt\"])\n",
" if match:\n",
" # Extract relevant info\n",
" task_description = match.group(1).strip()\n",
" input_output_spec = match.group(2).strip()\n",
" code_sample = match.group(3).strip()\n",
"\n",
" # Check if code sample is unique\n",
" hash_entry = f\"{hash(task_description)}-{hash(input_output_spec)}-{hash(code_sample)}\"\n",
" if hash_entry in seen:\n",
" duplicate += 1\n",
" continue\n",
" seen.add(hash_entry)\n",
"\n",
" # Save to disk\n",
" json.dump({\n",
" \"task_description\": task_description,\n",
" \"input_output_spec\": input_output_spec,\n",
" \"code_sample\": code_sample\n",
" }, f)\n",
" f.write(\"\\n\")\n",
" else:\n",
" print(f\"No match found for item {i}\")\n",
"\n",
"print(f\"There were {duplicate} out of {len(dataset)} duplicate entries\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Subsample the data"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class IdentitySampler:\n",
" def run(\n",
" self, features: Union[torch.Tensor, np.ndarray]\n",
" ) -> Union[torch.Tensor, np.ndarray]:\n",
" return features\n",
"\n",
"\n",
"class BaseSampler(abc.ABC):\n",
" def __init__(self, percentage: float):\n",
" if not 0 < percentage < 1:\n",
" raise ValueError(\"Percentage value not in (0, 1).\")\n",
" self.percentage = percentage\n",
"\n",
" @abc.abstractmethod\n",
" def run(\n",
" self, features: Union[torch.Tensor, np.ndarray]\n",
" ) -> Union[torch.Tensor, np.ndarray]:\n",
" pass\n",
"\n",
" def _store_type(self, features: Union[torch.Tensor, np.ndarray]) -> None:\n",
" self.features_is_numpy = isinstance(features, np.ndarray)\n",
" if not self.features_is_numpy:\n",
" self.features_device = features.device\n",
"\n",
" def _restore_type(self, features: torch.Tensor) -> Union[torch.Tensor, np.ndarray]:\n",
" if self.features_is_numpy:\n",
" return features.cpu().numpy()\n",
" return features.to(self.features_device)\n",
"\n",
"\n",
"class GreedyCoresetSampler(BaseSampler):\n",
" def __init__(\n",
" self,\n",
" percentage: float,\n",
" device: torch.device,\n",
" dtype: torch.dtype = torch.float32,\n",
" dimension_to_project_features_to=128,\n",
" ):\n",
" \"\"\"Greedy Coreset sampling base class.\"\"\"\n",
" super().__init__(percentage)\n",
"\n",
" self.device = device\n",
" self.dtype = dtype\n",
" self.dimension_to_project_features_to = dimension_to_project_features_to\n",
"\n",
" def _reduce_features(self, features):\n",
" if features.shape[1] == self.dimension_to_project_features_to:\n",
" return features\n",
" mapper = torch.nn.Linear(\n",
" features.shape[1], self.dimension_to_project_features_to, bias=False, dtype=self.dtype,\n",
" )\n",
" _ = mapper.to(self.device)\n",
" features = features.to(self.device)\n",
" return mapper(features)\n",
"\n",
" def run(\n",
" self, features: Union[torch.Tensor, np.ndarray]\n",
" ) -> Union[torch.Tensor, np.ndarray]:\n",
" \"\"\"Subsamples features using Greedy Coreset.\n",
"\n",
" Args:\n",
" features: [N x D]\n",
" \"\"\"\n",
" if self.percentage == 1:\n",
" return features\n",
" self._store_type(features)\n",
" if isinstance(features, np.ndarray):\n",
" features = torch.from_numpy(features)\n",
" reduced_features = self._reduce_features(features)\n",
" sample_indices = self._compute_greedy_coreset_indices(reduced_features)\n",
" return sample_indices\n",
"\n",
" @staticmethod\n",
" def _compute_batchwise_differences(\n",
" matrix_a: torch.Tensor, matrix_b: torch.Tensor\n",
" ) -> torch.Tensor:\n",
" \"\"\"Computes batchwise Euclidean distances using PyTorch.\"\"\"\n",
" a_times_a = matrix_a.unsqueeze(1).bmm(matrix_a.unsqueeze(2)).reshape(-1, 1)\n",
" b_times_b = matrix_b.unsqueeze(1).bmm(matrix_b.unsqueeze(2)).reshape(1, -1)\n",
" a_times_b = matrix_a.mm(matrix_b.T)\n",
"\n",
" return (-2 * a_times_b + a_times_a + b_times_b).clamp(0, None).sqrt()\n",
"\n",
" def _compute_greedy_coreset_indices(self, features: torch.Tensor) -> np.ndarray:\n",
" \"\"\"Runs iterative greedy coreset selection.\n",
"\n",
" Args:\n",
" features: [NxD] input feature bank to sample.\n",
" \"\"\"\n",
" distance_matrix = self._compute_batchwise_differences(features, features)\n",
" coreset_anchor_distances = torch.norm(distance_matrix, dim=1)\n",
"\n",
" coreset_indices = []\n",
" num_coreset_samples = int(len(features) * self.percentage)\n",
"\n",
" for _ in range(num_coreset_samples):\n",
" select_idx = torch.argmax(coreset_anchor_distances).item()\n",
" coreset_indices.append(select_idx)\n",
"\n",
" coreset_select_distance = distance_matrix[\n",
" :, select_idx : select_idx + 1 # noqa E203\n",
" ]\n",
" coreset_anchor_distances = torch.cat(\n",
" [coreset_anchor_distances.unsqueeze(-1), coreset_select_distance], dim=1\n",
" )\n",
" coreset_anchor_distances = torch.min(coreset_anchor_distances, dim=1).values\n",
"\n",
" return torch.tensor(coreset_indices, device=features.device, dtype=torch.int64)\n",
"\n",
"\n",
"class ApproximateGreedyCoresetSampler(GreedyCoresetSampler):\n",
" def __init__(\n",
" self,\n",
" percentage: float,\n",
" device: torch.device,\n",
" dtype: torch.dtype = torch.float32,\n",
" number_of_starting_points: int = 10,\n",
" dimension_to_project_features_to: int = 128,\n",
" ):\n",
" \"\"\"Approximate Greedy Coreset sampling base class.\"\"\"\n",
" self.number_of_starting_points = number_of_starting_points\n",
" super().__init__(percentage, device, dtype, dimension_to_project_features_to)\n",
"\n",
" def _compute_greedy_coreset_indices(self, features: torch.Tensor) -> np.ndarray:\n",
" \"\"\"Runs approximate iterative greedy coreset selection.\n",
"\n",
" This greedy coreset implementation does not require computation of the\n",
" full N x N distance matrix and thus requires a lot less memory, however\n",
" at the cost of increased sampling times.\n",
"\n",
" Args:\n",
" features: [NxD] input feature bank to sample.\n",
" \"\"\"\n",
" number_of_starting_points = np.clip(\n",
" self.number_of_starting_points, None, len(features)\n",
" )\n",
" start_points = np.random.choice(\n",
" len(features), number_of_starting_points, replace=False\n",
" ).tolist()\n",
"\n",
" approximate_distance_matrix = self._compute_batchwise_differences(\n",
" features, features[start_points]\n",
" )\n",
" approximate_coreset_anchor_distances = torch.mean(\n",
" approximate_distance_matrix, axis=-1\n",
" ).reshape(-1, 1)\n",
" coreset_indices = []\n",
" num_coreset_samples = int(len(features) * self.percentage)\n",
"\n",
" with torch.no_grad():\n",
" for _ in tqdm.tqdm(range(num_coreset_samples), desc=\"Subsampling...\"):\n",
" select_idx = torch.argmax(approximate_coreset_anchor_distances).item()\n",
" coreset_indices.append(select_idx)\n",
" coreset_select_distance = self._compute_batchwise_differences(\n",
" features, features[select_idx : select_idx + 1] # noqa: E203\n",
" )\n",
" approximate_coreset_anchor_distances = torch.cat(\n",
" [approximate_coreset_anchor_distances, coreset_select_distance],\n",
" dim=-1,\n",
" )\n",
" approximate_coreset_anchor_distances = torch.min(\n",
" approximate_coreset_anchor_distances, dim=1\n",
" ).values.reshape(-1, 1)\n",
"\n",
" return torch.tensor(coreset_indices, device=features.device, dtype=torch.int64)\n",
"\n",
"\n",
"class RandomSampler(BaseSampler):\n",
" def __init__(self, percentage: float):\n",
" super().__init__(percentage)\n",
"\n",
" def run(\n",
" self, features: Union[torch.Tensor, np.ndarray]\n",
" ) -> Union[torch.Tensor, np.ndarray]:\n",
" \"\"\"Randomly samples input feature collection.\n",
"\n",
" Args:\n",
" features: [N x D]\n",
" \"\"\"\n",
" num_random_samples = int(len(features) * self.percentage)\n",
" subset_indices = np.random.choice(\n",
" len(features), num_random_samples, replace=False\n",
" )\n",
" return torch.tensor(subset_indices, device=features.device, dtype=torch.int64)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# I ran this cell on Google Colab because I don't have a GPU on my local machine,\n",
"# hence why you see the Google Drive paths\n",
"\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"model = SentenceTransformer(\"nomic-ai/modernbert-embed-base\")\n",
"print(model)\n",
"\n",
"def get_entry_info(entry) -> str:\n",
" return entry['task_description']\n",
"\n",
"def get_embeddings(text) -> torch.Tensor:\n",
" return torch.from_numpy(model.encode(text)).to(torch.bfloat16)\n",
"\n",
"embeddings = []\n",
"\n",
"with open(\"./drive/MyDrive/reasoning-gym/codeio-pyedu-extracted.jsonl\") as f:\n",
" for line in tqdm(f):\n",
" entry = json.loads(line)\n",
" entry_info = get_entry_info(entry)\n",
" embeddings.append(get_embeddings(entry_info))\n",
"\n",
"embeddings = torch.stack(embeddings).to(torch.bfloat16).to(device)\n",
"print(embeddings.shape)\n",
"\n",
"sampler = ApproximateGreedyCoresetSampler(\n",
" percentage=0.05, \n",
" device=device, \n",
" dtype=torch.bfloat16,\n",
" dimension_to_project_features_to=768,\n",
")\n",
"subsampled = sampler.run(embeddings)\n",
"\n",
"indices = set(subsampled.cpu().tolist())\n",
"with open(\"./drive/MyDrive/reasoning-gym/codeio-pyedu-extracted.jsonl\", \"r\") as f_in, \\\n",
" open(\"./drive/MyDrive/reasoning-gym/codeio-pyedu-best-coverage.jsonl\", \"w+\") as f_out:\n",
" for i, line in enumerate(f_in):\n",
" if i in indices:\n",
" f_out.write(line)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Create input generators for each problem separately"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
" 2%|▏ | 137/7053 [58:17<31:00:25, 16.14s/it] "
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Error: 'NoneType' object has no attribute 'group'\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" 2%|▏ | 152/7053 [1:05:09<49:18:10, 25.72s/it]\n"
]
},
{
"ename": "KeyboardInterrupt",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[20], line 40\u001b[0m\n\u001b[1;32m 38\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m i, line \u001b[38;5;129;01min\u001b[39;00m tqdm(\u001b[38;5;28menumerate\u001b[39m(f_in), total\u001b[38;5;241m=\u001b[39mtotal_entries):\n\u001b[1;32m 39\u001b[0m entry \u001b[38;5;241m=\u001b[39m json\u001b[38;5;241m.\u001b[39mloads(line)\n\u001b[0;32m---> 40\u001b[0m response \u001b[38;5;241m=\u001b[39m \u001b[43mrequests\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpost\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 41\u001b[0m \u001b[43m \u001b[49m\u001b[43murl\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mhttps://openrouter.ai/api/v1/chat/completions\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 42\u001b[0m \u001b[43m \u001b[49m\u001b[43mheaders\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m{\u001b[49m\n\u001b[1;32m 43\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mAuthorization\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m:\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43mf\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mBearer \u001b[39;49m\u001b[38;5;132;43;01m{\u001b[39;49;00m\u001b[43mos\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mgetenv\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mOPENROUTER_API_KEY\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m)\u001b[49m\u001b[38;5;132;43;01m}\u001b[39;49;00m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 44\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mContent-Type\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m:\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mapplication/json\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 45\u001b[0m \u001b[43m \u001b[49m\u001b[43m}\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 46\u001b[0m \u001b[43m \u001b[49m\u001b[43mdata\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mjson\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdumps\u001b[49m\u001b[43m(\u001b[49m\u001b[43m{\u001b[49m\n\u001b[1;32m 47\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mmodel\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m:\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mdeepseek/deepseek-chat\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 48\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mmessages\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m:\u001b[49m\u001b[43m \u001b[49m\u001b[43m[\u001b[49m\n\u001b[1;32m 49\u001b[0m \u001b[43m \u001b[49m\u001b[43m{\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mrole\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m:\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43msystem\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mcontent\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m:\u001b[49m\u001b[43m \u001b[49m\u001b[43mSYSTEM_PROMPT\u001b[49m\u001b[43m}\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 50\u001b[0m \u001b[43m \u001b[49m\u001b[43m{\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mrole\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m:\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43muser\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mcontent\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m:\u001b[49m\u001b[43m \u001b[49m\u001b[43mUSER_PROMPT\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mformat\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mentry\u001b[49m\u001b[43m)\u001b[49m\u001b[43m}\u001b[49m\n\u001b[1;32m 51\u001b[0m \u001b[43m \u001b[49m\u001b[43m]\u001b[49m\n\u001b[1;32m 52\u001b[0m \u001b[43m \u001b[49m\u001b[43m}\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 53\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 55\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 56\u001b[0m full_response \u001b[38;5;241m=\u001b[39m response\u001b[38;5;241m.\u001b[39mjson()[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mchoices\u001b[39m\u001b[38;5;124m\"\u001b[39m][\u001b[38;5;241m0\u001b[39m][\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmessage\u001b[39m\u001b[38;5;124m\"\u001b[39m][\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcontent\u001b[39m\u001b[38;5;124m\"\u001b[39m]\n",
"File \u001b[0;32m~/miniconda3/envs/reasoning_gym/lib/python3.11/site-packages/requests/api.py:115\u001b[0m, in \u001b[0;36mpost\u001b[0;34m(url, data, json, **kwargs)\u001b[0m\n\u001b[1;32m 103\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21mpost\u001b[39m(url, data\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m, json\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m 104\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124mr\u001b[39m\u001b[38;5;124;03m\"\"\"Sends a POST request.\u001b[39;00m\n\u001b[1;32m 105\u001b[0m \n\u001b[1;32m 106\u001b[0m \u001b[38;5;124;03m :param url: URL for the new :class:`Request` object.\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 112\u001b[0m \u001b[38;5;124;03m :rtype: requests.Response\u001b[39;00m\n\u001b[1;32m 113\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m--> 115\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mrequest\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mpost\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43murl\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdata\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdata\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mjson\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mjson\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/miniconda3/envs/reasoning_gym/lib/python3.11/site-packages/requests/api.py:59\u001b[0m, in \u001b[0;36mrequest\u001b[0;34m(method, url, **kwargs)\u001b[0m\n\u001b[1;32m 55\u001b[0m \u001b[38;5;66;03m# By using the 'with' statement we are sure the session is closed, thus we\u001b[39;00m\n\u001b[1;32m 56\u001b[0m \u001b[38;5;66;03m# avoid leaving sockets open which can trigger a ResourceWarning in some\u001b[39;00m\n\u001b[1;32m 57\u001b[0m \u001b[38;5;66;03m# cases, and look like a memory leak in others.\u001b[39;00m\n\u001b[1;32m 58\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m sessions\u001b[38;5;241m.\u001b[39mSession() \u001b[38;5;28;01mas\u001b[39;00m session:\n\u001b[0;32m---> 59\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43msession\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrequest\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmethod\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmethod\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43murl\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43murl\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/miniconda3/envs/reasoning_gym/lib/python3.11/site-packages/requests/sessions.py:589\u001b[0m, in \u001b[0;36mSession.request\u001b[0;34m(self, method, url, params, data, headers, cookies, files, auth, timeout, allow_redirects, proxies, hooks, stream, verify, cert, json)\u001b[0m\n\u001b[1;32m 584\u001b[0m send_kwargs \u001b[38;5;241m=\u001b[39m {\n\u001b[1;32m 585\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtimeout\u001b[39m\u001b[38;5;124m\"\u001b[39m: timeout,\n\u001b[1;32m 586\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mallow_redirects\u001b[39m\u001b[38;5;124m\"\u001b[39m: allow_redirects,\n\u001b[1;32m 587\u001b[0m }\n\u001b[1;32m 588\u001b[0m send_kwargs\u001b[38;5;241m.\u001b[39mupdate(settings)\n\u001b[0;32m--> 589\u001b[0m resp \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msend\u001b[49m\u001b[43m(\u001b[49m\u001b[43mprep\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43msend_kwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 591\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m resp\n",
"File \u001b[0;32m~/miniconda3/envs/reasoning_gym/lib/python3.11/site-packages/requests/sessions.py:746\u001b[0m, in \u001b[0;36mSession.send\u001b[0;34m(self, request, **kwargs)\u001b[0m\n\u001b[1;32m 743\u001b[0m \u001b[38;5;28;01mpass\u001b[39;00m\n\u001b[1;32m 745\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m stream:\n\u001b[0;32m--> 746\u001b[0m \u001b[43mr\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcontent\u001b[49m\n\u001b[1;32m 748\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m r\n",
"File \u001b[0;32m~/miniconda3/envs/reasoning_gym/lib/python3.11/site-packages/requests/models.py:902\u001b[0m, in \u001b[0;36mResponse.content\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 900\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_content \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 901\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 902\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_content \u001b[38;5;241m=\u001b[39m \u001b[38;5;124mb\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;241m.\u001b[39mjoin(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39miter_content(CONTENT_CHUNK_SIZE)) \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;124mb\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 904\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_content_consumed \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m\n\u001b[1;32m 905\u001b[0m \u001b[38;5;66;03m# don't need to release the connection; that's been handled by urllib3\u001b[39;00m\n\u001b[1;32m 906\u001b[0m \u001b[38;5;66;03m# since we exhausted the data.\u001b[39;00m\n",
"File \u001b[0;32m~/miniconda3/envs/reasoning_gym/lib/python3.11/site-packages/requests/models.py:820\u001b[0m, in \u001b[0;36mResponse.iter_content.<locals>.generate\u001b[0;34m()\u001b[0m\n\u001b[1;32m 818\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mhasattr\u001b[39m(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mraw, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mstream\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n\u001b[1;32m 819\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 820\u001b[0m \u001b[38;5;28;01myield from\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mraw\u001b[38;5;241m.\u001b[39mstream(chunk_size, decode_content\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n\u001b[1;32m 821\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m ProtocolError \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[1;32m 822\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m ChunkedEncodingError(e)\n",
"File \u001b[0;32m~/miniconda3/envs/reasoning_gym/lib/python3.11/site-packages/urllib3/response.py:1063\u001b[0m, in \u001b[0;36mHTTPResponse.stream\u001b[0;34m(self, amt, decode_content)\u001b[0m\n\u001b[1;32m 1047\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 1048\u001b[0m \u001b[38;5;124;03mA generator wrapper for the read() method. A call will block until\u001b[39;00m\n\u001b[1;32m 1049\u001b[0m \u001b[38;5;124;03m``amt`` bytes have been read from the connection or until the\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 1060\u001b[0m \u001b[38;5;124;03m 'content-encoding' header.\u001b[39;00m\n\u001b[1;32m 1061\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 1062\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mchunked \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39msupports_chunked_reads():\n\u001b[0;32m-> 1063\u001b[0m \u001b[38;5;28;01myield from\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mread_chunked(amt, decode_content\u001b[38;5;241m=\u001b[39mdecode_content)\n\u001b[1;32m 1064\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 1065\u001b[0m \u001b[38;5;28;01mwhile\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m is_fp_closed(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_fp) \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_decoded_buffer) \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m0\u001b[39m:\n",
"File \u001b[0;32m~/miniconda3/envs/reasoning_gym/lib/python3.11/site-packages/urllib3/response.py:1219\u001b[0m, in \u001b[0;36mHTTPResponse.read_chunked\u001b[0;34m(self, amt, decode_content)\u001b[0m\n\u001b[1;32m 1216\u001b[0m amt \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 1218\u001b[0m \u001b[38;5;28;01mwhile\u001b[39;00m \u001b[38;5;28;01mTrue\u001b[39;00m:\n\u001b[0;32m-> 1219\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_update_chunk_length\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1220\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mchunk_left \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m0\u001b[39m:\n\u001b[1;32m 1221\u001b[0m \u001b[38;5;28;01mbreak\u001b[39;00m\n",
"File \u001b[0;32m~/miniconda3/envs/reasoning_gym/lib/python3.11/site-packages/urllib3/response.py:1138\u001b[0m, in \u001b[0;36mHTTPResponse._update_chunk_length\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1136\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mchunk_left \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 1137\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[0;32m-> 1138\u001b[0m line \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_fp\u001b[38;5;241m.\u001b[39mfp\u001b[38;5;241m.\u001b[39mreadline() \u001b[38;5;66;03m# type: ignore[union-attr]\u001b[39;00m\n\u001b[1;32m 1139\u001b[0m line \u001b[38;5;241m=\u001b[39m line\u001b[38;5;241m.\u001b[39msplit(\u001b[38;5;124mb\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m;\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;241m1\u001b[39m)[\u001b[38;5;241m0\u001b[39m]\n\u001b[1;32m 1140\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n",
"File \u001b[0;32m~/miniconda3/envs/reasoning_gym/lib/python3.11/socket.py:718\u001b[0m, in \u001b[0;36mSocketIO.readinto\u001b[0;34m(self, b)\u001b[0m\n\u001b[1;32m 716\u001b[0m \u001b[38;5;28;01mwhile\u001b[39;00m \u001b[38;5;28;01mTrue\u001b[39;00m:\n\u001b[1;32m 717\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 718\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_sock\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrecv_into\u001b[49m\u001b[43m(\u001b[49m\u001b[43mb\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 719\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m timeout:\n\u001b[1;32m 720\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_timeout_occurred \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m\n",
"File \u001b[0;32m~/miniconda3/envs/reasoning_gym/lib/python3.11/ssl.py:1314\u001b[0m, in \u001b[0;36mSSLSocket.recv_into\u001b[0;34m(self, buffer, nbytes, flags)\u001b[0m\n\u001b[1;32m 1310\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m flags \u001b[38;5;241m!=\u001b[39m \u001b[38;5;241m0\u001b[39m:\n\u001b[1;32m 1311\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[1;32m 1312\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mnon-zero flags not allowed in calls to recv_into() on \u001b[39m\u001b[38;5;132;01m%s\u001b[39;00m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;241m%\u001b[39m\n\u001b[1;32m 1313\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__class__\u001b[39m)\n\u001b[0;32m-> 1314\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mread\u001b[49m\u001b[43m(\u001b[49m\u001b[43mnbytes\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbuffer\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1315\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 1316\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28msuper\u001b[39m()\u001b[38;5;241m.\u001b[39mrecv_into(buffer, nbytes, flags)\n",
"File \u001b[0;32m~/miniconda3/envs/reasoning_gym/lib/python3.11/ssl.py:1166\u001b[0m, in \u001b[0;36mSSLSocket.read\u001b[0;34m(self, len, buffer)\u001b[0m\n\u001b[1;32m 1164\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1165\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m buffer \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m-> 1166\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_sslobj\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mread\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mlen\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbuffer\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1167\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 1168\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_sslobj\u001b[38;5;241m.\u001b[39mread(\u001b[38;5;28mlen\u001b[39m)\n",
"\u001b[0;31mKeyboardInterrupt\u001b[0m: "
]
}
],
"source": [
"SYSTEM_PROMPT = \"\"\"You are a helpful assistant that generates valid Python functions that act as input generators for a given code snippet.\n",
"\n",
"You have access to `random.Random`, therefore you SHOULD NOT import it again. You should use this random number generator to make the input generation process stochastic on each call.\n",
"\n",
"When the user asks you to generate an input for a code snippet, you should strictly respond in the following format:\n",
"<function>\n",
"def generate_input(rng: Random) -> dict:\n",
" # Your code here\n",
" pass\n",
"</function>\n",
"\n",
"The output of the function should be a dictionary where the keys are the variable names and the values are the generated values.\n",
"\n",
"It must contain all the variables that listed in the user's input specification, or more precisely in the `main_solution` function signature. \n",
"\"\"\"\n",
"\n",
"USER_PROMPT = \"\"\"Following are a task description, input/output specification, and relevant code snippet for a Python programming task.\n",
"\n",
"<task_description>\n",
"{task_description}\n",
"</task_description>\n",
"\n",
"<input_output_spec>\n",
"{input_output_spec}\n",
"</input_output_spec>\n",
"\n",
"<code_sample>\n",
"{code_sample}\n",
"</code_sample>\n",
"\n",
"Your task is to write a Python function `def generate_input(rng: Random) -> dict:` that generates valid inputs for the given code snippet, based on the provided information.\n",
"\"\"\"\n",
"\n",
"total_entries = sum(1 for _ in open(\"data/codeio-pyedu-best-coverage.jsonl\", \"r\"))\n",
"\n",
"with open(\"data/codeio-pyedu-best-coverage.jsonl\", \"r\") as f_in, \\\n",
" open(\"data/codeio-pyedu-with-input-generator.jsonl\", \"w+\") as f_out:\n",
" for i, line in tqdm(enumerate(f_in), total=total_entries):\n",
" entry = json.loads(line)\n",
" response = requests.post(\n",
" url=\"https://openrouter.ai/api/v1/chat/completions\",\n",
" headers={\n",
" \"Authorization\": f\"Bearer {os.getenv('OPENROUTER_API_KEY')}\",\n",
" \"Content-Type\": \"application/json\",\n",
" },\n",
" data = json.dumps({\n",
" \"model\": \"deepseek/deepseek-chat\",\n",
" \"messages\": [\n",
" {\"role\": \"system\", \"content\": SYSTEM_PROMPT},\n",
" {\"role\": \"user\", \"content\": USER_PROMPT.format(**entry)}\n",
" ]\n",
" })\n",
" )\n",
"\n",
" try:\n",
" full_response = response.json()[\"choices\"][0][\"message\"][\"content\"]\n",
" input_generator = re.search(r\"<function>(.*?)</function>\", full_response, re.DOTALL).group(1).strip()\n",
" entry['input_generator'] = input_generator\n",
" f_out.write(json.dumps(entry))\n",
" f_out.write(\"\\n\")\n",
" except Exception as e:\n",
" print(f\"Error: {e}\")\n",
" continue\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0]: {'p_meta': 'T(x)'}\n",
"[1]: {'p_meta': 'T(x)'}\n",
"[2]: {'p_meta': 'S(x)'}\n",
"[3]: {'p_meta': 'T(x)'}\n",
"[4]: {'p_meta': 'S(x)'}\n"
]
}
],
"source": [
"# Example of how to execute the generated code\n",
"# local_dict = {}\n",
"# exec(data['input_generator'], globals(), local_dict)\n",
"# generate_input_func = local_dict['generate_input']\n",
"# rng = random.Random()\n",
"\n",
"# for i in range(5):\n",
"# random_input = generate_input_func(rng)\n",
"# print(f\"[{i}]: {random_input}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "reasoning_gym",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.11"
}
},
"nbformat": 4,
"nbformat_minor": 2
}