{ "cells": [ { "cell_type": "markdown", "id": "e81636e0-fa3c-462d-ae2e-2bb35cafd544", "metadata": {}, "source": [ "## Investigating collisions in reasoning-gym datasets\n", "\n", "This notebook helps to investigate collisions in training and validation datasets generated with different seeds as intended to be used for RL training." ] }, { "cell_type": "code", "execution_count": 2, "id": "42323371-404e-4e86-b8b8-a420b4c79303", "metadata": {}, "outputs": [], "source": [ "import reasoning_gym" ] }, { "cell_type": "code", "execution_count": 3, "id": "f06e7932-6c77-4609-8a33-7c4d815841d6", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Total number of data: 15\n" ] } ], "source": [ "with open(\"data.txt\") as f:\n", " data_names = f.readlines()\n", " data_names = [name.strip() for name in data_names]\n", " print(\"Total number of data: \", len(data_names))" ] }, { "cell_type": "code", "execution_count": 4, "id": "d7a5a5bf-7428-46f5-a7f5-a46238df2543", "metadata": {}, "outputs": [], "source": [ "TOTAL = 10000\n", "collisions = []" ] }, { "cell_type": "code", "execution_count": null, "id": "7138aced-d61a-4e2a-9935-a9b251e6d554", "metadata": {}, "outputs": [], "source": [ "%%time\n", "for name in data_names:\n", " data_1 = reasoning_gym.create_dataset(name, size=TOTAL, seed=1)\n", " data_2 = reasoning_gym.create_dataset(name, size=TOTAL, seed=2)\n", " count = 0\n", " for item_1, item_2 in zip(data_1, data_2):\n", " if item_1[\"question\"] == item_2[\"question\"]:\n", " count += 1\n", "\n", " # Add name, count to collisions.txt\n", " with open('collisions_1.txt', 'a') as file:\n", " file.write(f\"{name}, {count}\\n\")" ] }, { "cell_type": "code", "execution_count": null, "id": "73297817-0eda-4bb4-9850-1375b2fcfa4d", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "id": "86d9535a-73fd-4930-a2df-393efd73cb75", "metadata": {}, "source": [ "# Report on collisions data generated" ] }, { "cell_type": "code", "execution_count": 17, "id": "d41efd7a-3d23-4433-98d5-617b5aa66f07", "metadata": {}, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "import seaborn as sns\n", "import pandas as pd" ] }, { "cell_type": "code", "execution_count": 18, "id": "e732e125-e232-4e1a-bc66-1ee37405d6ff", "metadata": {}, "outputs": [], "source": [ "with open('collisions.txt', 'r') as file:\n", " data = [line.strip().split(\",\") for line in file.readlines()]" ] }, { "cell_type": "code", "execution_count": 19, "id": "77a90340-f760-41d4-ab2a-0b14b6dbfc08", "metadata": {}, "outputs": [], "source": [ "# Clean data\n", "data = [(name, collision.strip()) for name, collision in data]" ] }, { "cell_type": "code", "execution_count": 20, "id": "36054684-470c-4ea4-8a60-25efb8218926", "metadata": {}, "outputs": [], "source": [ "df = pd.DataFrame([[name, collision] for name, collision in data], columns=[\"name\", \"collisions\"])" ] }, { "cell_type": "code", "execution_count": 21, "id": "0d2e6752-8d04-46c9-8057-3e71a39819f9", "metadata": {}, "outputs": [], "source": [ "# Change collision to int\n", "df[\"collisions\"] = df[\"collisions\"].astype(int)" ] }, { "cell_type": "code", "execution_count": 22, "id": "a2b05dc7-4319-4709-ae57-b3a6637bc66a", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
| \n", " | name | \n", "collisions | \n", "
|---|---|---|
| 0 | \n", "complex_arithmetic | \n", "0 | \n", "
| 1 | \n", "intermediate_integration | \n", "12 | \n", "
| 2 | \n", "polynomial_equations | \n", "0 | \n", "
| 3 | \n", "polynomial_multiplication | \n", "0 | \n", "
| 4 | \n", "simple_equations | \n", "0 | \n", "