greedy coreset sampling

This commit is contained in:
Zafir Stojanovski 2025-02-22 16:15:14 +01:00
parent e9ff3a1ee2
commit e84cec26ed

View file

@ -100,6 +100,219 @@
"print(f\"There were {duplicate} out of {len(dataset)} duplicate entries\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Subsample the data"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import abc\n",
"from typing import Union\n",
"\n",
"import numpy as np\n",
"import torch\n",
"import tqdm\n",
"\n",
"\n",
"class IdentitySampler:\n",
" def run(\n",
" self, features: Union[torch.Tensor, np.ndarray]\n",
" ) -> Union[torch.Tensor, np.ndarray]:\n",
" return features\n",
"\n",
"\n",
"class BaseSampler(abc.ABC):\n",
" def __init__(self, percentage: float):\n",
" if not 0 < percentage < 1:\n",
" raise ValueError(\"Percentage value not in (0, 1).\")\n",
" self.percentage = percentage\n",
"\n",
" @abc.abstractmethod\n",
" def run(\n",
" self, features: Union[torch.Tensor, np.ndarray]\n",
" ) -> Union[torch.Tensor, np.ndarray]:\n",
" pass\n",
"\n",
" def _store_type(self, features: Union[torch.Tensor, np.ndarray]) -> None:\n",
" self.features_is_numpy = isinstance(features, np.ndarray)\n",
" if not self.features_is_numpy:\n",
" self.features_device = features.device\n",
"\n",
" def _restore_type(self, features: torch.Tensor) -> Union[torch.Tensor, np.ndarray]:\n",
" if self.features_is_numpy:\n",
" return features.cpu().numpy()\n",
" return features.to(self.features_device)\n",
"\n",
"\n",
"class GreedyCoresetSampler(BaseSampler):\n",
" def __init__(\n",
" self,\n",
" percentage: float,\n",
" device: torch.device,\n",
" dimension_to_project_features_to=128,\n",
" ):\n",
" \"\"\"Greedy Coreset sampling base class.\"\"\"\n",
" super().__init__(percentage)\n",
"\n",
" self.device = device\n",
" self.dimension_to_project_features_to = dimension_to_project_features_to\n",
"\n",
" def _reduce_features(self, features):\n",
" if features.shape[1] == self.dimension_to_project_features_to:\n",
" return features\n",
" mapper = torch.nn.Linear(\n",
" features.shape[1], self.dimension_to_project_features_to, bias=False\n",
" )\n",
" _ = mapper.to(self.device)\n",
" features = features.to(self.device)\n",
" return mapper(features)\n",
"\n",
" def run(\n",
" self, features: Union[torch.Tensor, np.ndarray]\n",
" ) -> Union[torch.Tensor, np.ndarray]:\n",
" \"\"\"Subsamples features using Greedy Coreset.\n",
"\n",
" Args:\n",
" features: [N x D]\n",
" \"\"\"\n",
" if self.percentage == 1:\n",
" return features\n",
" self._store_type(features)\n",
" if isinstance(features, np.ndarray):\n",
" features = torch.from_numpy(features)\n",
" reduced_features = self._reduce_features(features)\n",
" sample_indices = self._compute_greedy_coreset_indices(reduced_features)\n",
" features = features[sample_indices]\n",
" return self._restore_type(features)\n",
"\n",
" @staticmethod\n",
" def _compute_batchwise_differences(\n",
" matrix_a: torch.Tensor, matrix_b: torch.Tensor\n",
" ) -> torch.Tensor:\n",
" \"\"\"Computes batchwise Euclidean distances using PyTorch.\"\"\"\n",
" a_times_a = matrix_a.unsqueeze(1).bmm(matrix_a.unsqueeze(2)).reshape(-1, 1)\n",
" b_times_b = matrix_b.unsqueeze(1).bmm(matrix_b.unsqueeze(2)).reshape(1, -1)\n",
" a_times_b = matrix_a.mm(matrix_b.T)\n",
"\n",
" return (-2 * a_times_b + a_times_a + b_times_b).clamp(0, None).sqrt()\n",
"\n",
" def _compute_greedy_coreset_indices(self, features: torch.Tensor) -> np.ndarray:\n",
" \"\"\"Runs iterative greedy coreset selection.\n",
"\n",
" Args:\n",
" features: [NxD] input feature bank to sample.\n",
" \"\"\"\n",
" distance_matrix = self._compute_batchwise_differences(features, features)\n",
" coreset_anchor_distances = torch.norm(distance_matrix, dim=1)\n",
"\n",
" coreset_indices = []\n",
" num_coreset_samples = int(len(features) * self.percentage)\n",
"\n",
" for _ in range(num_coreset_samples):\n",
" select_idx = torch.argmax(coreset_anchor_distances).item()\n",
" coreset_indices.append(select_idx)\n",
"\n",
" coreset_select_distance = distance_matrix[\n",
" :, select_idx : select_idx + 1 # noqa E203\n",
" ]\n",
" coreset_anchor_distances = torch.cat(\n",
" [coreset_anchor_distances.unsqueeze(-1), coreset_select_distance], dim=1\n",
" )\n",
" coreset_anchor_distances = torch.min(coreset_anchor_distances, dim=1).values\n",
"\n",
" return np.array(coreset_indices)\n",
"\n",
"\n",
"class ApproximateGreedyCoresetSampler(GreedyCoresetSampler):\n",
" def __init__(\n",
" self,\n",
" percentage: float,\n",
" device: torch.device,\n",
" number_of_starting_points: int = 10,\n",
" dimension_to_project_features_to: int = 128,\n",
" ):\n",
" \"\"\"Approximate Greedy Coreset sampling base class.\"\"\"\n",
" self.number_of_starting_points = number_of_starting_points\n",
" super().__init__(percentage, device, dimension_to_project_features_to)\n",
"\n",
" def _compute_greedy_coreset_indices(self, features: torch.Tensor) -> np.ndarray:\n",
" \"\"\"Runs approximate iterative greedy coreset selection.\n",
"\n",
" This greedy coreset implementation does not require computation of the\n",
" full N x N distance matrix and thus requires a lot less memory, however\n",
" at the cost of increased sampling times.\n",
"\n",
" Args:\n",
" features: [NxD] input feature bank to sample.\n",
" \"\"\"\n",
" number_of_starting_points = np.clip(\n",
" self.number_of_starting_points, None, len(features)\n",
" )\n",
" start_points = np.random.choice(\n",
" len(features), number_of_starting_points, replace=False\n",
" ).tolist()\n",
"\n",
" approximate_distance_matrix = self._compute_batchwise_differences(\n",
" features, features[start_points]\n",
" )\n",
" approximate_coreset_anchor_distances = torch.mean(\n",
" approximate_distance_matrix, axis=-1\n",
" ).reshape(-1, 1)\n",
" coreset_indices = []\n",
" num_coreset_samples = int(len(features) * self.percentage)\n",
"\n",
" with torch.no_grad():\n",
" for _ in tqdm.tqdm(range(num_coreset_samples), desc=\"Subsampling...\"):\n",
" select_idx = torch.argmax(approximate_coreset_anchor_distances).item()\n",
" coreset_indices.append(select_idx)\n",
" coreset_select_distance = self._compute_batchwise_differences(\n",
" features, features[select_idx : select_idx + 1] # noqa: E203\n",
" )\n",
" approximate_coreset_anchor_distances = torch.cat(\n",
" [approximate_coreset_anchor_distances, coreset_select_distance],\n",
" dim=-1,\n",
" )\n",
" approximate_coreset_anchor_distances = torch.min(\n",
" approximate_coreset_anchor_distances, dim=1\n",
" ).values.reshape(-1, 1)\n",
"\n",
" return np.array(coreset_indices)\n",
"\n",
"\n",
"class RandomSampler(BaseSampler):\n",
" def __init__(self, percentage: float):\n",
" super().__init__(percentage)\n",
"\n",
" def run(\n",
" self, features: Union[torch.Tensor, np.ndarray]\n",
" ) -> Union[torch.Tensor, np.ndarray]:\n",
" \"\"\"Randomly samples input feature collection.\n",
"\n",
" Args:\n",
" features: [N x D]\n",
" \"\"\"\n",
" num_random_samples = int(len(features) * self.percentage)\n",
" subset_indices = np.random.choice(\n",
" len(features), num_random_samples, replace=False\n",
" )\n",
" subset_indices = np.array(subset_indices)\n",
" return features[subset_indices]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "markdown",
"metadata": {},
@ -109,7 +322,7 @@
},
{
"cell_type": "code",
"execution_count": 26,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@ -164,101 +377,21 @@
" })\n",
" )\n",
" full_response = response.json()[\"choices\"][0][\"message\"][\"content\"]\n",
" input_generator = re.search(r\"<function>(.*?)</function>\", full_response, re.DOTALL).group(1).strip()"
" input_generator = re.search(r\"<function>(.*?)</function>\", full_response, re.DOTALL).group(1).strip()\n",
"\n",
" # local_dict = {}\n",
" # exec(input_generator, globals(), local_dict)\n",
" # generate_input_func = local_dict['generate_input']\n",
" # rng = random.Random()\n",
"\n",
" # for i in range(5):\n",
" # random_input = generate_input_func(rng)\n",
" # print(f\"[{i}]: {random_input}\")"
]
},
{
"cell_type": "code",
"execution_count": 41,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"In the context of Conway's Game of Life, a cellular automaton devised by John Horton Conway, consider a board with `m` by `n` cells, where each cell can be either live (1) or dead (0). The state of each cell evolves based on its neighbors according to specific rules. Given the current state of the board, what will be the state of the board after one iteration of the game?\n",
"----------------\n",
"Input:\n",
" `board` (List[List[int]]): A 2D list representing the state of the board. Each element in the list is either `0` (dead cell) or `1` (live cell).\n",
"\n",
"Output:\n",
" `return` (List[List[int]]): A 2D list representing the next state of the board after applying the rules of Conway's Game of Life. Each element in the list is either `0` (dead cell) or `1` (live cell).\n",
"----------------\n",
"# import necessary packages\n",
"from collections import Counter\n",
"\n",
"# all class and function definitions in the code file, if any\n",
"class Solution(object):\n",
" def gameOfLifeInfinite(self, live):\n",
" ctr = Counter((I, J)\n",
" for i, j in live\n",
" for I in range(i-1, i+2)\n",
" for J in range(j-1, j+2)\n",
" if I != i or J != j)\n",
"\n",
" return {ij\n",
" for ij in ctr\n",
" if ctr[ij] == 3 or ctr[ij] == 2 and ij in live}\n",
"\n",
" def gameOfLife(self, board):\n",
" live_cell = {(row, col) for row in range(len(board)) for col in range(len(board[0])) if board[row][col]}\n",
" live_cell_next = self.gameOfLifeInfinite(live_cell)\n",
" for i, row in enumerate(board):\n",
" for j in range(len(row)):\n",
" board[i][j] = int((i, j) in live_cell_next)\n",
" return board\n",
"\n",
"# main function\n",
"def main_solution(board):\n",
" # Convert the input board to a list of lists if it's not already\n",
" if not isinstance(board, list) or not all(isinstance(row, list) for row in board):\n",
" raise ValueError(\"Input board must be a list of lists\")\n",
" \n",
" # Call the gameOfLife function to get the next state of the board\n",
" solution = Solution()\n",
" next_state = solution.gameOfLife(board)\n",
" \n",
" # Return the next state of the board\n",
" return next_state\n"
]
}
],
"source": [
"print(entry[\"task_description\"])\n",
"print(\"----------------\")\n",
"print(entry[\"input_output_spec\"])\n",
"print(\"----------------\")\n",
"print(entry[\"code_sample\"])"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"def generate_input(rng: Random) -> dict:\n",
" # Generate random dimensions for the board\n",
" m = rng.randint(1, 10) # Number of rows\n",
" n = rng.randint(1, 10) # Number of columns\n",
" \n",
" # Generate the board with random 0s and 1s\n",
" board = [[rng.choice([0, 1]) for _ in range(n)] for _ in range(m)]\n",
" \n",
" return {'board': board}\n"
]
}
],
"source": [
"print(input_generator)"
]
},
{
"cell_type": "code",
"execution_count": 39,
"execution_count": null,
"metadata": {},
"outputs": [
{
@ -273,16 +406,7 @@
]
}
],
"source": [
"local_dict = {}\n",
"exec(input_generator, globals(), local_dict)\n",
"generate_input_func = local_dict['generate_input']\n",
"rng = random.Random()\n",
"\n",
"for i in range(5):\n",
" random_input = generate_input_func(rng)\n",
" print(f\"[{i}]: {random_input}\")"
]
"source": []
},
{
"cell_type": "markdown",